diff --git a/db.go b/db.go index d8ad656..af2774f 100644 --- a/db.go +++ b/db.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "os" + "strings" "sync" "time" @@ -75,16 +76,30 @@ func (m *MemDB) NSForToken(ns namespace, tok Token) error { // Package fetches the package associated with path. func (m *MemDB) Package(pth string) (Package, error) { m.l.RLock() + defer m.l.RUnlock() + pkg, ok := m.Packages[path(pth)] - m.l.RUnlock() + if ok { + return pkg, nil + } + + var longest Package + for _, p := range m.Packages { + if splitPathHasPrefix(strings.Split(pth, "/"), strings.Split(p.Path, "/")) { + if len(p.Path) > len(longest.Path) { + longest = p + } + } + } + var err error - if !ok { + if longest.Path == "" { err = verrors.HTTP{ Message: fmt.Sprintf("couldn't find package %q", pth), Code: http.StatusNotFound, } } - return pkg, err + return longest, err } // AddPackage adds p into packages table. diff --git a/db_test.go b/db_test.go new file mode 100644 index 0000000..0300686 --- /dev/null +++ b/db_test.go @@ -0,0 +1,80 @@ +package vain + +import ( + "errors" + "testing" +) + +func TestPartialPackage(t *testing.T) { + db, done := TestDB(t) + if db == nil { + t.Fatalf("could not create temp db") + } + defer done() + + paths := []path{ + "a/b", + "a/c", + "a/d/c", + "a/d/e", + + "f/b/c/d", + "f/b/c/e", + } + + for _, p := range paths { + db.Packages[p] = Package{Path: string(p)} + } + + tests := []struct { + pth string + pkg Package + err error + }{ + // obvious + { + pth: "a/b", + pkg: db.Packages["a/b"], + }, + { + pth: "a/d/c", + pkg: db.Packages["a/d/c"], + }, + + // here we exercise the code that matches closest submatch + { + pth: "a/b/c", + pkg: db.Packages["a/b"], + }, + { + pth: "f/b/c/d/e/f/g", + pkg: db.Packages["f/b/c/d"], + }, + + // some errors + { + pth: "foo", + err: errors.New("shouldn't find"), + }, + + { + pth: "a/d/f", + err: errors.New("shouldn't find"), + }, + } + + for _, test := range tests { + p, err := db.Package(test.pth) + + if got, want := p, test.pkg; got != want { + t.Errorf("bad package fetched: got %+v, want %+v", got, want) + } + + got := err + want := test.err + if (got == nil) != (want == nil) { + t.Errorf("unexpected error; got %v, want %v", got, want) + } + } + +}