diff --git a/api_test.go b/api_test.go index 1f16ee2..6d426cc 100644 --- a/api_test.go +++ b/api_test.go @@ -10,8 +10,11 @@ import ( "net/url" "strings" "testing" + "time" ) +const window = 5 * time.Minute + func TestAdd(t *testing.T) { db, done := testDB(t) if db == nil { @@ -20,7 +23,7 @@ func TestAdd(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "") + NewServer(sm, db, "", window) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") if err != nil { @@ -166,7 +169,7 @@ func TestInvalidPath(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "") + NewServer(sm, db, "", window) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") if err != nil { @@ -198,7 +201,7 @@ func TestCannotDuplicateExistingPath(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "") + NewServer(sm, db, "", window) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -244,7 +247,7 @@ func TestCannotAddExistingSubPath(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "") + NewServer(sm, db, "", window) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -292,7 +295,7 @@ func TestMissingRepo(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "") + NewServer(sm, db, "", window) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -325,7 +328,7 @@ func TestBadJson(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "") + NewServer(sm, db, "", window) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -358,7 +361,7 @@ func TestNoAuth(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "") + NewServer(sm, db, "", window) ts := httptest.NewServer(sm) u := fmt.Sprintf("%s/foo", ts.URL) @@ -387,7 +390,7 @@ func TestBadVcs(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "") + NewServer(sm, db, "", window) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -418,7 +421,7 @@ func TestUnsupportedMethod(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "") + NewServer(sm, db, "", window) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") @@ -450,7 +453,7 @@ func TestDelete(t *testing.T) { defer done() sm := http.NewServeMux() - NewServer(sm, db, "") + NewServer(sm, db, "", window) ts := httptest.NewServer(sm) tok, err := db.addUser("sm@example.org") diff --git a/cmd/vaind/main.go b/cmd/vaind/main.go index e073478..00f310b 100644 --- a/cmd/vaind/main.go +++ b/cmd/vaind/main.go @@ -45,6 +45,7 @@ import ( "log" "net/http" "os" + "time" "mcquay.me/vain" @@ -60,6 +61,8 @@ type config struct { Key string Static string + + EmailTimeout time.Duration `envconfig:"email_timeout"` } func main() { @@ -96,12 +99,14 @@ func main() { } c := &config{ - Port: 4040, + Port: 4040, + EmailTimeout: 5 * time.Minute, } if err := envconfig.Process("vain", c); err != nil { fmt.Fprintf(os.Stderr, "problem processing environment: %v", err) os.Exit(1) } + log.Printf("%+v", c) if len(os.Args) > 1 { switch os.Args[1] { case "env", "e", "help", "h": @@ -117,7 +122,7 @@ func main() { } log.Printf("serving at: http://%s:%d/", hostname, c.Port) sm := http.NewServeMux() - vain.NewServer(sm, db, c.Static) + vain.NewServer(sm, db, c.Static, c.EmailTimeout) addr := fmt.Sprintf(":%d", c.Port) if c.Cert == "" || c.Key == "" { diff --git a/db.go b/db.go index f1c5fbc..83a12be 100644 --- a/db.go +++ b/db.go @@ -245,15 +245,47 @@ func (db *DB) Confirm(token string) (string, error) { return newToken, nil } -func (db *DB) forgot(email string) (string, error) { - var token string - if err := db.conn.Get(&token, "SELECT token FROM users WHERE email = ?", email); err != nil { +func (db *DB) forgot(email string, window time.Duration) (string, error) { + txn, err := db.conn.Beginx() + if err != nil { return "", verrors.HTTP{ - Message: fmt.Sprintf("could not search for email %q in db: %v", email, err), + Message: fmt.Sprintf("problem creating transaction: %v", err), Code: http.StatusInternalServerError, } } - return token, nil + defer func() { + if err != nil { + txn.Rollback() + } else { + txn.Commit() + } + }() + + out := struct { + Token string + Requested time.Time + }{} + if err = txn.Get(&out, "SELECT token, requested FROM users WHERE email = ?", email); err != nil { + return "", verrors.HTTP{ + Message: fmt.Sprintf("could not find email %q in db", email), + Code: http.StatusNotFound, + } + } + + if out.Requested.After(time.Now()) { + return "", verrors.HTTP{ + Message: fmt.Sprintf("rate limit hit for %q; try again in %0.2f mins", email, out.Requested.Sub(time.Now()).Minutes()), + Code: http.StatusTooManyRequests, + } + } + _, err = txn.Exec("UPDATE users SET requested = ? WHERE email = ?", time.Now().Add(window), email) + if err != nil { + return "", verrors.HTTP{ + Message: fmt.Sprintf("could not update last requested time for %q: %v", email, err), + Code: http.StatusInternalServerError, + } + } + return out.Token, nil } func (db *DB) addUser(email string) (string, error) { diff --git a/server.go b/server.go index 55bed3a..fce3dd7 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ import ( "net/http" "net/mail" "strings" + "time" "github.com/elazarl/go-bindata-assetfs" @@ -28,10 +29,11 @@ func init() { } // NewServer populates a server, adds the routes, and returns it for use. -func NewServer(sm *http.ServeMux, store *DB, static string) *Server { +func NewServer(sm *http.ServeMux, store *DB, static string, emailTimeout time.Duration) *Server { s := &Server{ - db: store, - static: static, + db: store, + static: static, + emailTimeout: emailTimeout, } addRoutes(sm, s) return s @@ -39,8 +41,9 @@ func NewServer(sm *http.ServeMux, store *DB, static string) *Server { // Server serves up the http. type Server struct { - db *DB - static string + db *DB + static string + emailTimeout time.Duration } func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -183,7 +186,7 @@ func (s *Server) forgot(w http.ResponseWriter, req *http.Request) { return } - tok, err := s.db.forgot(addr) + tok, err := s.db.forgot(addr, s.emailTimeout) if err := verrors.ToHTTP(err); err != nil { http.Error(w, err.Message, err.Code) return