1
0
forked from sm/vain

Added simple user auth

Fixes #14.

Change-Id: I748933214f43ac7298f1e93c14bb0ee881976d43
This commit is contained in:
Stephen McQuay 2016-04-11 20:43:18 -07:00 committed by Stephen McQuay (smcquay)
parent 753a225f53
commit 680eecb111
No known key found for this signature in database
GPG Key ID: 1ABF428F71BAFC3D
13 changed files with 928 additions and 452 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.db

View File

@ -13,31 +13,55 @@ import (
) )
func TestAdd(t *testing.T) { func TestAdd(t *testing.T) {
ms := NewSimpleStore("") db, done := testDB(t)
if db == nil {
t.Fatalf("could not create temp db")
}
defer done()
sm := http.NewServeMux() sm := http.NewServeMux()
_ = NewServer(sm, ms) NewServer(sm, db)
ts := httptest.NewServer(sm) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org")
if err != nil {
t.Error("failure to add user: %v", err)
}
resp, err := http.Get(ts.URL) resp, err := http.Get(ts.URL)
if err != nil { if err != nil {
t.Errorf("couldn't GET: %v", err) t.Errorf("couldn't GET: %v", err)
} }
resp.Body.Close() resp.Body.Close()
if len(ms.p) != 0 {
t.Errorf("started with something in it; got %d, want %d", len(ms.p), 0)
}
bad := ts.URL if got, want := len(db.Pkgs()), 0; got != want {
resp, err = http.Post(bad, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) t.Errorf("started with something in it; got %d, want %d", got, want)
if err != nil {
t.Errorf("couldn't POST: %v", err)
}
resp.Body.Close()
if len(ms.p) != 0 {
t.Errorf("started with something in it; got %d, want %d", len(ms.p), 0)
} }
{ {
u := fmt.Sprintf("%s/db/", ts.URL) bad := ts.URL
body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)
req, err := http.NewRequest("POST", bad, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("couldn't POST: %v", err)
}
if got, want := resp.StatusCode, http.StatusBadRequest; got != want {
buf := &bytes.Buffer{}
io.Copy(buf, resp.Body)
t.Errorf("bad request got incorrect status: got %d, want %d", got, want)
t.Log("%s", buf)
}
resp.Body.Close()
if got, want := len(db.Pkgs()), 0; got != want {
t.Errorf("started with something in it; got %d, want %d", got, want)
}
}
{
u := fmt.Sprintf("%s/%s", ts.URL, prefix["pkgs"])
resp, err := http.Get(u) resp, err := http.Get(u)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -53,34 +77,52 @@ func TestAdd(t *testing.T) {
} }
} }
u := fmt.Sprintf("%s/foo", ts.URL) {
resp, err = http.Post(u, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`))
if err != nil {
t.Errorf("couldn't POST: %v", err)
}
if len(ms.p) != 1 { u := fmt.Sprintf("%s/foo", ts.URL)
t.Errorf("storage should have something in it; got %d, want %d", len(ms.p), 1) body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)
req, err := http.NewRequest("POST", u, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("problem performing request: %v", err)
} }
buf := &bytes.Buffer{}
io.Copy(buf, resp.Body)
t.Logf("%v", buf)
resp.Body.Close()
if got, want := len(db.Pkgs()), 1; got != want {
t.Errorf("pkgs should have something in it; got %d, want %d", got, want)
}
t.Logf("packages: %v", db.Pkgs())
ur, err := url.Parse(ts.URL) ur, err := url.Parse(ts.URL)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
good := fmt.Sprintf("%s/foo", ur.Host) good := fmt.Sprintf("%s/foo", ur.Host)
p, ok := ms.p[good]
if !ok { if !db.PackageExists(good) {
t.Fatalf("did not find package for %s; should have posted a valid package", good) t.Fatalf("did not find package for %s; should have posted a valid package", good)
} }
if p.path != good { p, err := db.Package(good)
t.Errorf("package name did not go through as expected; got %q, want %q", p.path, good) t.Logf("%+v", p)
if err != nil {
t.Fatalf("problem getting package: %v", err)
} }
if want := "https://s.mcquay.me/sm/vain"; p.Repo != want { if got, want := p.Path, good; got != want {
t.Errorf("repo did not go through as expected; got %q, want %q", p.Repo, want) t.Errorf("package name did not go through as expected; got %q, want %q", got, want)
}
if got, want := p.Repo, "https://s.mcquay.me/sm/vain"; got != want {
t.Errorf("repo did not go through as expected; got %q, want %q", got, want)
} }
if got, want := p.Vcs, "git"; got != want { if got, want := p.Vcs, "git"; got != want {
t.Errorf("Vcs did not go through as expected; got %q, want %q", got, want) t.Errorf("Vcs did not go through as expected; got %q, want %q", got, want)
} }
}
resp, err = http.Get(ts.URL) resp, err = http.Get(ts.URL)
if err != nil { if err != nil {
@ -99,7 +141,7 @@ func TestAdd(t *testing.T) {
} }
{ {
u := fmt.Sprintf("%s/db/", ts.URL) u := fmt.Sprintf("%s/%s", ts.URL, prefix["pkgs"])
resp, err := http.Get(u) resp, err := http.Get(u)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -117,121 +159,248 @@ func TestAdd(t *testing.T) {
} }
func TestInvalidPath(t *testing.T) { func TestInvalidPath(t *testing.T) {
ms := NewSimpleStore("") db, done := testDB(t)
s := &Server{ if db == nil {
storage: ms, t.Fatalf("could not create temp db")
} }
ts := httptest.NewServer(s) defer done()
resp, err := http.Post(ts.URL, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) sm := http.NewServeMux()
NewServer(sm, db)
ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org")
if err != nil {
t.Error("failure to add user: %v", err)
}
bad := ts.URL
body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)
req, err := http.NewRequest("POST", bad, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
if len(ms.p) != 0 { if len(db.Pkgs()) != 0 {
t.Errorf("should have failed to insert; got %d, want %d", len(ms.p), 0) t.Errorf("should have failed to insert; got %d, want %d", len(db.Pkgs()), 0)
} }
if want := http.StatusBadRequest; resp.StatusCode != want { if got, want := resp.StatusCode, http.StatusBadRequest; got != want {
t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want)) t.Errorf("should have failed to post at bad route; got %s, want %s", http.StatusText(got), http.StatusText(want))
} }
} }
func TestCannotDuplicateExistingPath(t *testing.T) { func TestCannotDuplicateExistingPath(t *testing.T) {
ms := NewSimpleStore("") db, done := testDB(t)
s := &Server{ if db == nil {
storage: ms, t.Fatalf("could not create temp db")
} }
ts := httptest.NewServer(s) defer done()
url := fmt.Sprintf("%s/foo", ts.URL) sm := http.NewServeMux()
resp, err := http.Post(url, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) NewServer(sm, db)
ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org")
if err != nil {
t.Error("failure to add user: %v", err)
}
u := fmt.Sprintf("%s/foo", ts.URL)
{
body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)
req, err := http.NewRequest("POST", u, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
if want := http.StatusOK; resp.StatusCode != want { if want := http.StatusOK; resp.StatusCode != want {
t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want))
} }
resp, err = http.Post(url, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) }
{
body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)
req, err := http.NewRequest("POST", u, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
if want := http.StatusConflict; resp.StatusCode != want { if want := http.StatusConflict; resp.StatusCode != want {
t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want))
} }
}
} }
func TestCannotAddExistingSubPath(t *testing.T) { func TestCannotAddExistingSubPath(t *testing.T) {
ms := NewSimpleStore("") db, done := testDB(t)
s := &Server{ if db == nil {
storage: ms, t.Fatalf("could not create temp db")
} }
ts := httptest.NewServer(s) defer done()
url := fmt.Sprintf("%s/foo/bar", ts.URL) sm := http.NewServeMux()
resp, err := http.Post(url, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) NewServer(sm, db)
ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org")
if err != nil {
t.Error("failure to add user: %v", err)
}
{
u := fmt.Sprintf("%s/foo/bar", ts.URL)
t.Logf("url: %v", u)
body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)
req, err := http.NewRequest("POST", u, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
if want := http.StatusOK; resp.StatusCode != want { if want := http.StatusOK; resp.StatusCode != want {
t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want))
} }
}
url = fmt.Sprintf("%s/foo", ts.URL) {
resp, err = http.Post(url, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) u := fmt.Sprintf("%s/foo", ts.URL)
resp, err = http.Post(url, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)
req, err := http.NewRequest("POST", u, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
if want := http.StatusConflict; resp.StatusCode != want { if want := http.StatusConflict; resp.StatusCode != want {
t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want))
} }
}
} }
func TestMissingRepo(t *testing.T) { func TestMissingRepo(t *testing.T) {
ms := NewSimpleStore("") db, done := testDB(t)
s := &Server{ if db == nil {
storage: ms, t.Fatalf("could not create temp db")
} }
ts := httptest.NewServer(s) defer done()
url := fmt.Sprintf("%s/foo", ts.URL)
resp, err := http.Post(url, "application/json", strings.NewReader(`{}`)) sm := http.NewServeMux()
NewServer(sm, db)
ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org")
if err != nil {
t.Error("failure to add user: %v", err)
}
u := fmt.Sprintf("%s/foo", ts.URL)
body := strings.NewReader(`{}`)
req, err := http.NewRequest("POST", u, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
if len(ms.p) != 0 { if len(db.Pkgs()) != 0 {
t.Errorf("should have failed to insert; got %d, want %d", len(ms.p), 0) t.Errorf("should have failed to insert; got %d, want %d", len(db.Pkgs()), 0)
} }
if want := http.StatusBadRequest; resp.StatusCode != want { if want := http.StatusBadRequest; resp.StatusCode != want {
t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want)) t.Errorf("should have failed to post with bad payload; got %s, want %s", resp.Status, http.StatusText(want))
} }
} }
func TestBadJson(t *testing.T) { func TestBadJson(t *testing.T) {
ms := NewSimpleStore("") db, done := testDB(t)
s := &Server{ if db == nil {
storage: ms, t.Fatalf("could not create temp db")
} }
ts := httptest.NewServer(s) defer done()
url := fmt.Sprintf("%s/foo", ts.URL)
resp, err := http.Post(url, "application/json", strings.NewReader(`{`)) sm := http.NewServeMux()
NewServer(sm, db)
ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org")
if err != nil {
t.Error("failure to add user: %v", err)
}
u := fmt.Sprintf("%s/foo", ts.URL)
body := strings.NewReader(`{`)
req, err := http.NewRequest("POST", u, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
if len(ms.p) != 0 { if len(db.Pkgs()) != 0 {
t.Errorf("should have failed to insert; got %d, want %d", len(ms.p), 0) t.Errorf("should have failed to insert; got %d, want %d", len(db.Pkgs()), 0)
} }
if want := http.StatusBadRequest; resp.StatusCode != want { if want := http.StatusBadRequest; resp.StatusCode != want {
t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want)) t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want))
} }
} }
func TestBadVcs(t *testing.T) { func TestNoAuth(t *testing.T) {
ms := NewSimpleStore("") db, done := testDB(t)
s := &Server{ if db == nil {
storage: ms, t.Fatalf("could not create temp db")
} }
ts := httptest.NewServer(s) defer done()
url := fmt.Sprintf("%s/foo", ts.URL)
resp, err := http.Post(url, "application/json", strings.NewReader(`{"vcs": "bitbucket", "repo": "https://s.mcquay.me/sm/vain"}`)) sm := http.NewServeMux()
NewServer(sm, db)
ts := httptest.NewServer(sm)
u := fmt.Sprintf("%s/foo", ts.URL)
body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)
req, err := http.NewRequest("POST", u, body)
req.Header.Add("Content-Type", "application/json")
// here we don't set the Authorization header
// req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("couldn't POST: %v", err)
}
resp.Body.Close()
if got, want := resp.StatusCode, http.StatusUnauthorized; got != want {
t.Errorf("posted with missing auth; got %v, want %v", http.StatusText(got), http.StatusText(want))
}
}
func TestBadVcs(t *testing.T) {
db, done := testDB(t)
if db == nil {
t.Fatalf("could not create temp db")
}
defer done()
sm := http.NewServeMux()
NewServer(sm, db)
ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org")
if err != nil {
t.Error("failure to add user: %v", err)
}
u := fmt.Sprintf("%s/foo", ts.URL)
body := strings.NewReader(`{"vcs": "bitbucket", "repo": "https://s.mcquay.me/sm/vain"}`)
req, err := http.NewRequest("POST", u, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
@ -242,40 +411,31 @@ func TestBadVcs(t *testing.T) {
} }
func TestUnsupportedMethod(t *testing.T) { func TestUnsupportedMethod(t *testing.T) {
ms := NewSimpleStore("") db, done := testDB(t)
s := &Server{ if db == nil {
storage: ms, t.Fatalf("could not create temp db")
} }
ts := httptest.NewServer(s) defer done()
url := fmt.Sprintf("%s/foo", ts.URL)
client := &http.Client{}
req, err := http.NewRequest("PUT", url, nil)
resp, err := client.Do(req)
if err != nil {
t.Errorf("couldn't POST: %v", err)
}
if len(ms.p) != 0 {
t.Errorf("should have failed to insert; got %d, want %d", len(ms.p), 0)
}
if want := http.StatusMethodNotAllowed; resp.StatusCode != want {
t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want))
}
}
func TestNewServer(t *testing.T) {
ms := NewSimpleStore("")
sm := http.NewServeMux() sm := http.NewServeMux()
s := NewServer(sm, ms) NewServer(sm, db)
ts := httptest.NewServer(s) ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org")
if err != nil {
t.Error("failure to add user: %v", err)
}
url := fmt.Sprintf("%s/foo", ts.URL) url := fmt.Sprintf("%s/foo", ts.URL)
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest("PUT", url, nil) req, err := http.NewRequest("PUT", url, nil)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
if len(ms.p) != 0 { if len(db.Pkgs()) != 0 {
t.Errorf("should have failed to insert; got %d, want %d", len(ms.p), 0) t.Errorf("should have failed to insert; got %d, want %d", len(db.Pkgs()), 0)
} }
if want := http.StatusMethodNotAllowed; resp.StatusCode != want { if want := http.StatusMethodNotAllowed; resp.StatusCode != want {
t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want)) t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want))
@ -283,27 +443,37 @@ func TestNewServer(t *testing.T) {
} }
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
ms := NewSimpleStore("") db, done := testDB(t)
sm := http.NewServeMux() if db == nil {
_ = NewServer(sm, ms) t.Fatalf("could not create temp db")
ts := httptest.NewServer(sm)
resp, err := http.Get(ts.URL)
if err != nil {
t.Errorf("couldn't GET: %v", err)
} }
resp.Body.Close() defer done()
if len(ms.p) != 0 {
t.Errorf("started with something in it; got %d, want %d", len(ms.p), 0) sm := http.NewServeMux()
NewServer(sm, db)
ts := httptest.NewServer(sm)
tok, err := db.addUser("sm@example.org")
if err != nil {
t.Error("failure to add user: %v", err)
}
t.Logf("%v", tok)
if len(db.Pkgs()) != 0 {
t.Errorf("started with something in it; got %d, want %d", len(db.Pkgs()), 0)
} }
u := fmt.Sprintf("%s/foo", ts.URL) u := fmt.Sprintf("%s/foo", ts.URL)
resp, err = http.Post(u, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)
req, err := http.NewRequest("POST", u, body)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
if got, want := len(ms.p), 1; got != want { if got, want := len(db.Pkgs()), 1; got != want {
t.Errorf("storage should have something in it; got %d, want %d", got, want) t.Errorf("pkgs should have something in it; got %d, want %d", got, want)
} }
{ {
@ -311,6 +481,7 @@ func TestDelete(t *testing.T) {
u := fmt.Sprintf("%s/bar", ts.URL) u := fmt.Sprintf("%s/bar", ts.URL)
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest("DELETE", u, nil) req, err := http.NewRequest("DELETE", u, nil)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err = client.Do(req) resp, err = client.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
@ -319,14 +490,18 @@ func TestDelete(t *testing.T) {
t.Errorf("should have not been able to delete unknown package; got %v, want %v", http.StatusText(got), http.StatusText(want)) t.Errorf("should have not been able to delete unknown package; got %v, want %v", http.StatusText(got), http.StatusText(want))
} }
} }
{
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest("DELETE", u, nil) req, err := http.NewRequest("DELETE", u, nil)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok))
resp, err = client.Do(req) resp, err = client.Do(req)
if err != nil { if err != nil {
t.Errorf("couldn't POST: %v", err) t.Errorf("couldn't POST: %v", err)
} }
if got, want := len(ms.p), 0; got != want { if got, want := len(db.Pkgs()), 0; got != want {
t.Errorf("storage should be empty; got %d, want %d", got, want) t.Errorf("pkgs should be empty; got %d, want %d", got, want)
}
} }
} }

View File

@ -51,21 +51,45 @@ import (
"github.com/kelseyhightower/envconfig" "github.com/kelseyhightower/envconfig"
) )
const usage = `vaind const usage = "vaind [init] <dbname>"
environment vars:
VAIN_PORT: tcp listen port
VAIN_HOST: hostname to use
VAIN_DB: path to json database
`
type config struct { type config struct {
Port int Port int
DB string
} }
func main() { func main() {
if len(os.Args) < 2 {
fmt.Fprintf(os.Stderr, "%s\n", usage)
os.Exit(1)
}
if os.Args[1] == "init" {
if len(os.Args) != 3 {
fmt.Fprintf(os.Stderr, "missing db name: %s\n", usage)
os.Exit(1)
}
db, err := vain.NewDB(os.Args[2])
if err != nil {
fmt.Fprintf(os.Stderr, "couldn't open db: %v\n", err)
os.Exit(1)
}
defer db.Close()
if err := db.Init(); err != nil {
fmt.Fprintf(os.Stderr, "problem initializing the db: %v\n", err)
os.Exit(1)
}
os.Exit(0)
}
db, err := vain.NewDB(os.Args[1])
if err != nil {
fmt.Fprintf(os.Stderr, "couldn't open db: %v\n", err)
os.Exit(1)
}
c := &config{ c := &config{
Port: 4040, Port: 4040,
} }
@ -80,9 +104,6 @@ func main() {
os.Exit(0) os.Exit(0)
} }
} }
if c.DB == "" {
log.Printf("warning: in-memory db mode; if you do not want this set VAIN_DB")
}
hostname := "localhost" hostname := "localhost"
if hn, err := os.Hostname(); err != nil { if hn, err := os.Hostname(); err != nil {
log.Printf("problem getting hostname:", err) log.Printf("problem getting hostname:", err)
@ -91,11 +112,7 @@ func main() {
} }
log.Printf("serving at: http://%s:%d/", hostname, c.Port) log.Printf("serving at: http://%s:%d/", hostname, c.Port)
sm := http.NewServeMux() sm := http.NewServeMux()
ms := vain.NewSimpleStore(c.DB) vain.NewServer(sm, db)
if err := ms.Load(); err != nil {
log.Printf("unable to load db: %v; creating fresh database", err)
}
vain.NewServer(sm, ms)
addr := fmt.Sprintf(":%d", c.Port) addr := fmt.Sprintf(":%d", c.Port)
if err := http.ListenAndServe(addr, sm); err != nil { if err := http.ListenAndServe(addr, sm); err != nil {
log.Printf("problem with http server: %v", err) log.Printf("problem with http server: %v", err)

268
db.go Normal file
View File

@ -0,0 +1,268 @@
package vain
import (
"fmt"
"log"
"net/http"
"time"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
verrors "mcquay.me/vain/errors"
vsql "mcquay.me/vain/sql"
)
type DB struct {
conn *sqlx.DB
}
func NewDB(path string) (*DB, error) {
conn, err := sqlx.Open("sqlite3", fmt.Sprintf("file:%s?cache=shared&mode=rwc", path))
if _, err := conn.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, err
}
return &DB{conn}, err
}
func (db *DB) Init() error {
content, err := vsql.Asset("sql/init.sql")
if err != nil {
return err
}
_, err = db.conn.Exec(string(content))
return err
}
func (db *DB) Close() error {
return db.conn.Close()
}
func (db *DB) AddPackage(p Package) error {
_, err := db.conn.NamedExec(
"INSERT INTO packages(vcs, repo, path, ns) VALUES (:vcs, :repo, :path, :ns)",
&p,
)
return err
}
func (db *DB) RemovePackage(path string) error {
_, err := db.conn.Exec("DELETE FROM packages WHERE path = ?", path)
return err
}
func (db *DB) Pkgs() []Package {
r := []Package{}
rows, err := db.conn.Queryx("SELECT * FROM packages")
if err != nil {
log.Printf("%+v", err)
return nil
}
for rows.Next() {
var p Package
err = rows.StructScan(&p)
if err != nil {
log.Printf("%+v", err)
return nil
}
r = append(r, p)
}
return r
}
func (db *DB) PackageExists(path string) bool {
var count int
if err := db.conn.Get(&count, "SELECT COUNT(*) FROM packages WHERE path = ?", path); err != nil {
log.Printf("%+v", err)
}
r := false
switch count {
case 1:
r = true
default:
log.Printf("unexpected count of packages matching %q: %d", path, count)
}
return r
}
func (db *DB) Package(path string) (Package, error) {
r := Package{}
err := db.conn.Get(&r, "SELECT * FROM packages WHERE path = ?", path)
return r, err
}
func (db *DB) NSForToken(ns string, tok string) error {
var err error
txn, err := db.conn.Beginx()
if err != nil {
return verrors.HTTP{
Message: fmt.Sprintf("problem creating transaction: %v", err),
Code: http.StatusInternalServerError,
}
}
defer func() {
if err != nil {
txn.Rollback()
} else {
txn.Commit()
}
}()
var count int
if err = txn.Get(&count, "SELECT COUNT(*) FROM namespaces WHERE namespaces.ns = ?", ns); err != nil {
return verrors.HTTP{
Message: fmt.Sprintf("problem matching fetching namespaces matching %q", ns),
Code: http.StatusInternalServerError,
}
}
if count == 0 {
if _, err = txn.Exec(
"INSERT INTO namespaces(ns, email) SELECT ?, users.email FROM users WHERE users.token = ?",
ns,
tok,
); err != nil {
return verrors.HTTP{
Message: fmt.Sprintf("problem inserting %q into namespaces for token %q: %v", ns, tok, err),
Code: http.StatusInternalServerError,
}
}
return err
}
if err = txn.Get(&count, "SELECT COUNT(*) FROM namespaces JOIN users ON namespaces.email = users.email WHERE users.token = ? AND namespaces.ns = ?", tok, ns); err != nil {
return verrors.HTTP{
Message: fmt.Sprintf("ns: %q, tok: %q; %v", ns, tok, err),
Code: http.StatusInternalServerError,
}
}
switch count {
case 1:
err = nil
case 0:
err = verrors.HTTP{
Message: fmt.Sprintf("not authorized against namespace %q", ns),
Code: http.StatusUnauthorized,
}
default:
err = verrors.HTTP{
Message: fmt.Sprintf("inconsistent db; found %d results with ns (%s) with token (%s): %d", count, ns, tok),
Code: http.StatusInternalServerError,
}
}
return err
}
func (db *DB) Register(email string) (string, error) {
var err error
txn, err := db.conn.Beginx()
if err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("problem creating transaction: %v", err),
Code: http.StatusInternalServerError,
}
}
defer func() {
if err != nil {
txn.Rollback()
} else {
txn.Commit()
}
}()
var count int
if err = txn.Get(&count, "SELECT COUNT(*) FROM users WHERE email = ?", email); err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("could not search for email %q in db: %v", email, err),
Code: http.StatusInternalServerError,
}
}
if count != 0 {
return "", verrors.HTTP{
Message: fmt.Sprintf("duplicate email %q", email),
Code: http.StatusConflict,
}
}
tok := FreshToken()
_, err = txn.Exec(
"INSERT INTO users(email, token, requested) VALUES (?, ?, ?)",
email,
tok,
time.Now(),
)
return tok, err
}
func (db *DB) Confirm(token string) (string, error) {
var err error
txn, err := db.conn.Beginx()
if err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("problem creating transaction: %v", err),
Code: http.StatusInternalServerError,
}
}
defer func() {
if err != nil {
txn.Rollback()
} else {
txn.Commit()
}
}()
var count int
if err = txn.Get(&count, "SELECT COUNT(*) FROM users WHERE token = ?", token); err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("could not perform search for user with token %q in db: %v", token, err),
Code: http.StatusInternalServerError,
}
}
if count != 1 {
return "", verrors.HTTP{
Message: fmt.Sprintf("bad token: %s", token),
Code: http.StatusNotFound,
}
}
newToken := FreshToken()
_, err = txn.Exec(
"UPDATE users SET token = ?, registered = 1 WHERE token = ?",
newToken,
token,
)
if err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("couldn't update user with token %q", token),
Code: http.StatusInternalServerError,
}
}
return newToken, nil
}
func (db *DB) forgot(email string) (string, error) {
var token string
if err := db.conn.Get(&token, "SELECT token FROM users WHERE email = ?", email); err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("could not search for email %q in db: %v", email, err),
Code: http.StatusInternalServerError,
}
}
return token, nil
}
func (db *DB) addUser(email string) (string, error) {
tok := FreshToken()
_, err := db.conn.Exec(
"INSERT INTO users(email, token, requested) VALUES (?, ?, ?)",
email,
tok,
time.Now(),
)
return tok, err
}

30
errors/errors.go Normal file
View File

@ -0,0 +1,30 @@
package errors
import (
"fmt"
"net/http"
)
type HTTP struct {
error
Message string
Code int
}
func (e HTTP) Error() string {
return fmt.Sprintf("%d: %s", e.Code, e.Message)
}
func ToHTTP(err error) *HTTP {
if err == nil {
return nil
}
rerr := &HTTP{
Message: err.Error(),
Code: http.StatusInternalServerError,
}
if e, ok := err.(HTTP); ok {
rerr.Code = e.Code
}
return rerr
}

6
gen.go Normal file
View File

@ -0,0 +1,6 @@
package vain
//go:generate go get github.com/jteeuwen/go-bindata/...
//go:generate go get github.com/elazarl/go-bindata-assetfs/...
//go:generate rm -f sql/static.go
//go:generate go-bindata -pkg sql -o sql/static.go sql/...

127
server.go
View File

@ -3,14 +3,30 @@ package vain
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"net/http" "net/http"
"strings" "strings"
verrors "mcquay.me/vain/errors"
) )
const apiPrefix = "/api/v0/"
var prefix map[string]string
func init() {
prefix = map[string]string{
"pkgs": apiPrefix + "db/",
"register": apiPrefix + "register/",
"confirm": apiPrefix + "confirm/",
"forgot": apiPrefix + "forgot/",
}
}
// NewServer populates a server, adds the routes, and returns it for use. // NewServer populates a server, adds the routes, and returns it for use.
func NewServer(sm *http.ServeMux, store Storage) *Server { func NewServer(sm *http.ServeMux, store *DB) *Server {
s := &Server{ s := &Server{
storage: store, db: store,
} }
addRoutes(sm, s) addRoutes(sm, s)
return s return s
@ -18,17 +34,43 @@ func NewServer(sm *http.ServeMux, store Storage) *Server {
// Server serves up the http. // Server serves up the http.
type Server struct { type Server struct {
storage Storage db *DB
} }
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
switch req.Method { if req.Method == "GET" {
case "GET": // TODO: perhaps have a nicely formatted page with info as root if
// go-get=1 not in request?
fmt.Fprintf(w, "<!DOCTYPE html>\n<html><head>\n") fmt.Fprintf(w, "<!DOCTYPE html>\n<html><head>\n")
for _, p := range s.storage.All() { for _, p := range s.db.Pkgs() {
fmt.Fprintf(w, "%s\n", p) fmt.Fprintf(w, "%s\n", p)
} }
fmt.Fprintf(w, "</head>\n<body><p>go tool metadata in head</p>\n</html>\n") fmt.Fprintf(w, "</head>\n<body><p>go tool metadata in head</p></body>\n</html>\n")
return
}
const prefix = "Bearer "
var tok string
auth := req.Header.Get("Authorization")
if strings.HasPrefix(auth, prefix) {
tok = strings.TrimPrefix(auth, prefix)
}
if tok == "" {
http.Error(w, "missing token", http.StatusUnauthorized)
return
}
ns, err := parseNamespace(req.URL.Path)
if err != nil {
http.Error(w, fmt.Sprintf("could not parse namespace:%v", err), http.StatusBadRequest)
return
}
if err := verrors.ToHTTP(s.db.NSForToken(ns, tok)); err != nil {
http.Error(w, err.Message, err.Code)
return
}
switch req.Method {
case "POST": case "POST":
if req.URL.Path == "/" { if req.URL.Path == "/" {
http.Error(w, fmt.Sprintf("invalid path %q", req.URL.Path), http.StatusBadRequest) http.Error(w, fmt.Sprintf("invalid path %q", req.URL.Path), http.StatusBadRequest)
@ -40,7 +82,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return return
} }
if p.Repo == "" { if p.Repo == "" {
http.Error(w, fmt.Sprintf("invalid repository %q", req.URL.Path), http.StatusBadRequest) http.Error(w, fmt.Sprintf("invalid repository %q", p.Repo), http.StatusBadRequest)
return return
} }
if p.Vcs == "" { if p.Vcs == "" {
@ -50,22 +92,24 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.Error(w, fmt.Sprintf("invalid vcs %q", p.Vcs), http.StatusBadRequest) http.Error(w, fmt.Sprintf("invalid vcs %q", p.Vcs), http.StatusBadRequest)
return return
} }
p.path = fmt.Sprintf("%s/%s", req.Host, strings.Trim(req.URL.Path, "/")) p.Path = fmt.Sprintf("%s/%s", req.Host, strings.Trim(req.URL.Path, "/"))
if !Valid(p.path, s.storage.All()) { p.Ns = ns
if !Valid(p.Path, s.db.Pkgs()) {
http.Error(w, fmt.Sprintf("invalid path; prefix already taken %q", req.URL.Path), http.StatusConflict) http.Error(w, fmt.Sprintf("invalid path; prefix already taken %q", req.URL.Path), http.StatusConflict)
return return
} }
if err := s.storage.Add(p); err != nil { if err := s.db.AddPackage(p); err != nil {
http.Error(w, fmt.Sprintf("unable to add package: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("unable to add package: %v", err), http.StatusInternalServerError)
return return
} }
case "DELETE": case "DELETE":
p := fmt.Sprintf("%s/%s", req.Host, strings.Trim(req.URL.Path, "/")) p := fmt.Sprintf("%s/%s", req.Host, strings.Trim(req.URL.Path, "/"))
if !s.storage.Contains(p) { if !s.db.PackageExists(p) {
http.Error(w, fmt.Sprintf("package %q not found", p), http.StatusNotFound) http.Error(w, fmt.Sprintf("package %q not found", p), http.StatusNotFound)
return return
} }
if err := s.storage.Remove(p); err != nil {
if err := s.db.RemovePackage(p); err != nil {
http.Error(w, fmt.Sprintf("unable to delete package: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("unable to delete package: %v", err), http.StatusInternalServerError)
return return
} }
@ -74,13 +118,62 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
} }
func (s *Server) db(w http.ResponseWriter, req *http.Request) { func (s *Server) register(w http.ResponseWriter, req *http.Request) {
all := s.storage.All() req.ParseForm()
email, ok := req.Form["email"]
if !ok || len(email) != 1 {
http.Error(w, "must provide one email parameter", http.StatusBadRequest)
return
}
tok, err := s.db.Register(email[0])
if err := verrors.ToHTTP(err); err != nil {
http.Error(w, err.Message, err.Code)
return
}
log.Printf("http://%s/api/v0/confirm/%+v", req.Host, tok)
fmt.Fprintf(w, "please check your email\n")
}
func (s *Server) confirm(w http.ResponseWriter, req *http.Request) {
tok := req.URL.Path[len(prefix["confirm"]):]
tok = strings.TrimRight(tok, "/")
if tok == "" {
http.Error(w, "must provide one email parameter", http.StatusBadRequest)
return
}
tok, err := s.db.Confirm(tok)
if err := verrors.ToHTTP(err); err != nil {
http.Error(w, err.Message, err.Code)
return
}
fmt.Fprintf(w, "new token: %s\n", tok)
}
func (s *Server) forgot(w http.ResponseWriter, req *http.Request) {
req.ParseForm()
email, ok := req.Form["email"]
if !ok || len(email) != 1 {
http.Error(w, "must provide one email parameter", http.StatusBadRequest)
return
}
tok, err := s.db.forgot(email[0])
if err := verrors.ToHTTP(err); err != nil {
http.Error(w, err.Message, err.Code)
return
}
log.Printf("http://%s/api/v0/confirm/%+v", req.Host, tok)
fmt.Fprintf(w, "please check your email\n")
}
func (s *Server) pkgs(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-type", "application/json") w.Header().Set("Content-type", "application/json")
json.NewEncoder(w).Encode(&all) json.NewEncoder(w).Encode(s.db.Pkgs())
} }
func addRoutes(sm *http.ServeMux, s *Server) { func addRoutes(sm *http.ServeMux, s *Server) {
sm.Handle("/", s) sm.Handle("/", s)
sm.HandleFunc("/db/", s.db)
sm.HandleFunc(prefix["pkgs"], s.pkgs)
sm.HandleFunc(prefix["register"], s.register)
sm.HandleFunc(prefix["confirm"], s.confirm)
sm.HandleFunc(prefix["forgot"], s.forgot)
} }

18
sql/init.sql Normal file
View File

@ -0,0 +1,18 @@
CREATE TABLE users (
email TEXT PRIMARY KEY,
token TEXT UNIQUE,
registered boolean DEFAULT 0,
requested DATETIME
);
CREATE TABLE namespaces (
ns TEXT PRIMARY KEY,
email TEXT REFERENCES users(email) ON DELETE CASCADE
);
CREATE TABLE packages (
vcs TEXT,
repo TEXT,
path TEXT UNIQUE,
ns TEXT REFERENCES namespaces(ns) ON DELETE CASCADE
);

View File

@ -1,126 +0,0 @@
package vain
import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"sync"
)
// Valid checks that p will not confuse the go tool if added to packages.
func Valid(p string, packages []Package) bool {
for _, pkg := range packages {
if strings.HasPrefix(pkg.path, p) || strings.HasPrefix(p, pkg.path) {
return false
}
}
return true
}
// Storage is a shim to allow for alternate storage types.
type Storage interface {
Contains(name string) bool
Add(p Package) error
Remove(path string) error
All() []Package
}
// SimpleStore implements a simple json on-disk storage.
type SimpleStore struct {
l sync.RWMutex
p map[string]Package
dbl sync.Mutex
path string
}
// NewSimpleStore returns a ready-to-use SimpleStore storing json at path.
func NewSimpleStore(path string) *SimpleStore {
return &SimpleStore{
path: path,
p: make(map[string]Package),
}
}
func (ms *SimpleStore) Contains(name string) bool {
_, contains := ms.p[name]
return contains
}
// Add adds p to the SimpleStore.
func (ss *SimpleStore) Add(p Package) error {
ss.l.Lock()
ss.p[p.path] = p
ss.l.Unlock()
m := ""
if err := ss.Save(); err != nil {
m = fmt.Sprintf("unable to store db: %v", err)
if err := ss.Remove(p.path); err != nil {
m = fmt.Sprintf("%s\nto add insult to injury, could not delete package: %v\n", m, err)
}
return errors.New(m)
}
return nil
}
// Remove removes p from the SimpleStore.
func (ss *SimpleStore) Remove(path string) error {
ss.l.Lock()
delete(ss.p, path)
ss.l.Unlock()
return nil
}
// All returns all current packages.
func (ss *SimpleStore) All() []Package {
r := []Package{}
ss.l.RLock()
for _, p := range ss.p {
r = append(r, p)
}
ss.l.RUnlock()
return r
}
// Save writes the db to disk.
func (ss *SimpleStore) Save() error {
// running in-memory only
if ss.path == "" {
return nil
}
ss.dbl.Lock()
defer ss.dbl.Unlock()
f, err := os.Create(ss.path)
if err != nil {
return err
}
defer f.Close()
return json.NewEncoder(f).Encode(ss.p)
}
// Load reads the db from disk and populates ss.
func (ss *SimpleStore) Load() error {
// running in-memory only
if ss.path == "" {
return nil
}
ss.dbl.Lock()
defer ss.dbl.Unlock()
f, err := os.Open(ss.path)
if err != nil {
return err
}
defer f.Close()
in := map[string]Package{}
if err := json.NewDecoder(f).Decode(&in); err != nil {
return err
}
for k, v := range in {
v.path = k
ss.p[k] = v
}
return nil
}

View File

@ -1,127 +0,0 @@
package vain
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
)
func equiv(a, b map[string]Package) (bool, []error) {
equiv := true
errs := []error{}
if got, want := len(a), len(b); got != want {
equiv = false
errs = append(errs, fmt.Errorf("incorrect number of elements: got %d, want %d", got, want))
return false, errs
}
for k := range a {
v, ok := b[k]
if !ok || v != a[k] {
errs = append(errs, fmt.Errorf("missing key: %s", k))
equiv = false
break
}
}
return equiv, errs
}
func TestSimpleStorage(t *testing.T) {
root, err := ioutil.TempDir("", "vain-")
if err != nil {
t.Fatalf("problem creating temp dir: %v", err)
}
defer func() { os.RemoveAll(root) }()
db := filepath.Join(root, "vain.json")
ms := NewSimpleStore(db)
orig := map[string]Package{
"foo": {Vcs: "mercurial"},
"bar": {Vcs: "bzr"},
"baz": {},
}
for k, v := range orig {
v.path = k
orig[k] = v
}
ms.p = orig
if err := ms.Save(); err != nil {
t.Errorf("should have been able to Save: %v", err)
}
ms.p = map[string]Package{}
if err := ms.Load(); err != nil {
t.Errorf("should have been able to Load: %v", err)
}
if ok, errs := equiv(ms.p, orig); !ok {
for _, err := range errs {
t.Error(err)
}
}
}
func TestRemove(t *testing.T) {
root, err := ioutil.TempDir("", "vain-")
if err != nil {
t.Fatalf("problem creating temp dir: %v", err)
}
defer func() { os.RemoveAll(root) }()
db := filepath.Join(root, "vain.json")
ms := NewSimpleStore(db)
ms.p = map[string]Package{
"foo": {},
"bar": {},
"baz": {},
}
if err := ms.Remove("foo"); err != nil {
t.Errorf("unexpected error during remove: %v", err)
}
want := map[string]Package{
"bar": {},
"baz": {},
}
if ok, errs := equiv(ms.p, want); !ok {
for _, err := range errs {
t.Error(err)
}
}
}
func TestPackageJsonParsing(t *testing.T) {
tests := []struct {
input string
output string
parsed Package
}{
{
input: `{"vcs":"git","repo":"https://s.mcquay.me/sm/ud/"}`,
output: `{"vcs":"git","repo":"https://s.mcquay.me/sm/ud/"}`,
parsed: Package{Vcs: "git", Repo: "https://s.mcquay.me/sm/ud/"},
},
{
input: `{"vcs":"hg","repo":"https://s.mcquay.me/sm/ud/"}`,
output: `{"vcs":"hg","repo":"https://s.mcquay.me/sm/ud/"}`,
parsed: Package{Vcs: "hg", Repo: "https://s.mcquay.me/sm/ud/"},
},
}
for _, test := range tests {
p := Package{}
if err := json.NewDecoder(strings.NewReader(test.input)).Decode(&p); err != nil {
t.Error(err)
}
if p != test.parsed {
t.Errorf("got:\n\t%v, want\n\t%v", p, test.parsed)
}
buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(&p); err != nil {
t.Error(err)
}
if got, want := strings.TrimSpace(buf.String()), test.output; got != want {
t.Errorf("got %v, want %v", got, want)
}
}
}

33
testing.go Normal file
View File

@ -0,0 +1,33 @@
package vain
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
)
func testDB(t *testing.T) (*DB, func()) {
dir, err := ioutil.TempDir("", "vain-testing-")
if err != nil {
t.Fatalf("could not create tmpdir for db: %v", err)
return nil, func() {}
}
name := filepath.Join(dir, "test.db")
db, err := NewDB(name)
if err != nil {
t.Fatalf("could not create db: %v", err)
return nil, func() {}
}
if err := db.Init(); err != nil {
return nil, func() {}
}
return db, func() {
db.Close()
if err := os.RemoveAll(dir); err != nil {
t.Fatalf("could not clean up tmpdir: %v", err)
}
}
}

45
vain.go
View File

@ -3,7 +3,15 @@
// The executable, cmd/vaind, is located in the respective subdirectory. // The executable, cmd/vaind, is located in the respective subdirectory.
package vain package vain
import "fmt" import (
"bytes"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
)
var vcss = map[string]bool{ var vcss = map[string]bool{
"hg": true, "hg": true,
@ -29,14 +37,45 @@ type Package struct {
// Repo: the remote repository url // Repo: the remote repository url
Repo string `json:"repo"` Repo string `json:"repo"`
path string Path string `json:"path"`
Ns string `json:"-"`
} }
func (p Package) String() string { func (p Package) String() string {
return fmt.Sprintf( return fmt.Sprintf(
"<meta name=\"go-import\" content=\"%s %s %s\">", "<meta name=\"go-import\" content=\"%s %s %s\">",
p.path, p.Path,
p.Vcs, p.Vcs,
p.Repo, p.Repo,
) )
} }
// Valid checks that p will not confuse the go tool if added to packages.
func Valid(p string, packages []Package) bool {
for _, pkg := range packages {
if strings.HasPrefix(pkg.Path, p) || strings.HasPrefix(p, pkg.Path) {
return false
}
}
return true
}
func parseNamespace(path string) (string, error) {
path = strings.TrimLeft(path, "/")
if path == "" {
return "", errors.New("path does not contain namespace")
}
elems := strings.Split(path, "/")
return elems[0], nil
}
func FreshToken() string {
buf := &bytes.Buffer{}
io.Copy(buf, io.LimitReader(rand.Reader, 6))
s := hex.EncodeToString(buf.Bytes())
r := []string{}
for i := 0; i < len(s)/4; i++ {
r = append(r, s[i*4:(i+1)*4])
}
return strings.Join(r, "-")
}

View File

@ -1,6 +1,7 @@
package vain package vain
import ( import (
"errors"
"fmt" "fmt"
"testing" "testing"
) )
@ -8,7 +9,7 @@ import (
func TestString(t *testing.T) { func TestString(t *testing.T) {
p := Package{ p := Package{
Vcs: "git", Vcs: "git",
path: "mcquay.me/bps", Path: "mcquay.me/bps",
Repo: "https://s.mcquay.me/sm/bps", Repo: "https://s.mcquay.me/sm/bps",
} }
got := fmt.Sprintf("%s", p) got := fmt.Sprintf("%s", p)
@ -55,59 +56,59 @@ func TestValid(t *testing.T) {
}, },
{ {
pkgs: []Package{ pkgs: []Package{
{path: "bobo"}, {Path: "bobo"},
}, },
in: "bobo", in: "bobo",
want: false, want: false,
}, },
{ {
pkgs: []Package{ pkgs: []Package{
{path: "a/b/c"}, {Path: "a/b/c"},
}, },
in: "a/b/c", in: "a/b/c",
want: false, want: false,
}, },
{ {
pkgs: []Package{ pkgs: []Package{
{path: "a/b/c"}, {Path: "a/b/c"},
}, },
in: "a/b", in: "a/b",
want: false, want: false,
}, },
{ {
pkgs: []Package{ pkgs: []Package{
{path: "name/db"}, {Path: "name/db"},
{path: "name/lib"}, {Path: "name/lib"},
}, },
in: "name/foo", in: "name/foo",
want: true, want: true,
}, },
{ {
pkgs: []Package{ pkgs: []Package{
{path: "a"}, {Path: "a"},
}, },
in: "a/b", in: "a/b",
want: false, want: false,
}, },
{ {
pkgs: []Package{ pkgs: []Package{
{path: "foo"}, {Path: "foo"},
}, },
in: "foo/bar", in: "foo/bar",
want: false, want: false,
}, },
{ {
pkgs: []Package{ pkgs: []Package{
{path: "foo/bar"}, {Path: "foo/bar"},
{path: "foo/baz"}, {Path: "foo/baz"},
}, },
in: "foo", in: "foo",
want: false, want: false,
}, },
{ {
pkgs: []Package{ pkgs: []Package{
{path: "bilbo"}, {Path: "bilbo"},
{path: "frodo"}, {Path: "frodo"},
}, },
in: "foo/bar/baz", in: "foo/bar/baz",
want: true, want: true,
@ -120,3 +121,51 @@ func TestValid(t *testing.T) {
} }
} }
} }
func TestNamespaceParsing(t *testing.T) {
tests := []struct {
input string
want string
err error
}{
{
input: "/sm/foo",
want: "sm",
},
{
input: "/a/b/c/d",
want: "a",
},
{
input: "/dm/bar",
want: "dm",
},
{
input: "/ud",
want: "ud",
},
// test stripping
{
input: "ud",
want: "ud",
},
{
input: "/",
err: errors.New("should find no namespace"),
},
{
input: "",
err: errors.New("should find no namespace"),
},
}
for _, test := range tests {
got, err := parseNamespace(test.input)
if err != nil && test.err == nil {
t.Errorf("unexpected error parsing %q; got %q, want %q, error: %v", test.input, got, test.want, err)
}
if got != test.want {
t.Errorf("parse failure: got %q, want %q", got, test.want)
}
}
}