diff --git a/api_test.go b/api_test.go index f25b8ac..4b98358 100644 --- a/api_test.go +++ b/api_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "io/ioutil" "net/http" "net/http/httptest" "net/url" @@ -508,3 +509,92 @@ func TestDelete(t *testing.T) { } } } + +func TestRegister(t *testing.T) { + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") + } + defer done() + + sm := http.NewServeMux() + mm := &mockMail{} + NewServer(sm, db, mm, "", window, false) + ts := httptest.NewServer(sm) + + u := fmt.Sprintf("%s%s?email=fake@example.com", ts.URL, prefix["register"]) + req, err := http.NewRequest("POST", u, nil) + _, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + req, err = http.NewRequest( + "GET", + strings.Replace(mm.msg, "https", "http", -1), //ugly hack to get around https vs http + nil, + ) + _, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + err = db.hasUser("fake@example.com") + if err != nil { + t.Fatalf("user was no correctly added to database", err) + } +} + +func TestRoundTrip(t *testing.T) { + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") + } + defer done() + + sm := http.NewServeMux() + mm := &mockMail{} + NewServer(sm, db, mm, "", window, false) + ts := httptest.NewServer(sm) + + u := fmt.Sprintf("%s%s?email=fake@example.com", ts.URL, prefix["register"]) + req, err := http.NewRequest("POST", u, nil) + _, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + req, err = http.NewRequest( + "GET", + strings.Replace(mm.msg, "https", "http", -1), //ugly hack to get around https vs http + nil, + ) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + err = db.hasUser("fake@example.com") + if err != nil { + t.Fatalf("user was no correctly added to database", err) + } + bs, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to parse response body", err) + } + tok := strings.Trim(string(bs), "new token: ") + + u = fmt.Sprintf("%s/foo", ts.URL) + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err = http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("couldn't POST: %v", err) + } + + if got, want := len(db.Pkgs()), 1; got != want { + t.Fatalf("pkgs should have something in it; got %d, want %d", got, want) + } +} diff --git a/db.go b/db.go index 197888b..df72de0 100644 --- a/db.go +++ b/db.go @@ -312,3 +312,29 @@ func (db *DB) addUser(email string) (string, error) { ) return tok, err } + +func (db *DB) hasUser(email string) error { + result, err := db.conn.Query( + "SELECT EXISTS(SELECT 1 FROM users WHERE email = ? LIMIT 1)", + email, + ) + if err != nil { + return verrors.HTTP{ + Message: fmt.Sprintf("could not find requested user's email: %q: %v", email, err), + Code: http.StatusInternalServerError, + } + } + var exists string + for result.Next() { + if err := result.Scan(&exists); err != nil { + log.Fatal(err) + } + } + if exists == "0" { + return verrors.HTTP{ + Message: fmt.Sprintf("could not find requested user's email: %q: %v", email, err), + Code: http.StatusInternalServerError, + } + } + return nil +} diff --git a/mail.go b/mail.go index b410d6d..e872374 100644 --- a/mail.go +++ b/mail.go @@ -52,3 +52,12 @@ func (e Email) Send(to mail.Address, msg string) error { } return err } + +type mockMail struct { + msg string +} + +func (m *mockMail) Send(to mail.Address, msg string) error { + m.msg = msg + return nil +}