From 480304bd87f5d3d468650e3e17dfd0ba53e9218a Mon Sep 17 00:00:00 2001 From: derek mcquay Date: Fri, 3 Jun 2016 13:42:05 -0700 Subject: [PATCH] add support for sending emails Change-Id: I9749cda3b997d70271cb4ca709a7cca82a9a0948 --- api_test.go | 222 +++++++++++++++++++++++++++++++++++++++++++--- cmd/vaind/main.go | 18 +++- db.go | 16 ++++ mail.go | 69 ++++++++++++++ server.go | 69 +++++++++----- vain.go | 10 +++ 6 files changed, 370 insertions(+), 34 deletions(-) create mode 100644 mail.go diff --git a/api_test.go b/api_test.go index 7f57892..61dd5a7 100644 --- a/api_test.go +++ b/api_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "io/ioutil" "net/http" "net/http/httptest" "net/url" @@ -23,7 +24,7 @@ func TestAdd(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, false) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") if err != nil { @@ -169,7 +170,7 @@ func TestInvalidPath(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, false) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") if err != nil { @@ -201,7 +202,7 @@ func TestCannotDuplicateExistingPath(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, false) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -247,7 +248,7 @@ func TestCannotAddExistingSubPath(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, false) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -295,7 +296,7 @@ func TestMissingRepo(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, false) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -328,7 +329,7 @@ func TestBadJson(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, false) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -361,7 +362,7 @@ func TestNoAuth(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, false) ts := httptest.NewServer(sm) u := fmt.Sprintf("%s/foo", ts.URL) @@ -390,7 +391,7 @@ func TestBadVcs(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, false) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -421,7 +422,7 @@ func TestUnsupportedMethod(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, false) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -453,7 +454,7 @@ func TestDelete(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, false) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -517,7 +518,7 @@ func TestSingleGet(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "", window) + NewServer(sm, db, nil, "", window, true) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -562,3 +563,202 @@ func TestSingleGet(t *testing.T) { } } } + +func TestRegister(t *testing.T) { + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") + } + defer done() + + sm := http.NewServeMux() + mm := &mockMail{} + NewServer(sm, db, mm, "", window, true) + ts := httptest.NewServer(sm) + + u := fmt.Sprintf("%s%s", ts.URL, prefix["register"]) + req, err := http.NewRequest("POST", u, nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + if status := resp.StatusCode; status != http.StatusBadRequest { + t.Fatalf("handler returned wrong status code: got %v want %v", + status, http.StatusBadRequest) + } + + u = fmt.Sprintf("%s%s?email=notARealEmail", ts.URL, prefix["register"]) + req, err = http.NewRequest("POST", u, nil) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + if status := resp.StatusCode; status != http.StatusBadRequest { + t.Fatalf("handler returned wrong status code: got %v want %v", + status, http.StatusBadRequest) + } + + u = fmt.Sprintf("%s%s?email=fake@example.com", ts.URL, prefix["register"]) + req, err = http.NewRequest("POST", u, nil) + _, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + req, err = http.NewRequest("GET", mm.msg, nil) + _, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + _, err = db.user("fake@example.com") + if err != nil { + t.Fatalf("user was no correctly added to database: %v", err) + } +} + +func TestRoundTrip(t *testing.T) { + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") + } + defer done() + + sm := http.NewServeMux() + mm := &mockMail{} + NewServer(sm, db, mm, "", window, true) + ts := httptest.NewServer(sm) + + u := fmt.Sprintf("%s%s?email=fake@example.com", ts.URL, prefix["register"]) + req, err := http.NewRequest("POST", u, nil) + _, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + req, err = http.NewRequest("GET", mm.msg, nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + _, err = db.user("fake@example.com") + if err != nil { + t.Fatalf("user was no correctly added to database: %v", err) + } + bs, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to parse response body: %v", err) + } + tok := strings.Trim(string(bs), "new token: ") + + u = fmt.Sprintf("%s/foo", ts.URL) + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err = http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + if got, want := len(db.Pkgs()), 1; got != want { + t.Fatalf("pkgs should have something in it; got %d, want %d", got, want) + } +} + +func TestForgot(t *testing.T) { + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") + } + defer done() + + sm := http.NewServeMux() + mm := &mockMail{} + NewServer(sm, db, mm, "", window, true) + ts := httptest.NewServer(sm) + + //try to do forget before user is added + u := fmt.Sprintf("%s%s?email=fake@example.com", ts.URL, prefix["forgot"]) + req, err := http.NewRequest("POST", u, nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + if status := resp.StatusCode; status != http.StatusNotFound { + t.Fatalf("handler returned wrong status code: got %v want %v", + status, http.StatusBadRequest) + } + + u = fmt.Sprintf("%s%s?email=notARealEmail", ts.URL, prefix["forgot"]) + req, err = http.NewRequest("POST", u, nil) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + if status := resp.StatusCode; status != http.StatusBadRequest { + t.Fatalf("handler returned wrong status code: got %v want %v", + status, http.StatusBadRequest) + } + + //register a new user + u = fmt.Sprintf("%s%s?email=fake@example.com", ts.URL, prefix["register"]) + req, err = http.NewRequest("POST", u, nil) + _, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + req, err = http.NewRequest("GET", mm.msg, nil) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + //check database for new user + _, err = db.user("fake@example.com") + if err != nil { + t.Fatalf("user was no correctly added to database: %v", err) + } + bs, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to parse response body: %v", err) + } + iniTok := strings.Trim(string(bs), "new token: ") + + //get new token for user (using forgot) + u = fmt.Sprintf("%s%s?email=fake@example.com", ts.URL, prefix["forgot"]) + req, err = http.NewRequest("POST", u, nil) + _, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + req, err = http.NewRequest("GET", mm.msg, nil) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + ft, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to parse response body: %v", err) + } + recTok := strings.Trim(string(ft), "new token: ") + + if iniTok == recTok { + t.Fatalf("tokens should not be the same; old token %s, new token %s", iniTok, recTok) + } + + //add new pkg using new token + u = fmt.Sprintf("%s/bar", ts.URL) + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err = http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", recTok)) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + if got, want := len(db.Pkgs()), 1; got != want { + t.Fatalf("pkgs should have something in it; got %d, want %d", got, want) + } +} diff --git a/cmd/vaind/main.go b/cmd/vaind/main.go index 9a2d942..1e3fb96 100644 --- a/cmd/vaind/main.go +++ b/cmd/vaind/main.go @@ -55,7 +55,8 @@ import ( const usage = "vaind [init] " type config struct { - Port int + Port int + Insecure bool Cert string Key string @@ -63,6 +64,11 @@ type config struct { Static string EmailTimeout time.Duration `envconfig:"email_timeout"` + + SMTPHost string `envconfig:"smtp_host"` + SMTPPort int `envconfig:"smtp_port"` + + From string } func main() { @@ -101,6 +107,7 @@ func main() { c := &config{ Port: 4040, EmailTimeout: 5 * time.Minute, + SMTPPort: 25, } if err := envconfig.Process("vain", c); err != nil { fmt.Fprintf(os.Stderr, "problem processing environment: %v", err) @@ -114,6 +121,13 @@ func main() { os.Exit(0) } } + + m, err := vain.NewEmail(c.From, c.SMTPHost, c.SMTPPort) + if err != nil { + fmt.Fprintf(os.Stderr, "problem initializing mailer: %v", err) + os.Exit(1) + } + hostname := "localhost" if hn, err := os.Hostname(); err != nil { log.Printf("problem getting hostname: %v", err) @@ -122,7 +136,7 @@ func main() { } log.Printf("serving at: http://%s:%d/", hostname, c.Port) sm := http.NewServeMux() - vain.NewServer(sm, db, c.Static, c.EmailTimeout) + vain.NewServer(sm, db, m, c.Static, c.EmailTimeout, c.Insecure) addr := fmt.Sprintf(":%d", c.Port) if c.Cert == "" || c.Key == "" { diff --git a/db.go b/db.go index 3ed7384..2f7c99a 100644 --- a/db.go +++ b/db.go @@ -319,3 +319,19 @@ func (db *DB) addUser(email string) (string, error) { ) return tok, err } + +func (db *DB) user(email string) (User, error) { + u := User{} + err := db.conn.Get( + &u, + "SELECT email, token, registered, requested FROM users WHERE email = ?", + email, + ) + if err == sql.ErrNoRows { + return User{}, verrors.HTTP{ + Message: fmt.Sprintf("could not find requested user's email: %q: %v", email, err), + Code: http.StatusNotFound, + } + } + return u, err +} diff --git a/mail.go b/mail.go new file mode 100644 index 0000000..af611d0 --- /dev/null +++ b/mail.go @@ -0,0 +1,69 @@ +package vain + +import ( + "bytes" + "fmt" + "net/mail" + "net/smtp" +) + +// A Mailer is a type that knows how to send smtp mail. +type Mailer interface { + Send(to mail.Address, subject, msg string) error +} + +// NewEmail returns *Email struct to be able to send smtp +// or an error if it can't correctly parse the email address. +func NewEmail(from, host string, port int) (*Email, error) { + if _, err := mail.ParseAddress(from); err != nil { + return nil, fmt.Errorf("can't parse an email address for 'from': %v", err) + } + r := &Email{ + host: host, + port: port, + from: from, + } + return r, nil +} + +// Email stores information required to use smtp. +type Email struct { + host string + port int + from string +} + +// Send sends a smtp email using the host and port in the Email struct and +//returns an error if there was a problem sending the email. +func (e Email) Send(to mail.Address, subject, msg string) error { + c, err := smtp.Dial(fmt.Sprintf("%s:%d", e.host, e.port)) + if err != nil { + return fmt.Errorf("couldn't dial mail server: %v", err) + } + defer c.Close() + if err := c.Mail(e.from); err != nil { + return err + } + if err := c.Rcpt(to.String()); err != nil { + return err + } + wc, err := c.Data() + if err != nil { + return fmt.Errorf("problem sending mail: %v", err) + } + buf := bytes.NewBufferString("Subject: " + subject + "\n\n" + msg) + buf.WriteTo(wc) + if err := c.Quit(); err != nil { + return nil + } + return err +} + +type mockMail struct { + msg string +} + +func (m *mockMail) Send(to mail.Address, subject, msg string) error { + m.msg = msg + return nil +} diff --git a/server.go b/server.go index 94d2e73..b4139fb 100644 --- a/server.go +++ b/server.go @@ -3,7 +3,6 @@ package vain import ( "encoding/json" "fmt" - "log" "net/http" "net/mail" "strings" @@ -16,6 +15,7 @@ import ( ) const apiPrefix = "/api/v0/" +const emailSubject = "your api token" var prefix map[string]string @@ -29,22 +29,26 @@ func init() { } } -// NewServer populates a server, adds the routes, and returns it for use. -func NewServer(sm *http.ServeMux, store *DB, static string, emailTimeout time.Duration) *Server { - s := &Server{ - db: store, - static: static, - emailTimeout: emailTimeout, - } - addRoutes(sm, s) - return s -} - // Server serves up the http. type Server struct { db *DB static string emailTimeout time.Duration + mail Mailer + insecure bool +} + +// NewServer populates a server, adds the routes, and returns it for use. +func NewServer(sm *http.ServeMux, store *DB, m Mailer, static string, emailTimeout time.Duration, insecure bool) *Server { + s := &Server{ + db: store, + static: static, + emailTimeout: emailTimeout, + mail: m, + insecure: insecure, + } + addRoutes(sm, s) + return s } func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -148,27 +152,37 @@ func (s *Server) register(w http.ResponseWriter, req *http.Request) { return } - addr := email[0] - if _, err := mail.ParseAddress(addr); err != nil { + addr, err := mail.ParseAddress(email[0]) + if err != nil { http.Error(w, fmt.Sprintf("invalid email detected: %v", err), http.StatusBadRequest) return } - tok, err := s.db.Register(addr) + tok, err := s.db.Register(addr.Address) if err := verrors.ToHTTP(err); err != nil { http.Error(w, err.Message, err.Code) return } + proto := "https" - if req.TLS == nil { + if s.insecure { proto = "http" } - log.Printf("%s://%s/api/v0/confirm/%+v", proto, req.Host, tok) resp := struct { Msg string `json:"msg"` }{ Msg: "please check your email\n", } + + err = s.mail.Send( + *addr, + "your api string", + fmt.Sprintf("%s://%s/api/v0/confirm/%+v", proto, req.Host, tok), + ) + if err != nil { + resp.Msg = fmt.Sprintf("problem sending email: %v", err) + w.WriteHeader(http.StatusInternalServerError) + } w.Header().Set("Content-type", "application/json") json.NewEncoder(w).Encode(resp) } @@ -196,23 +210,36 @@ func (s *Server) forgot(w http.ResponseWriter, req *http.Request) { return } - addr := email[0] - if _, err := mail.ParseAddress(addr); err != nil { + addr, err := mail.ParseAddress(email[0]) + if err != nil { http.Error(w, fmt.Sprintf("invalid email detected: %v", err), http.StatusBadRequest) return } - tok, err := s.db.forgot(addr, s.emailTimeout) + tok, err := s.db.forgot(addr.Address, s.emailTimeout) if err := verrors.ToHTTP(err); err != nil { http.Error(w, err.Message, err.Code) return } - log.Printf("http://%s/api/v0/confirm/%+v", req.Host, tok) + proto := "https" + if s.insecure { + proto = "http" + } resp := struct { Msg string `json:"msg"` }{ Msg: "please check your email\n", } + + err = s.mail.Send( + *addr, + emailSubject, + fmt.Sprintf("%s://%s/api/v0/confirm/%+v", proto, req.Host, tok), + ) + if err != nil { + resp.Msg = fmt.Sprintf("problem sending email: %v", err) + w.WriteHeader(http.StatusInternalServerError) + } w.Header().Set("Content-type", "application/json") json.NewEncoder(w).Encode(resp) } diff --git a/vain.go b/vain.go index 07d35cd..146721a 100644 --- a/vain.go +++ b/vain.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "strings" + "time" ) var vcss = map[string]bool{ @@ -41,6 +42,15 @@ type Package struct { Ns string `json:"-"` } +// User stores the information about a user including email used, their +// token, whether they have registerd and the requested timestamp +type User struct { + Email string + Token string + Registered bool + Requested time.Time +} + func (p Package) String() string { return fmt.Sprintf( "",