From 680eecb11144fbdac9e9a764fbe09cbc7e0da27b Mon Sep 17 00:00:00 2001 From: stephen mcquay Date: Mon, 11 Apr 2016 20:43:18 -0700 Subject: [PATCH] Added simple user auth Fixes #14. Change-Id: I748933214f43ac7298f1e93c14bb0ee881976d43 --- .gitignore | 1 + api_test.go | 475 +++++++++++++++++++++++++++++++--------------- cmd/vaind/main.go | 51 +++-- db.go | 268 ++++++++++++++++++++++++++ errors/errors.go | 30 +++ gen.go | 6 + server.go | 127 +++++++++++-- sql/init.sql | 18 ++ storage.go | 126 ------------ storage_test.go | 127 ------------- testing.go | 33 ++++ vain.go | 45 ++++- vain_test.go | 73 +++++-- 13 files changed, 928 insertions(+), 452 deletions(-) create mode 100644 .gitignore create mode 100644 db.go create mode 100644 errors/errors.go create mode 100644 gen.go create mode 100644 sql/init.sql delete mode 100644 storage.go delete mode 100644 storage_test.go create mode 100644 testing.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..98e6ef6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.db diff --git a/api_test.go b/api_test.go index 0903d1f..eab5906 100644 --- a/api_test.go +++ b/api_test.go @@ -13,31 +13,55 @@ import ( ) func TestAdd(t *testing.T) { - ms := NewSimpleStore("") + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") + } + defer done() + sm := http.NewServeMux() - _ = NewServer(sm, ms) + NewServer(sm, db) ts := httptest.NewServer(sm) + tok, err := db.addUser("sm@example.org") + if err != nil { + t.Error("failure to add user: %v", err) + } + resp, err := http.Get(ts.URL) if err != nil { t.Errorf("couldn't GET: %v", err) } resp.Body.Close() - if len(ms.p) != 0 { - t.Errorf("started with something in it; got %d, want %d", len(ms.p), 0) - } - bad := ts.URL - resp, err = http.Post(bad, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) - if err != nil { - t.Errorf("couldn't POST: %v", err) - } - resp.Body.Close() - if len(ms.p) != 0 { - t.Errorf("started with something in it; got %d, want %d", len(ms.p), 0) + if got, want := len(db.Pkgs()), 0; got != want { + t.Errorf("started with something in it; got %d, want %d", got, want) } { - u := fmt.Sprintf("%s/db/", ts.URL) + bad := ts.URL + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err := http.NewRequest("POST", bad, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("couldn't POST: %v", err) + } + if got, want := resp.StatusCode, http.StatusBadRequest; got != want { + buf := &bytes.Buffer{} + io.Copy(buf, resp.Body) + t.Errorf("bad request got incorrect status: got %d, want %d", got, want) + t.Log("%s", buf) + } + resp.Body.Close() + + if got, want := len(db.Pkgs()), 0; got != want { + t.Errorf("started with something in it; got %d, want %d", got, want) + } + } + + { + u := fmt.Sprintf("%s/%s", ts.URL, prefix["pkgs"]) resp, err := http.Get(u) if err != nil { t.Error(err) @@ -53,33 +77,51 @@ func TestAdd(t *testing.T) { } } - u := fmt.Sprintf("%s/foo", ts.URL) - resp, err = http.Post(u, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) - if err != nil { - t.Errorf("couldn't POST: %v", err) - } + { - if len(ms.p) != 1 { - t.Errorf("storage should have something in it; got %d, want %d", len(ms.p), 1) - } + u := fmt.Sprintf("%s/foo", ts.URL) + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err := http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("problem performing request: %v", err) + } + buf := &bytes.Buffer{} + io.Copy(buf, resp.Body) + t.Logf("%v", buf) + resp.Body.Close() - ur, err := url.Parse(ts.URL) - if err != nil { - t.Error(err) - } - good := fmt.Sprintf("%s/foo", ur.Host) - p, ok := ms.p[good] - if !ok { - t.Fatalf("did not find package for %s; should have posted a valid package", good) - } - if p.path != good { - t.Errorf("package name did not go through as expected; got %q, want %q", p.path, good) - } - if want := "https://s.mcquay.me/sm/vain"; p.Repo != want { - t.Errorf("repo did not go through as expected; got %q, want %q", p.Repo, want) - } - if got, want := p.Vcs, "git"; got != want { - t.Errorf("Vcs did not go through as expected; got %q, want %q", got, want) + if got, want := len(db.Pkgs()), 1; got != want { + t.Errorf("pkgs should have something in it; got %d, want %d", got, want) + } + t.Logf("packages: %v", db.Pkgs()) + + ur, err := url.Parse(ts.URL) + if err != nil { + t.Error(err) + } + + good := fmt.Sprintf("%s/foo", ur.Host) + + if !db.PackageExists(good) { + t.Fatalf("did not find package for %s; should have posted a valid package", good) + } + p, err := db.Package(good) + t.Logf("%+v", p) + if err != nil { + t.Fatalf("problem getting package: %v", err) + } + if got, want := p.Path, good; got != want { + t.Errorf("package name did not go through as expected; got %q, want %q", got, want) + } + if got, want := p.Repo, "https://s.mcquay.me/sm/vain"; got != want { + t.Errorf("repo did not go through as expected; got %q, want %q", got, want) + } + if got, want := p.Vcs, "git"; got != want { + t.Errorf("Vcs did not go through as expected; got %q, want %q", got, want) + } } resp, err = http.Get(ts.URL) @@ -99,7 +141,7 @@ func TestAdd(t *testing.T) { } { - u := fmt.Sprintf("%s/db/", ts.URL) + u := fmt.Sprintf("%s/%s", ts.URL, prefix["pkgs"]) resp, err := http.Get(u) if err != nil { t.Error(err) @@ -117,121 +159,248 @@ func TestAdd(t *testing.T) { } func TestInvalidPath(t *testing.T) { - ms := NewSimpleStore("") - s := &Server{ - storage: ms, + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") } - ts := httptest.NewServer(s) + defer done() - resp, err := http.Post(ts.URL, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) + sm := http.NewServeMux() + NewServer(sm, db) + ts := httptest.NewServer(sm) + tok, err := db.addUser("sm@example.org") + if err != nil { + t.Error("failure to add user: %v", err) + } + + bad := ts.URL + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err := http.NewRequest("POST", bad, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) if err != nil { t.Errorf("couldn't POST: %v", err) } - if len(ms.p) != 0 { - t.Errorf("should have failed to insert; got %d, want %d", len(ms.p), 0) + if len(db.Pkgs()) != 0 { + t.Errorf("should have failed to insert; got %d, want %d", len(db.Pkgs()), 0) } - if want := http.StatusBadRequest; resp.StatusCode != want { - t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want)) + if got, want := resp.StatusCode, http.StatusBadRequest; got != want { + t.Errorf("should have failed to post at bad route; got %s, want %s", http.StatusText(got), http.StatusText(want)) } } func TestCannotDuplicateExistingPath(t *testing.T) { - ms := NewSimpleStore("") - s := &Server{ - storage: ms, + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") } - ts := httptest.NewServer(s) + defer done() - url := fmt.Sprintf("%s/foo", ts.URL) - resp, err := http.Post(url, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) + sm := http.NewServeMux() + NewServer(sm, db) + ts := httptest.NewServer(sm) + + tok, err := db.addUser("sm@example.org") if err != nil { - t.Errorf("couldn't POST: %v", err) + t.Error("failure to add user: %v", err) } - if want := http.StatusOK; resp.StatusCode != want { - t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) + + u := fmt.Sprintf("%s/foo", ts.URL) + { + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err := http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("couldn't POST: %v", err) + } + if want := http.StatusOK; resp.StatusCode != want { + t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) + } } - resp, err = http.Post(url, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) - if err != nil { - t.Errorf("couldn't POST: %v", err) - } - if want := http.StatusConflict; resp.StatusCode != want { - t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) + + { + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err := http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("couldn't POST: %v", err) + } + if want := http.StatusConflict; resp.StatusCode != want { + t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) + } } } func TestCannotAddExistingSubPath(t *testing.T) { - ms := NewSimpleStore("") - s := &Server{ - storage: ms, + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") } - ts := httptest.NewServer(s) + defer done() - url := fmt.Sprintf("%s/foo/bar", ts.URL) - resp, err := http.Post(url, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) + sm := http.NewServeMux() + NewServer(sm, db) + ts := httptest.NewServer(sm) + + tok, err := db.addUser("sm@example.org") if err != nil { - t.Errorf("couldn't POST: %v", err) - } - if want := http.StatusOK; resp.StatusCode != want { - t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) + t.Error("failure to add user: %v", err) } - url = fmt.Sprintf("%s/foo", ts.URL) - resp, err = http.Post(url, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) - resp, err = http.Post(url, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) - if err != nil { - t.Errorf("couldn't POST: %v", err) + { + u := fmt.Sprintf("%s/foo/bar", ts.URL) + t.Logf("url: %v", u) + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err := http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("couldn't POST: %v", err) + } + if want := http.StatusOK; resp.StatusCode != want { + t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) + } } - if want := http.StatusConflict; resp.StatusCode != want { - t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) + + { + u := fmt.Sprintf("%s/foo", ts.URL) + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err := http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("couldn't POST: %v", err) + } + if want := http.StatusConflict; resp.StatusCode != want { + t.Errorf("initial post should have worked; got %s, want %s", resp.Status, http.StatusText(want)) + } } } func TestMissingRepo(t *testing.T) { - ms := NewSimpleStore("") - s := &Server{ - storage: ms, + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") } - ts := httptest.NewServer(s) - url := fmt.Sprintf("%s/foo", ts.URL) - resp, err := http.Post(url, "application/json", strings.NewReader(`{}`)) + defer done() + + sm := http.NewServeMux() + NewServer(sm, db) + ts := httptest.NewServer(sm) + + tok, err := db.addUser("sm@example.org") + if err != nil { + t.Error("failure to add user: %v", err) + } + + u := fmt.Sprintf("%s/foo", ts.URL) + body := strings.NewReader(`{}`) + req, err := http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) if err != nil { t.Errorf("couldn't POST: %v", err) } - if len(ms.p) != 0 { - t.Errorf("should have failed to insert; got %d, want %d", len(ms.p), 0) + if len(db.Pkgs()) != 0 { + t.Errorf("should have failed to insert; got %d, want %d", len(db.Pkgs()), 0) } if want := http.StatusBadRequest; resp.StatusCode != want { - t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want)) + t.Errorf("should have failed to post with bad payload; got %s, want %s", resp.Status, http.StatusText(want)) } } func TestBadJson(t *testing.T) { - ms := NewSimpleStore("") - s := &Server{ - storage: ms, + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") } - ts := httptest.NewServer(s) - url := fmt.Sprintf("%s/foo", ts.URL) - resp, err := http.Post(url, "application/json", strings.NewReader(`{`)) + defer done() + + sm := http.NewServeMux() + NewServer(sm, db) + ts := httptest.NewServer(sm) + + tok, err := db.addUser("sm@example.org") + if err != nil { + t.Error("failure to add user: %v", err) + } + + u := fmt.Sprintf("%s/foo", ts.URL) + body := strings.NewReader(`{`) + req, err := http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) if err != nil { t.Errorf("couldn't POST: %v", err) } - if len(ms.p) != 0 { - t.Errorf("should have failed to insert; got %d, want %d", len(ms.p), 0) + if len(db.Pkgs()) != 0 { + t.Errorf("should have failed to insert; got %d, want %d", len(db.Pkgs()), 0) } if want := http.StatusBadRequest; resp.StatusCode != want { t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want)) } } -func TestBadVcs(t *testing.T) { - ms := NewSimpleStore("") - s := &Server{ - storage: ms, +func TestNoAuth(t *testing.T) { + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") } - ts := httptest.NewServer(s) - url := fmt.Sprintf("%s/foo", ts.URL) - resp, err := http.Post(url, "application/json", strings.NewReader(`{"vcs": "bitbucket", "repo": "https://s.mcquay.me/sm/vain"}`)) + defer done() + + sm := http.NewServeMux() + NewServer(sm, db) + ts := httptest.NewServer(sm) + + u := fmt.Sprintf("%s/foo", ts.URL) + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err := http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + + // here we don't set the Authorization header + // req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("couldn't POST: %v", err) + } + resp.Body.Close() + if got, want := resp.StatusCode, http.StatusUnauthorized; got != want { + t.Errorf("posted with missing auth; got %v, want %v", http.StatusText(got), http.StatusText(want)) + } +} + +func TestBadVcs(t *testing.T) { + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") + } + defer done() + + sm := http.NewServeMux() + NewServer(sm, db) + ts := httptest.NewServer(sm) + + tok, err := db.addUser("sm@example.org") + if err != nil { + t.Error("failure to add user: %v", err) + } + + u := fmt.Sprintf("%s/foo", ts.URL) + body := strings.NewReader(`{"vcs": "bitbucket", "repo": "https://s.mcquay.me/sm/vain"}`) + req, err := http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) if err != nil { t.Errorf("couldn't POST: %v", err) } @@ -242,40 +411,31 @@ func TestBadVcs(t *testing.T) { } func TestUnsupportedMethod(t *testing.T) { - ms := NewSimpleStore("") - s := &Server{ - storage: ms, + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") } - ts := httptest.NewServer(s) - url := fmt.Sprintf("%s/foo", ts.URL) - client := &http.Client{} - req, err := http.NewRequest("PUT", url, nil) - resp, err := client.Do(req) - if err != nil { - t.Errorf("couldn't POST: %v", err) - } - if len(ms.p) != 0 { - t.Errorf("should have failed to insert; got %d, want %d", len(ms.p), 0) - } - if want := http.StatusMethodNotAllowed; resp.StatusCode != want { - t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want)) - } -} + defer done() -func TestNewServer(t *testing.T) { - ms := NewSimpleStore("") sm := http.NewServeMux() - s := NewServer(sm, ms) - ts := httptest.NewServer(s) + NewServer(sm, db) + ts := httptest.NewServer(sm) + + tok, err := db.addUser("sm@example.org") + if err != nil { + t.Error("failure to add user: %v", err) + } + url := fmt.Sprintf("%s/foo", ts.URL) client := &http.Client{} req, err := http.NewRequest("PUT", url, nil) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) resp, err := client.Do(req) if err != nil { t.Errorf("couldn't POST: %v", err) } - if len(ms.p) != 0 { - t.Errorf("should have failed to insert; got %d, want %d", len(ms.p), 0) + if len(db.Pkgs()) != 0 { + t.Errorf("should have failed to insert; got %d, want %d", len(db.Pkgs()), 0) } if want := http.StatusMethodNotAllowed; resp.StatusCode != want { t.Errorf("should have failed to post at bad route; got %s, want %s", resp.Status, http.StatusText(want)) @@ -283,27 +443,37 @@ func TestNewServer(t *testing.T) { } func TestDelete(t *testing.T) { - ms := NewSimpleStore("") - sm := http.NewServeMux() - _ = NewServer(sm, ms) - ts := httptest.NewServer(sm) - resp, err := http.Get(ts.URL) - if err != nil { - t.Errorf("couldn't GET: %v", err) + db, done := testDB(t) + if db == nil { + t.Fatalf("could not create temp db") } - resp.Body.Close() - if len(ms.p) != 0 { - t.Errorf("started with something in it; got %d, want %d", len(ms.p), 0) + defer done() + + sm := http.NewServeMux() + NewServer(sm, db) + ts := httptest.NewServer(sm) + + tok, err := db.addUser("sm@example.org") + if err != nil { + t.Error("failure to add user: %v", err) + } + t.Logf("%v", tok) + if len(db.Pkgs()) != 0 { + t.Errorf("started with something in it; got %d, want %d", len(db.Pkgs()), 0) } u := fmt.Sprintf("%s/foo", ts.URL) - resp, err = http.Post(u, "application/json", strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`)) + body := strings.NewReader(`{"repo": "https://s.mcquay.me/sm/vain"}`) + req, err := http.NewRequest("POST", u, body) + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err := http.DefaultClient.Do(req) if err != nil { t.Errorf("couldn't POST: %v", err) } - if got, want := len(ms.p), 1; got != want { - t.Errorf("storage should have something in it; got %d, want %d", got, want) + if got, want := len(db.Pkgs()), 1; got != want { + t.Errorf("pkgs should have something in it; got %d, want %d", got, want) } { @@ -311,6 +481,7 @@ func TestDelete(t *testing.T) { u := fmt.Sprintf("%s/bar", ts.URL) client := &http.Client{} req, err := http.NewRequest("DELETE", u, nil) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) resp, err = client.Do(req) if err != nil { t.Errorf("couldn't POST: %v", err) @@ -319,14 +490,18 @@ func TestDelete(t *testing.T) { t.Errorf("should have not been able to delete unknown package; got %v, want %v", http.StatusText(got), http.StatusText(want)) } } - client := &http.Client{} - req, err := http.NewRequest("DELETE", u, nil) - resp, err = client.Do(req) - if err != nil { - t.Errorf("couldn't POST: %v", err) - } - if got, want := len(ms.p), 0; got != want { - t.Errorf("storage should be empty; got %d, want %d", got, want) + { + client := &http.Client{} + req, err := http.NewRequest("DELETE", u, nil) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) + resp, err = client.Do(req) + if err != nil { + t.Errorf("couldn't POST: %v", err) + } + + if got, want := len(db.Pkgs()), 0; got != want { + t.Errorf("pkgs should be empty; got %d, want %d", got, want) + } } } diff --git a/cmd/vaind/main.go b/cmd/vaind/main.go index 91d0330..7e5baca 100644 --- a/cmd/vaind/main.go +++ b/cmd/vaind/main.go @@ -51,21 +51,45 @@ import ( "github.com/kelseyhightower/envconfig" ) -const usage = `vaind - -environment vars: - -VAIN_PORT: tcp listen port -VAIN_HOST: hostname to use -VAIN_DB: path to json database -` +const usage = "vaind [init] " type config struct { Port int - DB string } func main() { + if len(os.Args) < 2 { + fmt.Fprintf(os.Stderr, "%s\n", usage) + os.Exit(1) + } + + if os.Args[1] == "init" { + if len(os.Args) != 3 { + fmt.Fprintf(os.Stderr, "missing db name: %s\n", usage) + os.Exit(1) + } + + db, err := vain.NewDB(os.Args[2]) + if err != nil { + fmt.Fprintf(os.Stderr, "couldn't open db: %v\n", err) + os.Exit(1) + } + defer db.Close() + + if err := db.Init(); err != nil { + fmt.Fprintf(os.Stderr, "problem initializing the db: %v\n", err) + os.Exit(1) + } + + os.Exit(0) + } + + db, err := vain.NewDB(os.Args[1]) + if err != nil { + fmt.Fprintf(os.Stderr, "couldn't open db: %v\n", err) + os.Exit(1) + } + c := &config{ Port: 4040, } @@ -80,9 +104,6 @@ func main() { os.Exit(0) } } - if c.DB == "" { - log.Printf("warning: in-memory db mode; if you do not want this set VAIN_DB") - } hostname := "localhost" if hn, err := os.Hostname(); err != nil { log.Printf("problem getting hostname:", err) @@ -91,11 +112,7 @@ func main() { } log.Printf("serving at: http://%s:%d/", hostname, c.Port) sm := http.NewServeMux() - ms := vain.NewSimpleStore(c.DB) - if err := ms.Load(); err != nil { - log.Printf("unable to load db: %v; creating fresh database", err) - } - vain.NewServer(sm, ms) + vain.NewServer(sm, db) addr := fmt.Sprintf(":%d", c.Port) if err := http.ListenAndServe(addr, sm); err != nil { log.Printf("problem with http server: %v", err) diff --git a/db.go b/db.go new file mode 100644 index 0000000..f1c5fbc --- /dev/null +++ b/db.go @@ -0,0 +1,268 @@ +package vain + +import ( + "fmt" + "log" + "net/http" + "time" + + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" + + verrors "mcquay.me/vain/errors" + vsql "mcquay.me/vain/sql" +) + +type DB struct { + conn *sqlx.DB +} + +func NewDB(path string) (*DB, error) { + conn, err := sqlx.Open("sqlite3", fmt.Sprintf("file:%s?cache=shared&mode=rwc", path)) + if _, err := conn.Exec("PRAGMA foreign_keys = ON"); err != nil { + return nil, err + } + return &DB{conn}, err +} + +func (db *DB) Init() error { + content, err := vsql.Asset("sql/init.sql") + if err != nil { + return err + } + _, err = db.conn.Exec(string(content)) + return err +} + +func (db *DB) Close() error { + return db.conn.Close() +} + +func (db *DB) AddPackage(p Package) error { + _, err := db.conn.NamedExec( + "INSERT INTO packages(vcs, repo, path, ns) VALUES (:vcs, :repo, :path, :ns)", + &p, + ) + return err +} + +func (db *DB) RemovePackage(path string) error { + _, err := db.conn.Exec("DELETE FROM packages WHERE path = ?", path) + return err +} + +func (db *DB) Pkgs() []Package { + r := []Package{} + rows, err := db.conn.Queryx("SELECT * FROM packages") + if err != nil { + log.Printf("%+v", err) + return nil + } + for rows.Next() { + var p Package + err = rows.StructScan(&p) + if err != nil { + log.Printf("%+v", err) + return nil + } + r = append(r, p) + } + return r +} + +func (db *DB) PackageExists(path string) bool { + var count int + if err := db.conn.Get(&count, "SELECT COUNT(*) FROM packages WHERE path = ?", path); err != nil { + log.Printf("%+v", err) + } + + r := false + switch count { + case 1: + r = true + default: + log.Printf("unexpected count of packages matching %q: %d", path, count) + } + return r +} + +func (db *DB) Package(path string) (Package, error) { + r := Package{} + err := db.conn.Get(&r, "SELECT * FROM packages WHERE path = ?", path) + return r, err +} + +func (db *DB) NSForToken(ns string, tok string) error { + var err error + txn, err := db.conn.Beginx() + if err != nil { + return verrors.HTTP{ + Message: fmt.Sprintf("problem creating transaction: %v", err), + Code: http.StatusInternalServerError, + } + } + defer func() { + if err != nil { + txn.Rollback() + } else { + txn.Commit() + } + }() + + var count int + if err = txn.Get(&count, "SELECT COUNT(*) FROM namespaces WHERE namespaces.ns = ?", ns); err != nil { + return verrors.HTTP{ + Message: fmt.Sprintf("problem matching fetching namespaces matching %q", ns), + Code: http.StatusInternalServerError, + } + } + + if count == 0 { + if _, err = txn.Exec( + "INSERT INTO namespaces(ns, email) SELECT ?, users.email FROM users WHERE users.token = ?", + ns, + tok, + ); err != nil { + return verrors.HTTP{ + Message: fmt.Sprintf("problem inserting %q into namespaces for token %q: %v", ns, tok, err), + Code: http.StatusInternalServerError, + } + } + return err + } + + if err = txn.Get(&count, "SELECT COUNT(*) FROM namespaces JOIN users ON namespaces.email = users.email WHERE users.token = ? AND namespaces.ns = ?", tok, ns); err != nil { + return verrors.HTTP{ + Message: fmt.Sprintf("ns: %q, tok: %q; %v", ns, tok, err), + Code: http.StatusInternalServerError, + } + } + + switch count { + case 1: + err = nil + case 0: + err = verrors.HTTP{ + Message: fmt.Sprintf("not authorized against namespace %q", ns), + Code: http.StatusUnauthorized, + } + default: + err = verrors.HTTP{ + Message: fmt.Sprintf("inconsistent db; found %d results with ns (%s) with token (%s): %d", count, ns, tok), + Code: http.StatusInternalServerError, + } + } + return err +} + +func (db *DB) Register(email string) (string, error) { + var err error + txn, err := db.conn.Beginx() + if err != nil { + return "", verrors.HTTP{ + Message: fmt.Sprintf("problem creating transaction: %v", err), + Code: http.StatusInternalServerError, + } + } + defer func() { + if err != nil { + txn.Rollback() + } else { + txn.Commit() + } + }() + + var count int + if err = txn.Get(&count, "SELECT COUNT(*) FROM users WHERE email = ?", email); err != nil { + return "", verrors.HTTP{ + Message: fmt.Sprintf("could not search for email %q in db: %v", email, err), + Code: http.StatusInternalServerError, + } + } + + if count != 0 { + return "", verrors.HTTP{ + Message: fmt.Sprintf("duplicate email %q", email), + Code: http.StatusConflict, + } + } + + tok := FreshToken() + _, err = txn.Exec( + "INSERT INTO users(email, token, requested) VALUES (?, ?, ?)", + email, + tok, + time.Now(), + ) + return tok, err +} + +func (db *DB) Confirm(token string) (string, error) { + var err error + txn, err := db.conn.Beginx() + if err != nil { + return "", verrors.HTTP{ + Message: fmt.Sprintf("problem creating transaction: %v", err), + Code: http.StatusInternalServerError, + } + } + defer func() { + if err != nil { + txn.Rollback() + } else { + txn.Commit() + } + }() + + var count int + if err = txn.Get(&count, "SELECT COUNT(*) FROM users WHERE token = ?", token); err != nil { + return "", verrors.HTTP{ + Message: fmt.Sprintf("could not perform search for user with token %q in db: %v", token, err), + Code: http.StatusInternalServerError, + } + } + + if count != 1 { + return "", verrors.HTTP{ + Message: fmt.Sprintf("bad token: %s", token), + Code: http.StatusNotFound, + } + } + + newToken := FreshToken() + + _, err = txn.Exec( + "UPDATE users SET token = ?, registered = 1 WHERE token = ?", + newToken, + token, + ) + if err != nil { + return "", verrors.HTTP{ + Message: fmt.Sprintf("couldn't update user with token %q", token), + Code: http.StatusInternalServerError, + } + } + return newToken, nil +} + +func (db *DB) forgot(email string) (string, error) { + var token string + if err := db.conn.Get(&token, "SELECT token FROM users WHERE email = ?", email); err != nil { + return "", verrors.HTTP{ + Message: fmt.Sprintf("could not search for email %q in db: %v", email, err), + Code: http.StatusInternalServerError, + } + } + return token, nil +} + +func (db *DB) addUser(email string) (string, error) { + tok := FreshToken() + _, err := db.conn.Exec( + "INSERT INTO users(email, token, requested) VALUES (?, ?, ?)", + email, + tok, + time.Now(), + ) + return tok, err +} diff --git a/errors/errors.go b/errors/errors.go new file mode 100644 index 0000000..b171165 --- /dev/null +++ b/errors/errors.go @@ -0,0 +1,30 @@ +package errors + +import ( + "fmt" + "net/http" +) + +type HTTP struct { + error + Message string + Code int +} + +func (e HTTP) Error() string { + return fmt.Sprintf("%d: %s", e.Code, e.Message) +} + +func ToHTTP(err error) *HTTP { + if err == nil { + return nil + } + rerr := &HTTP{ + Message: err.Error(), + Code: http.StatusInternalServerError, + } + if e, ok := err.(HTTP); ok { + rerr.Code = e.Code + } + return rerr +} diff --git a/gen.go b/gen.go new file mode 100644 index 0000000..0f186d7 --- /dev/null +++ b/gen.go @@ -0,0 +1,6 @@ +package vain + +//go:generate go get github.com/jteeuwen/go-bindata/... +//go:generate go get github.com/elazarl/go-bindata-assetfs/... +//go:generate rm -f sql/static.go +//go:generate go-bindata -pkg sql -o sql/static.go sql/... diff --git a/server.go b/server.go index de58100..a87d2c9 100644 --- a/server.go +++ b/server.go @@ -3,14 +3,30 @@ package vain import ( "encoding/json" "fmt" + "log" "net/http" "strings" + + verrors "mcquay.me/vain/errors" ) +const apiPrefix = "/api/v0/" + +var prefix map[string]string + +func init() { + prefix = map[string]string{ + "pkgs": apiPrefix + "db/", + "register": apiPrefix + "register/", + "confirm": apiPrefix + "confirm/", + "forgot": apiPrefix + "forgot/", + } +} + // NewServer populates a server, adds the routes, and returns it for use. -func NewServer(sm *http.ServeMux, store Storage) *Server { +func NewServer(sm *http.ServeMux, store *DB) *Server { s := &Server{ - storage: store, + db: store, } addRoutes(sm, s) return s @@ -18,17 +34,43 @@ func NewServer(sm *http.ServeMux, store Storage) *Server { // Server serves up the http. type Server struct { - storage Storage + db *DB } func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { - switch req.Method { - case "GET": + if req.Method == "GET" { + // TODO: perhaps have a nicely formatted page with info as root if + // go-get=1 not in request? fmt.Fprintf(w, "\n\n") - for _, p := range s.storage.All() { + for _, p := range s.db.Pkgs() { fmt.Fprintf(w, "%s\n", p) } - fmt.Fprintf(w, "\n

