390 lines
11 KiB
Go
390 lines
11 KiB
Go
|
// Copyright 2017 Google Inc. All Rights Reserved.
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
package firestore
|
||
|
|
||
|
import (
|
||
|
"testing"
|
||
|
|
||
|
"golang.org/x/net/context"
|
||
|
"google.golang.org/grpc/status"
|
||
|
|
||
|
pb "google.golang.org/genproto/googleapis/firestore/v1beta1"
|
||
|
|
||
|
"github.com/golang/protobuf/ptypes/empty"
|
||
|
"google.golang.org/api/iterator"
|
||
|
"google.golang.org/grpc"
|
||
|
"google.golang.org/grpc/codes"
|
||
|
)
|
||
|
|
||
|
func TestRunTransaction(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
const db = "projects/projectID/databases/(default)"
|
||
|
tid := []byte{1}
|
||
|
c, srv := newMock(t)
|
||
|
beginReq := &pb.BeginTransactionRequest{Database: db}
|
||
|
beginRes := &pb.BeginTransactionResponse{Transaction: tid}
|
||
|
commitReq := &pb.CommitRequest{Database: db, Transaction: tid}
|
||
|
// Empty transaction.
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
|
||
|
err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
// Transaction with read and write.
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
aDoc := &pb.Document{
|
||
|
Name: db + "/documents/C/a",
|
||
|
CreateTime: aTimestamp,
|
||
|
UpdateTime: aTimestamp2,
|
||
|
Fields: map[string]*pb.Value{"count": intval(1)},
|
||
|
}
|
||
|
srv.addRPC(
|
||
|
&pb.BatchGetDocumentsRequest{
|
||
|
Database: c.path(),
|
||
|
Documents: []string{db + "/documents/C/a"},
|
||
|
ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
|
||
|
}, []interface{}{
|
||
|
&pb.BatchGetDocumentsResponse{
|
||
|
Result: &pb.BatchGetDocumentsResponse_Found{aDoc},
|
||
|
ReadTime: aTimestamp2,
|
||
|
},
|
||
|
})
|
||
|
aDoc2 := &pb.Document{
|
||
|
Name: aDoc.Name,
|
||
|
Fields: map[string]*pb.Value{"count": intval(2)},
|
||
|
}
|
||
|
srv.addRPC(
|
||
|
&pb.CommitRequest{
|
||
|
Database: db,
|
||
|
Transaction: tid,
|
||
|
Writes: []*pb.Write{{
|
||
|
Operation: &pb.Write_Update{aDoc2},
|
||
|
UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
|
||
|
CurrentDocument: &pb.Precondition{
|
||
|
ConditionType: &pb.Precondition_Exists{true},
|
||
|
},
|
||
|
}},
|
||
|
},
|
||
|
&pb.CommitResponse{CommitTime: aTimestamp3},
|
||
|
)
|
||
|
err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
|
||
|
docref := c.Collection("C").Doc("a")
|
||
|
doc, err := tx.Get(docref)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
count, err := doc.DataAt("count")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
tx.Update(docref, []Update{{Path: "count", Value: count.(int64) + 1}})
|
||
|
return nil
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
// Query
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(
|
||
|
&pb.RunQueryRequest{
|
||
|
Parent: db,
|
||
|
QueryType: &pb.RunQueryRequest_StructuredQuery{
|
||
|
&pb.StructuredQuery{
|
||
|
From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: "C"}},
|
||
|
},
|
||
|
},
|
||
|
ConsistencySelector: &pb.RunQueryRequest_Transaction{tid},
|
||
|
},
|
||
|
[]interface{}{},
|
||
|
)
|
||
|
srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp3})
|
||
|
err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
|
||
|
it := tx.Documents(c.Collection("C"))
|
||
|
defer it.Stop()
|
||
|
_, err := it.Next()
|
||
|
if err != iterator.Done {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
// Retry entire transaction.
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
|
||
|
srv.addRPC(
|
||
|
&pb.BeginTransactionRequest{
|
||
|
Database: db,
|
||
|
Options: &pb.TransactionOptions{
|
||
|
Mode: &pb.TransactionOptions_ReadWrite_{
|
||
|
&pb.TransactionOptions_ReadWrite{tid},
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
beginRes,
|
||
|
)
|
||
|
srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
|
||
|
err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { return nil })
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestTransactionErrors(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
const db = "projects/projectID/databases/(default)"
|
||
|
c, srv := newMock(t)
|
||
|
var (
|
||
|
tid = []byte{1}
|
||
|
internalErr = status.Errorf(codes.Internal, "so sad")
|
||
|
beginReq = &pb.BeginTransactionRequest{
|
||
|
Database: db,
|
||
|
}
|
||
|
beginRes = &pb.BeginTransactionResponse{Transaction: tid}
|
||
|
getReq = &pb.BatchGetDocumentsRequest{
|
||
|
Database: c.path(),
|
||
|
Documents: []string{db + "/documents/C/a"},
|
||
|
ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
|
||
|
}
|
||
|
rollbackReq = &pb.RollbackRequest{Database: db, Transaction: tid}
|
||
|
commitReq = &pb.CommitRequest{Database: db, Transaction: tid}
|
||
|
)
|
||
|
|
||
|
// BeginTransaction has a permanent error.
|
||
|
srv.addRPC(beginReq, internalErr)
|
||
|
err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
|
||
|
if grpc.Code(err) != codes.Internal {
|
||
|
t.Errorf("got <%v>, want Internal", err)
|
||
|
}
|
||
|
|
||
|
// Get has a permanent error.
|
||
|
get := func(_ context.Context, tx *Transaction) error {
|
||
|
_, err := tx.Get(c.Doc("C/a"))
|
||
|
return err
|
||
|
}
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(getReq, internalErr)
|
||
|
srv.addRPC(rollbackReq, &empty.Empty{})
|
||
|
err = c.RunTransaction(ctx, get)
|
||
|
if grpc.Code(err) != codes.Internal {
|
||
|
t.Errorf("got <%v>, want Internal", err)
|
||
|
}
|
||
|
|
||
|
// Get has a permanent error, but the rollback fails. We still
|
||
|
// return Get's error.
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(getReq, internalErr)
|
||
|
srv.addRPC(rollbackReq, status.Errorf(codes.FailedPrecondition, ""))
|
||
|
err = c.RunTransaction(ctx, get)
|
||
|
if grpc.Code(err) != codes.Internal {
|
||
|
t.Errorf("got <%v>, want Internal", err)
|
||
|
}
|
||
|
|
||
|
// Commit has a permanent error.
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(getReq, []interface{}{
|
||
|
&pb.BatchGetDocumentsResponse{
|
||
|
Result: &pb.BatchGetDocumentsResponse_Found{&pb.Document{
|
||
|
Name: "projects/projectID/databases/(default)/documents/C/a",
|
||
|
CreateTime: aTimestamp,
|
||
|
UpdateTime: aTimestamp2,
|
||
|
}},
|
||
|
ReadTime: aTimestamp2,
|
||
|
},
|
||
|
})
|
||
|
srv.addRPC(commitReq, internalErr)
|
||
|
err = c.RunTransaction(ctx, get)
|
||
|
if grpc.Code(err) != codes.Internal {
|
||
|
t.Errorf("got <%v>, want Internal", err)
|
||
|
}
|
||
|
|
||
|
// Read after write.
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(rollbackReq, &empty.Empty{})
|
||
|
err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
|
||
|
tx.Delete(c.Doc("C/a"))
|
||
|
if _, err := tx.Get(c.Doc("C/a")); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
if err != errReadAfterWrite {
|
||
|
t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
|
||
|
}
|
||
|
|
||
|
// Read after write, with query.
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(rollbackReq, &empty.Empty{})
|
||
|
err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
|
||
|
tx.Delete(c.Doc("C/a"))
|
||
|
it := tx.Documents(c.Collection("C").Select("x"))
|
||
|
defer it.Stop()
|
||
|
if _, err := it.Next(); err != iterator.Done {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
if err != errReadAfterWrite {
|
||
|
t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
|
||
|
}
|
||
|
|
||
|
// Read after write fails even if the user ignores the read's error.
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(rollbackReq, &empty.Empty{})
|
||
|
err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
|
||
|
tx.Delete(c.Doc("C/a"))
|
||
|
tx.Get(c.Doc("C/a"))
|
||
|
return nil
|
||
|
})
|
||
|
if err != errReadAfterWrite {
|
||
|
t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
|
||
|
}
|
||
|
|
||
|
// Write in read-only transaction.
|
||
|
srv.reset()
|
||
|
srv.addRPC(
|
||
|
&pb.BeginTransactionRequest{
|
||
|
Database: db,
|
||
|
Options: &pb.TransactionOptions{
|
||
|
Mode: &pb.TransactionOptions_ReadOnly_{&pb.TransactionOptions_ReadOnly{}},
|
||
|
},
|
||
|
},
|
||
|
beginRes,
|
||
|
)
|
||
|
srv.addRPC(rollbackReq, &empty.Empty{})
|
||
|
err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
|
||
|
return tx.Delete(c.Doc("C/a"))
|
||
|
}, ReadOnly)
|
||
|
if err != errWriteReadOnly {
|
||
|
t.Errorf("got <%v>, want <%v>", err, errWriteReadOnly)
|
||
|
}
|
||
|
|
||
|
// Too many retries.
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
|
||
|
srv.addRPC(
|
||
|
&pb.BeginTransactionRequest{
|
||
|
Database: db,
|
||
|
Options: &pb.TransactionOptions{
|
||
|
Mode: &pb.TransactionOptions_ReadWrite_{
|
||
|
&pb.TransactionOptions_ReadWrite{tid},
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
beginRes,
|
||
|
)
|
||
|
srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
|
||
|
srv.addRPC(rollbackReq, &empty.Empty{})
|
||
|
err = c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil },
|
||
|
MaxAttempts(2))
|
||
|
if grpc.Code(err) != codes.Aborted {
|
||
|
t.Errorf("got <%v>, want Aborted", err)
|
||
|
}
|
||
|
|
||
|
// Nested transaction.
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(rollbackReq, &empty.Empty{})
|
||
|
err = c.RunTransaction(ctx, func(ctx context.Context, tx *Transaction) error {
|
||
|
return c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
|
||
|
})
|
||
|
if got, want := err, errNestedTransaction; got != want {
|
||
|
t.Errorf("got <%v>, want <%v>", got, want)
|
||
|
}
|
||
|
|
||
|
// Non-transactional operation.
|
||
|
dr := c.Doc("C/d")
|
||
|
|
||
|
for i, op := range []func(ctx context.Context) error{
|
||
|
func(ctx context.Context) error { _, err := c.GetAll(ctx, []*DocumentRef{dr}); return err },
|
||
|
func(ctx context.Context) error { _, _, err := c.Collection("C").Add(ctx, testData); return err },
|
||
|
func(ctx context.Context) error { _, err := dr.Get(ctx); return err },
|
||
|
func(ctx context.Context) error { _, err := dr.Create(ctx, testData); return err },
|
||
|
func(ctx context.Context) error { _, err := dr.Set(ctx, testData); return err },
|
||
|
func(ctx context.Context) error { _, err := dr.Delete(ctx); return err },
|
||
|
func(ctx context.Context) error {
|
||
|
_, err := dr.Update(ctx, []Update{{FieldPath: []string{"*"}, Value: 1}})
|
||
|
return err
|
||
|
},
|
||
|
func(ctx context.Context) error { it := c.Collections(ctx); _, err := it.Next(); return err },
|
||
|
func(ctx context.Context) error { it := dr.Collections(ctx); _, err := it.Next(); return err },
|
||
|
func(ctx context.Context) error {
|
||
|
_, err := c.Batch().Set(dr, testData).Commit(ctx)
|
||
|
return err
|
||
|
},
|
||
|
func(ctx context.Context) error {
|
||
|
it := c.Collection("C").Documents(ctx)
|
||
|
defer it.Stop()
|
||
|
_, err := it.Next()
|
||
|
return err
|
||
|
},
|
||
|
} {
|
||
|
srv.reset()
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
srv.addRPC(rollbackReq, &empty.Empty{})
|
||
|
err = c.RunTransaction(ctx, func(ctx context.Context, _ *Transaction) error {
|
||
|
return op(ctx)
|
||
|
})
|
||
|
if got, want := err, errNonTransactionalOp; got != want {
|
||
|
t.Errorf("#%d: got <%v>, want <%v>", i, got, want)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestTransactionGetAll(t *testing.T) {
|
||
|
c, srv := newMock(t)
|
||
|
defer c.Close()
|
||
|
const dbPath = "projects/projectID/databases/(default)"
|
||
|
tid := []byte{1}
|
||
|
beginReq := &pb.BeginTransactionRequest{Database: dbPath}
|
||
|
beginRes := &pb.BeginTransactionResponse{Transaction: tid}
|
||
|
srv.addRPC(beginReq, beginRes)
|
||
|
req := &pb.BatchGetDocumentsRequest{
|
||
|
Database: dbPath,
|
||
|
Documents: []string{
|
||
|
dbPath + "/documents/C/a",
|
||
|
dbPath + "/documents/C/b",
|
||
|
dbPath + "/documents/C/c",
|
||
|
},
|
||
|
ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
|
||
|
}
|
||
|
err := c.RunTransaction(context.Background(), func(_ context.Context, tx *Transaction) error {
|
||
|
testGetAll(t, c, srv, dbPath,
|
||
|
func(drs []*DocumentRef) ([]*DocumentSnapshot, error) { return tx.GetAll(drs) },
|
||
|
req)
|
||
|
commitReq := &pb.CommitRequest{Database: dbPath, Transaction: tid}
|
||
|
srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
|
||
|
return nil
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|