diff --git a/cmd/hwtc/main.go b/cmd/hwtc/main.go index ee80b92..da462dc 100644 --- a/cmd/hwtc/main.go +++ b/cmd/hwtc/main.go @@ -7,6 +7,9 @@ import ( "os" "strings" + "github.com/twitchtv/twirp" + + "mcquay.me/hwt" pb "mcquay.me/hwt/rpc/hwt" ) @@ -20,8 +23,17 @@ func main() { c := pb.NewHelloWorldProtobufClient(fmt.Sprintf("http://%s", os.Args[1]), &http.Client{}) + h := http.Header{} + h.Set("sm-auth", hwt.PSK) + ctx := context.Background() + ctx, err := twirp.WithHTTPRequestHeaders(ctx, h) + if err != nil { + fmt.Fprintf(os.Stderr, "setting twirp headers: %v\n", err) + os.Exit(1) + } + for i := 0; ; i++ { - resp, err := c.Hello(context.Background(), &pb.HelloReq{Subject: strings.Join(os.Args[2:], " ")}) + resp, err := c.Hello(ctx, &pb.HelloReq{Subject: strings.Join(os.Args[2:], " ")}) if err != nil { fmt.Fprintf(os.Stderr, "hello: %#v\n", err) os.Exit(1) diff --git a/cmd/hwtd/main.go b/cmd/hwtd/main.go index cbea177..9b44a6d 100644 --- a/cmd/hwtd/main.go +++ b/cmd/hwtd/main.go @@ -26,7 +26,7 @@ func main() { hs := hwt.NewMetricsHooks(metrics.HTTPLatency) th := pb.NewHelloWorldServer(s, hs) sm := http.NewServeMux() - sm.Handle("/", th) + sm.HandleFunc("/", hwt.Auth(th.ServeHTTP)) sm.Handle("/metrics", promhttp.Handler()) if err := http.ListenAndServe(":8080", sm); err != nil { log.Fatalf("listen and serve: %v", err) diff --git a/hwt.go b/hwt.go index 01058cf..e1f5e9d 100644 --- a/hwt.go +++ b/hwt.go @@ -17,8 +17,13 @@ func (s *Server) Hello(ctx context.Context, req *pb.HelloReq) (*pb.HelloResp, er return nil, twirp.RequiredArgumentError("subject") } + u, err := getUser(ctx) + if err != nil { + return nil, twirp.InternalErrorWith(err) + } + r := &pb.HelloResp{ - Text: fmt.Sprintf("echo: %v", req.Subject), + Text: fmt.Sprintf("%s said: %v", u, req.Subject), Hostname: s.Hostname, } return r, nil diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..0097a7f --- /dev/null +++ b/middleware.go @@ -0,0 +1,32 @@ +package hwt + +import ( + "context" + "errors" + "net/http" +) + +const PSK = "some key" + +var reqUserKey = new(int) + +func Auth(h http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + k := req.Header.Get("sm-auth") + if k == "" { + w.Header().Set("www-authenticate", "sm-auth") + http.Error(w, "missing/invalid key", http.StatusUnauthorized) + return + } + ctx := context.WithValue(req.Context(), reqUserKey, "valid user") + h(w, req.WithContext(ctx)) + } +} + +func getUser(ctx context.Context) (string, error) { + u, ok := ctx.Value(reqUserKey).(string) + if !ok { + return "", errors.New("user key not found in context") + } + return u, nil +}