1
0
forked from sm/vain

gate forget email frequency.

Change-Id: Id0f3bd2ec7c6714d23f9989a341855da5c8aa1bf
This commit is contained in:
Stephen McQuay 2016-05-14 21:30:58 -07:00
parent 11f88feef0
commit c99f57527f
No known key found for this signature in database
GPG Key ID: 1ABF428F71BAFC3D
4 changed files with 66 additions and 23 deletions

View File

@ -10,8 +10,11 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time"
) )
const window = 5 * time.Minute
func TestAdd(t *testing.T) { func TestAdd(t *testing.T) {
db, done := testDB(t) db, done := testDB(t)
if db == nil { if db == nil {
@ -20,7 +23,7 @@ func TestAdd(t *testing.T) {
defer done() defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
NewServer(sm, db, "") NewServer(sm, db, "", window)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org") tok, err := db.addUser("sm@example.org")
if err != nil { if err != nil {
@ -166,7 +169,7 @@ func TestInvalidPath(t *testing.T) {
defer done() defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
NewServer(sm, db, "") NewServer(sm, db, "", window)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org") tok, err := db.addUser("sm@example.org")
if err != nil { if err != nil {
@ -198,7 +201,7 @@ func TestCannotDuplicateExistingPath(t *testing.T) {
defer done() defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
NewServer(sm, db, "") NewServer(sm, db, "", window)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org") tok, err := db.addUser("sm@example.org")
@ -244,7 +247,7 @@ func TestCannotAddExistingSubPath(t *testing.T) {
defer done() defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
NewServer(sm, db, "") NewServer(sm, db, "", window)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org") tok, err := db.addUser("sm@example.org")
@ -292,7 +295,7 @@ func TestMissingRepo(t *testing.T) {
defer done() defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
NewServer(sm, db, "") NewServer(sm, db, "", window)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org") tok, err := db.addUser("sm@example.org")
@ -325,7 +328,7 @@ func TestBadJson(t *testing.T) {
defer done() defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
NewServer(sm, db, "") NewServer(sm, db, "", window)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org") tok, err := db.addUser("sm@example.org")
@ -358,7 +361,7 @@ func TestNoAuth(t *testing.T) {
defer done() defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
NewServer(sm, db, "") NewServer(sm, db, "", window)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
u := fmt.Sprintf("%s/foo", ts.URL) u := fmt.Sprintf("%s/foo", ts.URL)
@ -387,7 +390,7 @@ func TestBadVcs(t *testing.T) {
defer done() defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
NewServer(sm, db, "") NewServer(sm, db, "", window)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org") tok, err := db.addUser("sm@example.org")
@ -418,7 +421,7 @@ func TestUnsupportedMethod(t *testing.T) {
defer done() defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
NewServer(sm, db, "") NewServer(sm, db, "", window)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org") tok, err := db.addUser("sm@example.org")
@ -450,7 +453,7 @@ func TestDelete(t *testing.T) {
defer done() defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
NewServer(sm, db, "") NewServer(sm, db, "", window)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org") tok, err := db.addUser("sm@example.org")

View File

@ -45,6 +45,7 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"time"
"mcquay.me/vain" "mcquay.me/vain"
@ -60,6 +61,8 @@ type config struct {
Key string Key string
Static string Static string
EmailTimeout time.Duration `envconfig:"email_timeout"`
} }
func main() { func main() {
@ -97,11 +100,13 @@ func main() {
c := &config{ c := &config{
Port: 4040, Port: 4040,
EmailTimeout: 5 * time.Minute,
} }
if err := envconfig.Process("vain", c); err != nil { if err := envconfig.Process("vain", c); err != nil {
fmt.Fprintf(os.Stderr, "problem processing environment: %v", err) fmt.Fprintf(os.Stderr, "problem processing environment: %v", err)
os.Exit(1) os.Exit(1)
} }
log.Printf("%+v", c)
if len(os.Args) > 1 { if len(os.Args) > 1 {
switch os.Args[1] { switch os.Args[1] {
case "env", "e", "help", "h": case "env", "e", "help", "h":
@ -117,7 +122,7 @@ func main() {
} }
log.Printf("serving at: http://%s:%d/", hostname, c.Port) log.Printf("serving at: http://%s:%d/", hostname, c.Port)
sm := http.NewServeMux() sm := http.NewServeMux()
vain.NewServer(sm, db, c.Static) vain.NewServer(sm, db, c.Static, c.EmailTimeout)
addr := fmt.Sprintf(":%d", c.Port) addr := fmt.Sprintf(":%d", c.Port)
if c.Cert == "" || c.Key == "" { if c.Cert == "" || c.Key == "" {

42
db.go
View File

@ -245,15 +245,47 @@ func (db *DB) Confirm(token string) (string, error) {
return newToken, nil return newToken, nil
} }
func (db *DB) forgot(email string) (string, error) { func (db *DB) forgot(email string, window time.Duration) (string, error) {
var token string txn, err := db.conn.Beginx()
if err := db.conn.Get(&token, "SELECT token FROM users WHERE email = ?", email); err != nil { if err != nil {
return "", verrors.HTTP{ return "", verrors.HTTP{
Message: fmt.Sprintf("could not search for email %q in db: %v", email, err), Message: fmt.Sprintf("problem creating transaction: %v", err),
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
} }
} }
return token, nil defer func() {
if err != nil {
txn.Rollback()
} else {
txn.Commit()
}
}()
out := struct {
Token string
Requested time.Time
}{}
if err = txn.Get(&out, "SELECT token, requested FROM users WHERE email = ?", email); err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("could not find email %q in db", email),
Code: http.StatusNotFound,
}
}
if out.Requested.After(time.Now()) {
return "", verrors.HTTP{
Message: fmt.Sprintf("rate limit hit for %q; try again in %0.2f mins", email, out.Requested.Sub(time.Now()).Minutes()),
Code: http.StatusTooManyRequests,
}
}
_, err = txn.Exec("UPDATE users SET requested = ? WHERE email = ?", time.Now().Add(window), email)
if err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("could not update last requested time for %q: %v", email, err),
Code: http.StatusInternalServerError,
}
}
return out.Token, nil
} }
func (db *DB) addUser(email string) (string, error) { func (db *DB) addUser(email string) (string, error) {

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"net/mail" "net/mail"
"strings" "strings"
"time"
"github.com/elazarl/go-bindata-assetfs" "github.com/elazarl/go-bindata-assetfs"
@ -28,10 +29,11 @@ func init() {
} }
// NewServer populates a server, adds the routes, and returns it for use. // NewServer populates a server, adds the routes, and returns it for use.
func NewServer(sm *http.ServeMux, store *DB, static string) *Server { func NewServer(sm *http.ServeMux, store *DB, static string, emailTimeout time.Duration) *Server {
s := &Server{ s := &Server{
db: store, db: store,
static: static, static: static,
emailTimeout: emailTimeout,
} }
addRoutes(sm, s) addRoutes(sm, s)
return s return s
@ -41,6 +43,7 @@ func NewServer(sm *http.ServeMux, store *DB, static string) *Server {
type Server struct { type Server struct {
db *DB db *DB
static string static string
emailTimeout time.Duration
} }
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@ -183,7 +186,7 @@ func (s *Server) forgot(w http.ResponseWriter, req *http.Request) {
return return
} }
tok, err := s.db.forgot(addr) tok, err := s.db.forgot(addr, s.emailTimeout)
if err := verrors.ToHTTP(err); err != nil { if err := verrors.ToHTTP(err); err != nil {
http.Error(w, err.Message, err.Code) http.Error(w, err.Message, err.Code)
return return