323 lines
9.0 KiB
Go
323 lines
9.0 KiB
Go
// Copyright 2017 Google Inc. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package pubsub
|
|
|
|
// This file provides a fake/mock in-memory pubsub server.
|
|
|
|
import (
|
|
"io"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"cloud.google.com/go/internal/testutil"
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/golang/protobuf/ptypes"
|
|
durpb "github.com/golang/protobuf/ptypes/duration"
|
|
emptypb "github.com/golang/protobuf/ptypes/empty"
|
|
"golang.org/x/net/context"
|
|
pb "google.golang.org/genproto/googleapis/pubsub/v1"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
type fakeServer struct {
|
|
pb.PublisherServer
|
|
pb.SubscriberServer
|
|
|
|
Addr string
|
|
|
|
mu sync.Mutex
|
|
Acked map[string]bool // acked message IDs
|
|
Deadlines map[string]int32 // deadlines by message ID
|
|
pullResponses []*pullResponse
|
|
wg sync.WaitGroup
|
|
subs map[string]*pb.Subscription
|
|
topics map[string]*pb.Topic
|
|
}
|
|
|
|
type pullResponse struct {
|
|
msgs []*pb.ReceivedMessage
|
|
err error
|
|
}
|
|
|
|
func newFakeServer() (*fakeServer, error) {
|
|
srv, err := testutil.NewServer()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
fake := &fakeServer{
|
|
Addr: srv.Addr,
|
|
Acked: map[string]bool{},
|
|
Deadlines: map[string]int32{},
|
|
subs: map[string]*pb.Subscription{},
|
|
topics: map[string]*pb.Topic{},
|
|
}
|
|
pb.RegisterPublisherServer(srv.Gsrv, fake)
|
|
pb.RegisterSubscriberServer(srv.Gsrv, fake)
|
|
srv.Start()
|
|
return fake, nil
|
|
}
|
|
|
|
// Each call to addStreamingPullMessages results in one StreamingPullResponse.
|
|
func (s *fakeServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) {
|
|
s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil})
|
|
}
|
|
|
|
func (s *fakeServer) addStreamingPullError(err error) {
|
|
s.pullResponses = append(s.pullResponses, &pullResponse{nil, err})
|
|
}
|
|
|
|
func (s *fakeServer) wait() {
|
|
s.wg.Wait()
|
|
}
|
|
|
|
func (s *fakeServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error {
|
|
s.wg.Add(1)
|
|
defer s.wg.Done()
|
|
errc := make(chan error, 1)
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
for {
|
|
req, err := stream.Recv()
|
|
if err != nil {
|
|
errc <- err
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
for _, id := range req.AckIds {
|
|
s.Acked[id] = true
|
|
}
|
|
for i, id := range req.ModifyDeadlineAckIds {
|
|
s.Deadlines[id] = req.ModifyDeadlineSeconds[i]
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
}()
|
|
// Send responses.
|
|
for {
|
|
s.mu.Lock()
|
|
if len(s.pullResponses) == 0 {
|
|
s.mu.Unlock()
|
|
// Nothing to send, so wait for the client to shut down the stream.
|
|
err := <-errc // a real error, or at least EOF
|
|
if err == io.EOF {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
pr := s.pullResponses[0]
|
|
s.pullResponses = s.pullResponses[1:]
|
|
s.mu.Unlock()
|
|
if pr.err != nil {
|
|
// Add a slight delay to ensure the server receives any
|
|
// messages en route from the client before shutting down the stream.
|
|
// This reduces flakiness of tests involving retry.
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
if pr.err == io.EOF {
|
|
return nil
|
|
}
|
|
if pr.err != nil {
|
|
return pr.err
|
|
}
|
|
// Return any error from Recv.
|
|
select {
|
|
case err := <-errc:
|
|
return err
|
|
default:
|
|
}
|
|
res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs}
|
|
if err := stream.Send(res); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
const (
|
|
minMessageRetentionDuration = 10 * time.Minute
|
|
maxMessageRetentionDuration = 168 * time.Hour
|
|
)
|
|
|
|
var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration)
|
|
|
|
func checkMRD(pmrd *durpb.Duration) error {
|
|
mrd, err := ptypes.Duration(pmrd)
|
|
if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration {
|
|
return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func checkAckDeadline(ads int32) error {
|
|
if ads < 10 || ads > 600 {
|
|
// PubSub service returns Unknown.
|
|
return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *fakeServer) CreateSubscription(ctx context.Context, sub *pb.Subscription) (*pb.Subscription, error) {
|
|
if s.subs[sub.Name] != nil {
|
|
return nil, status.Errorf(codes.AlreadyExists, "subscription %q", sub.Name)
|
|
}
|
|
sub2 := proto.Clone(sub).(*pb.Subscription)
|
|
if err := checkAckDeadline(sub.AckDeadlineSeconds); err != nil {
|
|
return nil, err
|
|
}
|
|
if sub.MessageRetentionDuration == nil {
|
|
sub2.MessageRetentionDuration = defaultMessageRetentionDuration
|
|
}
|
|
if err := checkMRD(sub2.MessageRetentionDuration); err != nil {
|
|
return nil, err
|
|
}
|
|
if sub.PushConfig == nil {
|
|
sub2.PushConfig = &pb.PushConfig{}
|
|
}
|
|
s.subs[sub.Name] = sub2
|
|
return sub2, nil
|
|
}
|
|
|
|
func (s *fakeServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
|
|
if sub := s.subs[req.Subscription]; sub != nil {
|
|
return sub, nil
|
|
}
|
|
return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
|
|
}
|
|
|
|
func (s *fakeServer) UpdateSubscription(ctx context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
|
|
sub := s.subs[req.Subscription.Name]
|
|
if sub == nil {
|
|
return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription.Name)
|
|
}
|
|
for _, path := range req.UpdateMask.Paths {
|
|
switch path {
|
|
case "push_config":
|
|
sub.PushConfig = req.Subscription.PushConfig
|
|
|
|
case "ack_deadline_seconds":
|
|
a := req.Subscription.AckDeadlineSeconds
|
|
if err := checkAckDeadline(a); err != nil {
|
|
return nil, err
|
|
}
|
|
sub.AckDeadlineSeconds = a
|
|
|
|
case "retain_acked_messages":
|
|
sub.RetainAckedMessages = req.Subscription.RetainAckedMessages
|
|
|
|
case "message_retention_duration":
|
|
if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil {
|
|
return nil, err
|
|
}
|
|
sub.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
|
|
|
|
// TODO(jba): labels
|
|
default:
|
|
return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
|
|
}
|
|
}
|
|
return sub, nil
|
|
}
|
|
|
|
func (s *fakeServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
|
|
if s.subs[req.Subscription] == nil {
|
|
return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
|
|
}
|
|
delete(s.subs, req.Subscription)
|
|
return &emptypb.Empty{}, nil
|
|
}
|
|
|
|
func (s *fakeServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
|
|
if s.topics[t.Name] != nil {
|
|
return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
|
|
}
|
|
t2 := proto.Clone(t).(*pb.Topic)
|
|
s.topics[t.Name] = t2
|
|
return t2, nil
|
|
}
|
|
|
|
func (s *fakeServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
|
|
if t := s.topics[req.Topic]; t != nil {
|
|
return t, nil
|
|
}
|
|
return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
|
|
}
|
|
|
|
func (s *fakeServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
|
|
if s.topics[req.Topic] == nil {
|
|
return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
|
|
}
|
|
delete(s.topics, req.Topic)
|
|
return &emptypb.Empty{}, nil
|
|
}
|
|
|
|
func (s *fakeServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
|
|
var names []string
|
|
for n := range s.topics {
|
|
if strings.HasPrefix(n, req.Project) {
|
|
names = append(names, n)
|
|
}
|
|
}
|
|
sort.Strings(names)
|
|
from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res := &pb.ListTopicsResponse{NextPageToken: nextToken}
|
|
for i := from; i < to; i++ {
|
|
res.Topics = append(res.Topics, s.topics[names[i]])
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func (s *fakeServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
|
|
var names []string
|
|
for _, sub := range s.subs {
|
|
if strings.HasPrefix(sub.Name, req.Project) {
|
|
names = append(names, sub.Name)
|
|
}
|
|
}
|
|
sort.Strings(names)
|
|
from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
|
|
for i := from; i < to; i++ {
|
|
res.Subscriptions = append(res.Subscriptions, s.subs[names[i]])
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func (s *fakeServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
|
|
var names []string
|
|
for _, sub := range s.subs {
|
|
if sub.Topic == req.Topic {
|
|
names = append(names, sub.Name)
|
|
}
|
|
}
|
|
sort.Strings(names)
|
|
from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &pb.ListTopicSubscriptionsResponse{
|
|
Subscriptions: names[from:to],
|
|
NextPageToken: nextToken,
|
|
}, nil
|
|
}
|