gate forget email frequency.
Change-Id: Id0f3bd2ec7c6714d23f9989a341855da5c8aa1bf
This commit is contained in:
parent
11f88feef0
commit
c99f57527f
23
api_test.go
23
api_test.go
@ -10,8 +10,11 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const window = 5 * time.Minute
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
db, done := testDB(t)
|
||||
if db == nil {
|
||||
@ -20,7 +23,7 @@ func TestAdd(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
sm := http.NewServeMux()
|
||||
NewServer(sm, db, "")
|
||||
NewServer(sm, db, "", window)
|
||||
ts := httptest.NewServer(sm)
|
||||
tok, err := db.addUser("sm@example.org")
|
||||
if err != nil {
|
||||
@ -166,7 +169,7 @@ func TestInvalidPath(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
sm := http.NewServeMux()
|
||||
NewServer(sm, db, "")
|
||||
NewServer(sm, db, "", window)
|
||||
ts := httptest.NewServer(sm)
|
||||
tok, err := db.addUser("sm@example.org")
|
||||
if err != nil {
|
||||
@ -198,7 +201,7 @@ func TestCannotDuplicateExistingPath(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
sm := http.NewServeMux()
|
||||
NewServer(sm, db, "")
|
||||
NewServer(sm, db, "", window)
|
||||
ts := httptest.NewServer(sm)
|
||||
|
||||
tok, err := db.addUser("sm@example.org")
|
||||
@ -244,7 +247,7 @@ func TestCannotAddExistingSubPath(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
sm := http.NewServeMux()
|
||||
NewServer(sm, db, "")
|
||||
NewServer(sm, db, "", window)
|
||||
ts := httptest.NewServer(sm)
|
||||
|
||||
tok, err := db.addUser("sm@example.org")
|
||||
@ -292,7 +295,7 @@ func TestMissingRepo(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
sm := http.NewServeMux()
|
||||
NewServer(sm, db, "")
|
||||
NewServer(sm, db, "", window)
|
||||
ts := httptest.NewServer(sm)
|
||||
|
||||
tok, err := db.addUser("sm@example.org")
|
||||
@ -325,7 +328,7 @@ func TestBadJson(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
sm := http.NewServeMux()
|
||||
NewServer(sm, db, "")
|
||||
NewServer(sm, db, "", window)
|
||||
ts := httptest.NewServer(sm)
|
||||
|
||||
tok, err := db.addUser("sm@example.org")
|
||||
@ -358,7 +361,7 @@ func TestNoAuth(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
sm := http.NewServeMux()
|
||||
NewServer(sm, db, "")
|
||||
NewServer(sm, db, "", window)
|
||||
ts := httptest.NewServer(sm)
|
||||
|
||||
u := fmt.Sprintf("%s/foo", ts.URL)
|
||||
@ -387,7 +390,7 @@ func TestBadVcs(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
sm := http.NewServeMux()
|
||||
NewServer(sm, db, "")
|
||||
NewServer(sm, db, "", window)
|
||||
ts := httptest.NewServer(sm)
|
||||
|
||||
tok, err := db.addUser("sm@example.org")
|
||||
@ -418,7 +421,7 @@ func TestUnsupportedMethod(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
sm := http.NewServeMux()
|
||||
NewServer(sm, db, "")
|
||||
NewServer(sm, db, "", window)
|
||||
ts := httptest.NewServer(sm)
|
||||
|
||||
tok, err := db.addUser("sm@example.org")
|
||||
@ -450,7 +453,7 @@ func TestDelete(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
sm := http.NewServeMux()
|
||||
NewServer(sm, db, "")
|
||||
NewServer(sm, db, "", window)
|
||||
ts := httptest.NewServer(sm)
|
||||
|
||||
tok, err := db.addUser("sm@example.org")
|
||||
|
@ -45,6 +45,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"mcquay.me/vain"
|
||||
|
||||
@ -60,6 +61,8 @@ type config struct {
|
||||
Key string
|
||||
|
||||
Static string
|
||||
|
||||
EmailTimeout time.Duration `envconfig:"email_timeout"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
@ -97,11 +100,13 @@ func main() {
|
||||
|
||||
c := &config{
|
||||
Port: 4040,
|
||||
EmailTimeout: 5 * time.Minute,
|
||||
}
|
||||
if err := envconfig.Process("vain", c); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "problem processing environment: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
log.Printf("%+v", c)
|
||||
if len(os.Args) > 1 {
|
||||
switch os.Args[1] {
|
||||
case "env", "e", "help", "h":
|
||||
@ -117,7 +122,7 @@ func main() {
|
||||
}
|
||||
log.Printf("serving at: http://%s:%d/", hostname, c.Port)
|
||||
sm := http.NewServeMux()
|
||||
vain.NewServer(sm, db, c.Static)
|
||||
vain.NewServer(sm, db, c.Static, c.EmailTimeout)
|
||||
addr := fmt.Sprintf(":%d", c.Port)
|
||||
|
||||
if c.Cert == "" || c.Key == "" {
|
||||
|
42
db.go
42
db.go
@ -245,15 +245,47 @@ func (db *DB) Confirm(token string) (string, error) {
|
||||
return newToken, nil
|
||||
}
|
||||
|
||||
func (db *DB) forgot(email string) (string, error) {
|
||||
var token string
|
||||
if err := db.conn.Get(&token, "SELECT token FROM users WHERE email = ?", email); err != nil {
|
||||
func (db *DB) forgot(email string, window time.Duration) (string, error) {
|
||||
txn, err := db.conn.Beginx()
|
||||
if err != nil {
|
||||
return "", verrors.HTTP{
|
||||
Message: fmt.Sprintf("could not search for email %q in db: %v", email, err),
|
||||
Message: fmt.Sprintf("problem creating transaction: %v", err),
|
||||
Code: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
return token, nil
|
||||
defer func() {
|
||||
if err != nil {
|
||||
txn.Rollback()
|
||||
} else {
|
||||
txn.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
out := struct {
|
||||
Token string
|
||||
Requested time.Time
|
||||
}{}
|
||||
if err = txn.Get(&out, "SELECT token, requested FROM users WHERE email = ?", email); err != nil {
|
||||
return "", verrors.HTTP{
|
||||
Message: fmt.Sprintf("could not find email %q in db", email),
|
||||
Code: http.StatusNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
if out.Requested.After(time.Now()) {
|
||||
return "", verrors.HTTP{
|
||||
Message: fmt.Sprintf("rate limit hit for %q; try again in %0.2f mins", email, out.Requested.Sub(time.Now()).Minutes()),
|
||||
Code: http.StatusTooManyRequests,
|
||||
}
|
||||
}
|
||||
_, err = txn.Exec("UPDATE users SET requested = ? WHERE email = ?", time.Now().Add(window), email)
|
||||
if err != nil {
|
||||
return "", verrors.HTTP{
|
||||
Message: fmt.Sprintf("could not update last requested time for %q: %v", email, err),
|
||||
Code: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
return out.Token, nil
|
||||
}
|
||||
|
||||
func (db *DB) addUser(email string) (string, error) {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/elazarl/go-bindata-assetfs"
|
||||
|
||||
@ -28,10 +29,11 @@ func init() {
|
||||
}
|
||||
|
||||
// NewServer populates a server, adds the routes, and returns it for use.
|
||||
func NewServer(sm *http.ServeMux, store *DB, static string) *Server {
|
||||
func NewServer(sm *http.ServeMux, store *DB, static string, emailTimeout time.Duration) *Server {
|
||||
s := &Server{
|
||||
db: store,
|
||||
static: static,
|
||||
emailTimeout: emailTimeout,
|
||||
}
|
||||
addRoutes(sm, s)
|
||||
return s
|
||||
@ -41,6 +43,7 @@ func NewServer(sm *http.ServeMux, store *DB, static string) *Server {
|
||||
type Server struct {
|
||||
db *DB
|
||||
static string
|
||||
emailTimeout time.Duration
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
@ -183,7 +186,7 @@ func (s *Server) forgot(w http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
tok, err := s.db.forgot(addr)
|
||||
tok, err := s.db.forgot(addr, s.emailTimeout)
|
||||
if err := verrors.ToHTTP(err); err != nil {
|
||||
http.Error(w, err.Message, err.Code)
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user