From 6aebd3476cebf8f2212fa575c22b1890b1b7628f Mon Sep 17 00:00:00 2001 From: derek mcquay Date: Thu, 25 Aug 2016 11:26:49 -0700 Subject: [PATCH] fixed oauth bug figured out it was how i was using sessions. --- cmd/chipd/main.go | 3 ++- server.go | 27 ++++++++++++++++----------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/cmd/chipd/main.go b/cmd/chipd/main.go index 80c1121..2fff5d8 100644 --- a/cmd/chipd/main.go +++ b/cmd/chipd/main.go @@ -12,6 +12,7 @@ import ( "s.mcquay.me/dm/chipmunk" + "github.com/gorilla/context" "github.com/kelseyhightower/envconfig" "github.com/spf13/cobra" ) @@ -98,7 +99,7 @@ func main() { log.Printf("serving at: http://%s:%d/", hostname, config.Port) addr := fmt.Sprintf("%s:%d", config.Host, config.Port) - err = http.ListenAndServe(addr, sm) + err = http.ListenAndServe(addr, context.ClearHandler(sm)) if err != nil { log.Printf("%+v", err) os.Exit(1) diff --git a/server.go b/server.go index 2d1aca6..9fc9440 100644 --- a/server.go +++ b/server.go @@ -77,7 +77,7 @@ func (s *Server) fakeSetup(w http.ResponseWriter, r *http.Request) { func (s *Server) tranx(w http.ResponseWriter, r *http.Request) { //TODO add back in oauth //w.Header().Set("Content-Type", "application/json") - //session, _ := store.Get(r, "creds") + //session, err := store.Get(r, "creds") //if err != nil { // http.Error(w, err.Error(), http.StatusInternalServerError) // return @@ -133,7 +133,7 @@ func (s *Server) tranx(w http.ResponseWriter, r *http.Request) { func (s *Server) costPerMonth(w http.ResponseWriter, r *http.Request) { //TODO add back in oauth //w.Header().Set("Content-Type", "application/json") - //session, _ := store.Get(r, "creds") + //session, err := store.Get(r, "creds") //if err != nil { // http.Error(w, err.Error(), http.StatusInternalServerError) // return @@ -174,7 +174,7 @@ func (s *Server) costPerMonth(w http.ResponseWriter, r *http.Request) { func (s *Server) listUsers(w http.ResponseWriter, r *http.Request) { //TODO add back in oauth //w.Header().Set("Content-Type", "application/json") - //session, _ := store.Get(r, "creds") + //session, err := store.Get(r, "creds") //if err != nil { // http.Error(w, err.Error(), http.StatusInternalServerError) // return @@ -228,7 +228,6 @@ func (s *Server) oauthCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusTemporaryRedirect) return } - defer email.Body.Close() data, _ := ioutil.ReadAll(email.Body) u := userInfo{} @@ -242,8 +241,10 @@ func (s *Server) oauthCallback(w http.ResponseWriter, r *http.Request) { if authorizedEmail(u.Email) { session, err := store.Get(r, "creds") if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + if !session.IsNew { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } session.Values["authenticated"] = true session.Values["uname"] = u.Email @@ -280,19 +281,19 @@ func (s *Server) auth(w http.ResponseWriter, r *http.Request) { http.Error(w, string(b), http.StatusUnauthorized) } -func (s *Server) logout(w http.ResponseWriter, req *http.Request) { - session, err := store.Get(req, "creds") +func (s *Server) logout(w http.ResponseWriter, r *http.Request) { + session, err := store.Get(r, "creds") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } delete(session.Values, "authenticated") delete(session.Values, "uname") - session.Save(req, w) - http.Redirect(w, req, "/", http.StatusSeeOther) + session.Save(r, w) + http.Redirect(w, r, "/", http.StatusSeeOther) } -func (s *Server) serverInfo(w http.ResponseWriter, req *http.Request) { +func (s *Server) serverInfo(w http.ResponseWriter, r *http.Request) { output := struct { Version string `json:"version"` Start string `json:"start"` @@ -309,6 +310,10 @@ func (s *Server) serverInfo(w http.ResponseWriter, req *http.Request) { func (s *Server) plist(w http.ResponseWriter, r *http.Request) { session, err := store.Get(r, "creds") if err != nil { + if session.IsNew { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } http.Error(w, err.Error(), http.StatusInternalServerError) return }