go tool metadata in head

\n\n") + fmt.Fprintf(w, "\n

go tool metadata in head

\n\n") + return + } + + const prefix = "Bearer " + var tok string + auth := req.Header.Get("Authorization") + if strings.HasPrefix(auth, prefix) { + tok = strings.TrimPrefix(auth, prefix) + } + if tok == "" { + http.Error(w, "missing token", http.StatusUnauthorized) + return + } + ns, err := parseNamespace(req.URL.Path) + if err != nil { + http.Error(w, fmt.Sprintf("could not parse namespace:%v", err), http.StatusBadRequest) + return + } + + if err := verrors.ToHTTP(s.db.NSForToken(ns, tok)); err != nil { + http.Error(w, err.Message, err.Code) + return + } + + switch req.Method { case "POST": if req.URL.Path == "/" { http.Error(w, fmt.Sprintf("invalid path %q", req.URL.Path), http.StatusBadRequest) @@ -40,7 +82,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } if p.Repo == "" { - http.Error(w, fmt.Sprintf("invalid repository %q", req.URL.Path), http.StatusBadRequest) + http.Error(w, fmt.Sprintf("invalid repository %q", p.Repo), http.StatusBadRequest) return } if p.Vcs == "" { @@ -50,22 +92,24 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { http.Error(w, fmt.Sprintf("invalid vcs %q", p.Vcs), http.StatusBadRequest) return } - p.path = fmt.Sprintf("%s/%s", req.Host, strings.Trim(req.URL.Path, "/")) - if !Valid(p.path, s.storage.All()) { + p.Path = fmt.Sprintf("%s/%s", req.Host, strings.Trim(req.URL.Path, "/")) + p.Ns = ns + if !Valid(p.Path, s.db.Pkgs()) { http.Error(w, fmt.Sprintf("invalid path; prefix already taken %q", req.URL.Path), http.StatusConflict) return } - if err := s.storage.Add(p); err != nil { + if err := s.db.AddPackage(p); err != nil { http.Error(w, fmt.Sprintf("unable to add package: %v", err), http.StatusInternalServerError) return } case "DELETE": p := fmt.Sprintf("%s/%s", req.Host, strings.Trim(req.URL.Path, "/")) - if !s.storage.Contains(p) { + if !s.db.PackageExists(p) { http.Error(w, fmt.Sprintf("package %q not found", p), http.StatusNotFound) return } - if err := s.storage.Remove(p); err != nil { + + if err := s.db.RemovePackage(p); err != nil { http.Error(w, fmt.Sprintf("unable to delete package: %v", err), http.StatusInternalServerError) return } @@ -74,13 +118,62 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } -func (s *Server) db(w http.ResponseWriter, req *http.Request) { - all := s.storage.All() +func (s *Server) register(w http.ResponseWriter, req *http.Request) { + req.ParseForm() + email, ok := req.Form["email"] + if !ok || len(email) != 1 { + http.Error(w, "must provide one email parameter", http.StatusBadRequest) + return + } + tok, err := s.db.Register(email[0]) + if err := verrors.ToHTTP(err); err != nil { + http.Error(w, err.Message, err.Code) + return + } + log.Printf("http://%s/api/v0/confirm/%+v", req.Host, tok) + fmt.Fprintf(w, "please check your email\n") +} + +func (s *Server) confirm(w http.ResponseWriter, req *http.Request) { + tok := req.URL.Path[len(prefix["confirm"]):] + tok = strings.TrimRight(tok, "/") + if tok == "" { + http.Error(w, "must provide one email parameter", http.StatusBadRequest) + return + } + tok, err := s.db.Confirm(tok) + if err := verrors.ToHTTP(err); err != nil { + http.Error(w, err.Message, err.Code) + return + } + fmt.Fprintf(w, "new token: %s\n", tok) +} + +func (s *Server) forgot(w http.ResponseWriter, req *http.Request) { + req.ParseForm() + email, ok := req.Form["email"] + if !ok || len(email) != 1 { + http.Error(w, "must provide one email parameter", http.StatusBadRequest) + return + } + tok, err := s.db.forgot(email[0]) + if err := verrors.ToHTTP(err); err != nil { + http.Error(w, err.Message, err.Code) + return + } + log.Printf("http://%s/api/v0/confirm/%+v", req.Host, tok) + fmt.Fprintf(w, "please check your email\n") +} +func (s *Server) pkgs(w http.ResponseWriter, req *http.Request) { w.Header().Set("Content-type", "application/json") - json.NewEncoder(w).Encode(&all) + json.NewEncoder(w).Encode(s.db.Pkgs()) } func addRoutes(sm *http.ServeMux, s *Server) { sm.Handle("/", s) - sm.HandleFunc("/db/", s.db) + + sm.HandleFunc(prefix["pkgs"], s.pkgs) + sm.HandleFunc(prefix["register"], s.register) + sm.HandleFunc(prefix["confirm"], s.confirm) + sm.HandleFunc(prefix["forgot"], s.forgot) } diff --git a/sql/init.sql b/sql/init.sql new file mode 100644 index 0000000..8687c01 --- /dev/null +++ b/sql/init.sql @@ -0,0 +1,18 @@ +CREATE TABLE users ( + email TEXT PRIMARY KEY, + token TEXT UNIQUE, + registered boolean DEFAULT 0, + requested DATETIME +); + +CREATE TABLE namespaces ( + ns TEXT PRIMARY KEY, + email TEXT REFERENCES users(email) ON DELETE CASCADE +); + +CREATE TABLE packages ( + vcs TEXT, + repo TEXT, + path TEXT UNIQUE, + ns TEXT REFERENCES namespaces(ns) ON DELETE CASCADE +); diff --git a/storage.go b/storage.go deleted file mode 100644 index aa2840d..0000000 --- a/storage.go +++ /dev/null @@ -1,126 +0,0 @@ -package vain - -import ( - "encoding/json" - "errors" - "fmt" - "os" - "strings" - "sync" -) - -// Valid checks that p will not confuse the go tool if added to packages. -func Valid(p string, packages []Package) bool { - for _, pkg := range packages { - if strings.HasPrefix(pkg.path, p) || strings.HasPrefix(p, pkg.path) { - return false - } - } - return true -} - -// Storage is a shim to allow for alternate storage types. -type Storage interface { - Contains(name string) bool - Add(p Package) error - Remove(path string) error - All() []Package -} - -// SimpleStore implements a simple json on-disk storage. -type SimpleStore struct { - l sync.RWMutex - p map[string]Package - - dbl sync.Mutex - path string -} - -// NewSimpleStore returns a ready-to-use SimpleStore storing json at path. -func NewSimpleStore(path string) *SimpleStore { - return &SimpleStore{ - path: path, - p: make(map[string]Package), - } -} - -func (ms *SimpleStore) Contains(name string) bool { - _, contains := ms.p[name] - return contains -} - -// Add adds p to the SimpleStore. -func (ss *SimpleStore) Add(p Package) error { - ss.l.Lock() - ss.p[p.path] = p - ss.l.Unlock() - m := "" - if err := ss.Save(); err != nil { - m = fmt.Sprintf("unable to store db: %v", err) - if err := ss.Remove(p.path); err != nil { - m = fmt.Sprintf("%s\nto add insult to injury, could not delete package: %v\n", m, err) - } - return errors.New(m) - } - return nil -} - -// Remove removes p from the SimpleStore. -func (ss *SimpleStore) Remove(path string) error { - ss.l.Lock() - delete(ss.p, path) - ss.l.Unlock() - return nil -} - -// All returns all current packages. -func (ss *SimpleStore) All() []Package { - r := []Package{} - ss.l.RLock() - for _, p := range ss.p { - r = append(r, p) - } - ss.l.RUnlock() - return r -} - -// Save writes the db to disk. -func (ss *SimpleStore) Save() error { - // running in-memory only - if ss.path == "" { - return nil - } - ss.dbl.Lock() - defer ss.dbl.Unlock() - f, err := os.Create(ss.path) - if err != nil { - return err - } - defer f.Close() - return json.NewEncoder(f).Encode(ss.p) -} - -// Load reads the db from disk and populates ss. -func (ss *SimpleStore) Load() error { - // running in-memory only - if ss.path == "" { - return nil - } - ss.dbl.Lock() - defer ss.dbl.Unlock() - f, err := os.Open(ss.path) - if err != nil { - return err - } - defer f.Close() - - in := map[string]Package{} - if err := json.NewDecoder(f).Decode(&in); err != nil { - return err - } - for k, v := range in { - v.path = k - ss.p[k] = v - } - return nil -} diff --git a/storage_test.go b/storage_test.go deleted file mode 100644 index 25beaef..0000000 --- a/storage_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package vain - -import ( - "bytes" - "encoding/json" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "strings" - "testing" -) - -func equiv(a, b map[string]Package) (bool, []error) { - equiv := true - errs := []error{} - if got, want := len(a), len(b); got != want { - equiv = false - errs = append(errs, fmt.Errorf("incorrect number of elements: got %d, want %d", got, want)) - return false, errs - } - - for k := range a { - v, ok := b[k] - if !ok || v != a[k] { - errs = append(errs, fmt.Errorf("missing key: %s", k)) - equiv = false - break - } - } - return equiv, errs -} - -func TestSimpleStorage(t *testing.T) { - root, err := ioutil.TempDir("", "vain-") - if err != nil { - t.Fatalf("problem creating temp dir: %v", err) - } - defer func() { os.RemoveAll(root) }() - db := filepath.Join(root, "vain.json") - ms := NewSimpleStore(db) - orig := map[string]Package{ - "foo": {Vcs: "mercurial"}, - "bar": {Vcs: "bzr"}, - "baz": {}, - } - for k, v := range orig { - v.path = k - orig[k] = v - } - ms.p = orig - if err := ms.Save(); err != nil { - t.Errorf("should have been able to Save: %v", err) - } - ms.p = map[string]Package{} - if err := ms.Load(); err != nil { - t.Errorf("should have been able to Load: %v", err) - } - if ok, errs := equiv(ms.p, orig); !ok { - for _, err := range errs { - t.Error(err) - } - } -} - -func TestRemove(t *testing.T) { - root, err := ioutil.TempDir("", "vain-") - if err != nil { - t.Fatalf("problem creating temp dir: %v", err) - } - defer func() { os.RemoveAll(root) }() - db := filepath.Join(root, "vain.json") - ms := NewSimpleStore(db) - ms.p = map[string]Package{ - "foo": {}, - "bar": {}, - "baz": {}, - } - if err := ms.Remove("foo"); err != nil { - t.Errorf("unexpected error during remove: %v", err) - } - want := map[string]Package{ - "bar": {}, - "baz": {}, - } - if ok, errs := equiv(ms.p, want); !ok { - for _, err := range errs { - t.Error(err) - } - } -} - -func TestPackageJsonParsing(t *testing.T) { - tests := []struct { - input string - output string - parsed Package - }{ - { - input: `{"vcs":"git","repo":"https://s.mcquay.me/sm/ud/"}`, - output: `{"vcs":"git","repo":"https://s.mcquay.me/sm/ud/"}`, - parsed: Package{Vcs: "git", Repo: "https://s.mcquay.me/sm/ud/"}, - }, - { - input: `{"vcs":"hg","repo":"https://s.mcquay.me/sm/ud/"}`, - output: `{"vcs":"hg","repo":"https://s.mcquay.me/sm/ud/"}`, - parsed: Package{Vcs: "hg", Repo: "https://s.mcquay.me/sm/ud/"}, - }, - } - - for _, test := range tests { - p := Package{} - if err := json.NewDecoder(strings.NewReader(test.input)).Decode(&p); err != nil { - t.Error(err) - } - if p != test.parsed { - t.Errorf("got:\n\t%v, want\n\t%v", p, test.parsed) - } - buf := &bytes.Buffer{} - if err := json.NewEncoder(buf).Encode(&p); err != nil { - t.Error(err) - } - if got, want := strings.TrimSpace(buf.String()), test.output; got != want { - t.Errorf("got %v, want %v", got, want) - } - } -} diff --git a/testing.go b/testing.go new file mode 100644 index 0000000..d658450 --- /dev/null +++ b/testing.go @@ -0,0 +1,33 @@ +package vain + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +func testDB(t *testing.T) (*DB, func()) { + dir, err := ioutil.TempDir("", "vain-testing-") + if err != nil { + t.Fatalf("could not create tmpdir for db: %v", err) + return nil, func() {} + } + name := filepath.Join(dir, "test.db") + db, err := NewDB(name) + if err != nil { + t.Fatalf("could not create db: %v", err) + return nil, func() {} + } + + if err := db.Init(); err != nil { + return nil, func() {} + } + + return db, func() { + db.Close() + if err := os.RemoveAll(dir); err != nil { + t.Fatalf("could not clean up tmpdir: %v", err) + } + } +} diff --git a/vain.go b/vain.go index b715c71..2490396 100644 --- a/vain.go +++ b/vain.go @@ -3,7 +3,15 @@ // The executable, cmd/vaind, is located in the respective subdirectory. package vain -import "fmt" +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" +) var vcss = map[string]bool{ "hg": true, @@ -29,14 +37,45 @@ type Package struct { // Repo: the remote repository url Repo string `json:"repo"` - path string + Path string `json:"path"` + Ns string `json:"-"` } func (p Package) String() string { return fmt.Sprintf( "", - p.path, + p.Path, p.Vcs, p.Repo, ) } + +// Valid checks that p will not confuse the go tool if added to packages. +func Valid(p string, packages []Package) bool { + for _, pkg := range packages { + if strings.HasPrefix(pkg.Path, p) || strings.HasPrefix(p, pkg.Path) { + return false + } + } + return true +} + +func parseNamespace(path string) (string, error) { + path = strings.TrimLeft(path, "/") + if path == "" { + return "", errors.New("path does not contain namespace") + } + elems := strings.Split(path, "/") + return elems[0], nil +} + +func FreshToken() string { + buf := &bytes.Buffer{} + io.Copy(buf, io.LimitReader(rand.Reader, 6)) + s := hex.EncodeToString(buf.Bytes()) + r := []string{} + for i := 0; i < len(s)/4; i++ { + r = append(r, s[i*4:(i+1)*4]) + } + return strings.Join(r, "-") +} diff --git a/vain_test.go b/vain_test.go index a7a7ed5..2533b20 100644 --- a/vain_test.go +++ b/vain_test.go @@ -1,6 +1,7 @@ package vain import ( + "errors" "fmt" "testing" ) @@ -8,7 +9,7 @@ import ( func TestString(t *testing.T) { p := Package{ Vcs: "git", - path: "mcquay.me/bps", + Path: "mcquay.me/bps", Repo: "https://s.mcquay.me/sm/bps", } got := fmt.Sprintf("%s", p) @@ -55,59 +56,59 @@ func TestValid(t *testing.T) { }, { pkgs: []Package{ - {path: "bobo"}, + {Path: "bobo"}, }, in: "bobo", want: false, }, { pkgs: []Package{ - {path: "a/b/c"}, + {Path: "a/b/c"}, }, in: "a/b/c", want: false, }, { pkgs: []Package{ - {path: "a/b/c"}, + {Path: "a/b/c"}, }, in: "a/b", want: false, }, { pkgs: []Package{ - {path: "name/db"}, - {path: "name/lib"}, + {Path: "name/db"}, + {Path: "name/lib"}, }, in: "name/foo", want: true, }, { pkgs: []Package{ - {path: "a"}, + {Path: "a"}, }, in: "a/b", want: false, }, { pkgs: []Package{ - {path: "foo"}, + {Path: "foo"}, }, in: "foo/bar", want: false, }, { pkgs: []Package{ - {path: "foo/bar"}, - {path: "foo/baz"}, + {Path: "foo/bar"}, + {Path: "foo/baz"}, }, in: "foo", want: false, }, { pkgs: []Package{ - {path: "bilbo"}, - {path: "frodo"}, + {Path: "bilbo"}, + {Path: "frodo"}, }, in: "foo/bar/baz", want: true, @@ -120,3 +121,51 @@ func TestValid(t *testing.T) { } } } + +func TestNamespaceParsing(t *testing.T) { + tests := []struct { + input string + want string + err error + }{ + { + input: "/sm/foo", + want: "sm", + }, + { + input: "/a/b/c/d", + want: "a", + }, + { + input: "/dm/bar", + want: "dm", + }, + { + input: "/ud", + want: "ud", + }, + // test stripping + { + input: "ud", + want: "ud", + }, + { + input: "/", + err: errors.New("should find no namespace"), + }, + { + input: "", + err: errors.New("should find no namespace"), + }, + } + for _, test := range tests { + got, err := parseNamespace(test.input) + if err != nil && test.err == nil { + t.Errorf("unexpected error parsing %q; got %q, want %q, error: %v", test.input, got, test.want, err) + } + + if got != test.want { + t.Errorf("parse failure: got %q, want %q", got, test.want) + } + } +}