diff --git a/api_test.go b/api_test.go index 7810d94..95ba0a6 100644 --- a/api_test.go +++ b/api_test.go @@ -9,7 +9,7 @@ import ( ) func TestAdd(t *testing.T) { - ms := NewMemStore() + ms := NewMemStore("") s := &Server{ storage: ms, } @@ -60,7 +60,7 @@ func TestAdd(t *testing.T) { } func TestInvalidPath(t *testing.T) { - ms := NewMemStore() + ms := NewMemStore("") s := &Server{ storage: ms, } @@ -80,7 +80,7 @@ func TestInvalidPath(t *testing.T) { } func TestCannotDuplicateExistingPath(t *testing.T) { - ms := NewMemStore() + ms := NewMemStore("") s := &Server{ storage: ms, } @@ -105,7 +105,7 @@ func TestCannotDuplicateExistingPath(t *testing.T) { } func TestCannotAddExistingSubPath(t *testing.T) { - ms := NewMemStore() + ms := NewMemStore("") s := &Server{ storage: ms, } diff --git a/cmd/ysvd/main.go b/cmd/ysvd/main.go index 720d087..7e7aafd 100644 --- a/cmd/ysvd/main.go +++ b/cmd/ysvd/main.go @@ -22,17 +22,18 @@ environment vars: YSV_PORT: tcp listen port YSV_HOST: hostname to use +YSV_DB: path to json database ` type config struct { Port int Host string + DB string } func main() { c := &config{ Port: 4040, - Host: "localhost", } if err := envconfig.Process("ysv", c); err != nil { fmt.Fprintf(os.Stderr, "problem processing environment: %v", err) @@ -41,10 +42,17 @@ func main() { if len(os.Args) > 1 { switch os.Args[1] { case "env", "e", "help", "h": - fmt.Fprintf(os.Stderr, "%s\n", usage) - os.Exit(1) + fmt.Printf("%s\n", usage) + os.Exit(0) } } + if c.Host == "" { + log.Printf("must set YSV_HOST; please run $(ysvd env) for more information") + os.Exit(1) + } + if c.DB == "" { + log.Printf("warning: in-memory db mode; if you do not want this set YSV_DB") + } hostname := "localhost" if hn, err := os.Hostname(); err != nil { log.Printf("problem getting hostname:", err) @@ -53,7 +61,11 @@ func main() { } log.Printf("serving at: http://%s:%d/", hostname, c.Port) sm := http.NewServeMux() - ms := vain.NewMemStore() + ms := vain.NewMemStore(c.DB) + if err := ms.Load(); err != nil { + log.Printf("unable to load db: %v", err) + os.Exit(1) + } vain.NewServer(sm, ms, c.Host) addr := fmt.Sprintf(":%d", c.Port) if err := http.ListenAndServe(addr, sm); err != nil { diff --git a/server.go b/server.go index eef7095..ff0cc52 100644 --- a/server.go +++ b/server.go @@ -35,7 +35,17 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { http.Error(w, fmt.Sprintf("invalid path; prefix already taken %q", req.URL.Path), http.StatusConflict) return } - s.storage.Add(p) + if err := s.storage.Add(p); err != nil { + http.Error(w, fmt.Sprintf("unable to add package: %v", err), http.StatusInternalServerError) + return + } + if err := s.storage.Save(); err != nil { + http.Error(w, fmt.Sprintf("unable to store db: %v", err), http.StatusInternalServerError) + if err := s.storage.Remove(p.Path); err != nil { + fmt.Fprintf(w, "to add insult to injury, could not delete package: %v\n", err) + } + return + } default: http.Error(w, fmt.Sprintf("unsupported method %q; accepted: POST, GET", req.Method), http.StatusMethodNotAllowed) } diff --git a/storage.go b/storage.go index e63d33f..8030680 100644 --- a/storage.go +++ b/storage.go @@ -1,7 +1,8 @@ package vain import ( - "errors" + "encoding/json" + "os" "strings" "sync" ) @@ -18,33 +19,63 @@ func Valid(p string, packages []Package) bool { type MemStore struct { l sync.RWMutex p map[string]Package + + dbl sync.Mutex + path string } -func NewMemStore() *MemStore { +func NewMemStore(path string) *MemStore { return &MemStore{ - p: make(map[string]Package), + path: path, + p: make(map[string]Package), } } -func (ms MemStore) Add(p Package) error { +func (ms *MemStore) Add(p Package) error { ms.l.Lock() ms.p[p.Path] = p ms.l.Unlock() return nil } -func (ms MemStore) Remove(path string) error { +func (ms *MemStore) Remove(path string) error { ms.l.Lock() delete(ms.p, path) ms.l.Unlock() return nil } -func (ms MemStore) Save() error { - return errors.New("save is not implemented") +func (ms *MemStore) Save() error { + // running in-memory only + if ms.path == "" { + return nil + } + ms.dbl.Lock() + defer ms.dbl.Unlock() + f, err := os.Create(ms.path) + if err != nil { + return err + } + defer f.Close() + return json.NewEncoder(f).Encode(ms.p) } -func (ms MemStore) All() []Package { +func (ms *MemStore) Load() error { + // running in-memory only + if ms.path == "" { + return nil + } + ms.dbl.Lock() + defer ms.dbl.Unlock() + f, err := os.Open(ms.path) + if err != nil { + return err + } + defer f.Close() + return json.NewDecoder(f).Decode(&ms.p) +} + +func (ms *MemStore) All() []Package { r := []Package{} ms.l.RLock() for _, p := range ms.p {