diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..ea75d2d --- /dev/null +++ b/auth.go @@ -0,0 +1,36 @@ +package allowances + +import ( + "encoding/json" + "log" + "net/http" +) + +const sessionName = "allowances" + +type authed func(w http.ResponseWriter, r *http.Request, uid string) error + +func (a *Allowances) protected(handler authed) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + session, err := a.store.Get(r, sessionName) + if err != nil { + log.Printf("%+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + u, ok := session.Values["uuid"] + if !ok { + http.Redirect( + w, r, + prefix["login"], + http.StatusTemporaryRedirect, + ) + return + } + err = handler(w, r, u.(string)) + if err != nil { + json.NewEncoder(w).Encode(NewFailure(err.Error())) + return + } + } +} diff --git a/cmd/allowances/main.go b/cmd/allowances/main.go new file mode 100644 index 0000000..8e5ab7e --- /dev/null +++ b/cmd/allowances/main.go @@ -0,0 +1,105 @@ +package main + +import ( + "fmt" + "log" + "net/http" + "os" + "strconv" + + "github.com/bgentry/speakeasy" + + "mcquay.me/allowances" +) + +const usage = `allowances app + +subcommands: + pw -- manage password file + serve -- serve webapp +` + +const pwUsage = `allowances pw + +subcommands: + add + test +` + +func main() { + if len(os.Args) < 2 { + fmt.Fprintf(os.Stderr, usage) + os.Exit(1) + } + + subcommand := os.Args[1] + + switch subcommand { + case "pw": + pwCmd := os.Args[2:] + if len(pwCmd) != 2 { + fmt.Fprintf(os.Stderr, "%s\n", pwUsage) + os.Exit(1) + } + switch pwCmd[0] { + case "add": + pw, err := speakeasy.Ask("new pass: ") + if err != nil { + fmt.Fprintf(os.Stderr, "failure to get password: %v", err) + os.Exit(1) + } + if err := allowances.AddPassword(pwCmd[1], pw); err != nil { + fmt.Fprintf(os.Stderr, "problem adding password: %v", err) + os.Exit(1) + } + case "test": + passes, _, err := allowances.GetHashes(pwCmd[1]) + if err != nil { + fmt.Fprintf(os.Stderr, "problem opening passes file: %v", err) + os.Exit(1) + } + pw, err := speakeasy.Ask("check password: ") + if err != nil { + panic(err) + } + ok, err := passes.Check(pw) + if err != nil { + panic(err) + } + if !ok { + fmt.Fprintf(os.Stderr, "bad password") + os.Exit(1) + } + default: + fmt.Fprintf(os.Stderr, "%s\n", pwUsage) + os.Exit(1) + } + case "serve": + sm := http.NewServeMux() + dbfile := os.Getenv("DB") + passfile := os.Getenv("PASSES") + _, err := allowances.NewAllowances(sm, dbfile, passfile, os.Getenv("STATIC")) + if err != nil { + fmt.Fprintf(os.Stderr, "unable to initialize web server: %v\n", err) + os.Exit(1) + } + port := 8000 + if os.Getenv("PORT") != "" { + p, err := strconv.Atoi(os.Getenv("PORT")) + if err != nil { + fmt.Fprintf(os.Stderr, "problem parsing port from env: %v\n", err) + os.Exit(1) + } + port = p + } + addr := fmt.Sprintf(":%d", port) + log.Printf("%+v", addr) + err = http.ListenAndServe(addr, sm) + if err != nil { + panic(err) + } + default: + fmt.Fprintf(os.Stderr, "unknown subcommand %s\n\n%s\n", subcommand, usage) + os.Exit(1) + } +} diff --git a/db.go b/db.go index 9393448..5109e55 100644 --- a/db.go +++ b/db.go @@ -1,58 +1,67 @@ -package main +package allowances import ( - "golang.org/x/crypto/bcrypt" "encoding/json" "io/ioutil" "log" + "os" "sync" + + "golang.org/x/crypto/bcrypt" ) -var dbMutex = sync.RWMutex{} - -func get_passes(filename string) (cur_passes []string, err error) { - b, err := ioutil.ReadFile(filename) - if err != nil { - log.Fatal(err) +func GetHashes(filename string) (Passes, bool, error) { + r := []string{} + exists := false + if !Exists(filename) { + return r, exists, nil } - err = json.Unmarshal(b, &cur_passes) + exists = true + f, err := os.Open(filename) if err != nil { - log.Fatal(err) + return nil, exists, err } - return + err = json.NewDecoder(f).Decode(&r) + if err != nil { + return nil, exists, err + } + return r, exists, nil } -func add_password(filename, new_pw string) (err error) { - cur_passes, err := get_passes(filename) +func AddPassword(filename, pw string) error { + curPasses, _, err := GetHashes(filename) if err != nil { - log.Fatal(err) + return err } hpass, err := bcrypt.GenerateFromPassword( - []byte(*add_pw), bcrypt.DefaultCost) - cur_passes = append(cur_passes, string(hpass)) - b, err := json.Marshal(cur_passes) - err = ioutil.WriteFile(filename, b, 0644) + []byte(pw), bcrypt.DefaultCost) + curPasses = append(curPasses, string(hpass)) + + f, err := os.Create(filename) if err != nil { - log.Fatal(err) + return err } - return + if err := json.NewEncoder(f).Encode(curPasses); err != nil { + return err + } + return nil } -func check_password(filename, attempt string) (result bool) { - hashes, err := get_passes(filename) - if err != nil { - log.Fatal(err) - } - for _, hash := range hashes { +type Passes []string + +func (p Passes) Check(attempt string) (bool, error) { + // TODO: parallelize + for _, hash := range p { err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(attempt)) if err == nil { - result = true - return + return true, err } } - return + return false, nil } +var dbMutex = sync.RWMutex{} + func loadChildren(filename string) (children map[string]int) { dbMutex.RLock() defer dbMutex.RUnlock() @@ -70,7 +79,7 @@ func loadChildren(filename string) (children map[string]int) { func dumpChildren(filename string, children map[string]int) { dbMutex.Lock() defer dbMutex.Unlock() - b, err := json.Marshal(children) + b, err := json.Marshal(children) err = ioutil.WriteFile(filename, b, 0644) if err != nil { log.Fatal("serious issue writing children db file", err) diff --git a/fs.go b/fs.go new file mode 100644 index 0000000..027f1eb --- /dev/null +++ b/fs.go @@ -0,0 +1,12 @@ +package allowances + +import "os" + +func Exists(path string) bool { + if _, err := os.Stat(path); err != nil { + return false + } else if os.IsNotExist(err) { + return false + } + return true +} diff --git a/gen.go b/gen.go new file mode 100644 index 0000000..2ffe8a6 --- /dev/null +++ b/gen.go @@ -0,0 +1,6 @@ +package allowances + +//go:generate go get github.com/jteeuwen/go-bindata/... +//go:generate go get github.com/elazarl/go-bindata-assetfs/... +//go:generate rm -vf static.go +//go:generate go-bindata -o static.go -pkg=allowances static/... templates/... diff --git a/handlers.go b/handlers.go index 6d789fd..c43cc52 100644 --- a/handlers.go +++ b/handlers.go @@ -1,60 +1,118 @@ -package main +package allowances import ( "encoding/json" + "fmt" "log" "net/http" "strconv" "strings" + + "github.com/gorilla/sessions" ) -func homeHandler(w http.ResponseWriter, req *http.Request) { - session, _ := store.Get(req, "creds") - loggedIn := session.Values["logged in"] - if loggedIn == nil { - http.Redirect(w, req, "/login", http.StatusSeeOther) - return - } - children := loadChildren(*db_file) - T("index.html").Execute(w, map[string]interface{}{ - "children": children}) +func init() { + log.SetFlags(log.Lshortfile | log.Ltime) } -func loginHandler(w http.ResponseWriter, req *http.Request) { - pwAttempt := req.FormValue("passwd") - if check_password(*passes_file, pwAttempt) { - session, _ := store.Get(req, "creds") - session.Values["logged in"] = true +type failure struct { + Success bool `json:"success"` + Error string `json:"error"` +} + +func NewFailure(msg string) *failure { + return &failure{ + Success: false, + Error: msg, + } +} + +type Allowances struct { + db string + hashes Passes + store *sessions.CookieStore +} + +func NewAllowances(sm *http.ServeMux, dbfile, passfile, staticFiles string) (*Allowances, error) { + var err error + tmpls, err = getTemplates() + if err != nil { + return nil, err + } + hashes, exists, err := GetHashes(passfile) + if !exists { + return nil, fmt.Errorf("passes file doesn't exist: %q", passfile) + } + if err != nil { + return nil, err + } + + if !Exists(dbfile) { + return nil, fmt.Errorf("child db file doesn't exist: %q", dbfile) + } + r := &Allowances{ + db: dbfile, + hashes: hashes, + store: sessions.NewCookieStore([]byte("hello world")), + } + addRoutes(sm, r, staticFiles) + return r, nil +} + +func (a *Allowances) home(w http.ResponseWriter, req *http.Request, uid string) error { + children := loadChildren(a.db) + tmpls["home"].Execute(w, map[string]interface{}{"children": children}) + return nil +} + +func (a *Allowances) login(w http.ResponseWriter, req *http.Request) { + attempt := req.FormValue("passwd") + ok, err := a.hashes.Check(attempt) + if err != nil { + b, _ := json.Marshal(NewFailure(err.Error())) + http.Error(w, string(b), http.StatusBadRequest) + return + } + if ok { + session, _ := a.store.Get(req, sessionName) + session.Values["uuid"] = "me" session.Save(req, w) http.Redirect(w, req, "/", http.StatusSeeOther) return } - T("login.html").Execute(w, map[string]interface{}{}) + tmpls["login"].Execute(w, map[string]interface{}{}) } -func logoutHandler(w http.ResponseWriter, req *http.Request) { - session, _ := store.Get(req, "creds") - delete(session.Values, "logged in") +func (a *Allowances) logout(w http.ResponseWriter, req *http.Request, u string) error { + session, err := a.store.Get(req, sessionName) + if err != nil { + return err + } + delete(session.Values, "uuid") session.Save(req, w) http.Redirect(w, req, "/", http.StatusSeeOther) - return + return nil } -func addHandler(w http.ResponseWriter, req *http.Request) { - path := req.URL.Path[len(addPath):] +func (a *Allowances) add(w http.ResponseWriter, req *http.Request, uid string) error { + path := req.URL.Path[len(prefix["add"]):] bits := strings.Split(path, "/") child := bits[0] amount, err := strconv.Atoi(bits[1]) if err != nil { - log.Fatal("couldn't parse a dollar amount", err) + return fmt.Errorf("couldn't parse a dollar amount: %v", err) } - children := loadChildren(*db_file) + children := loadChildren(a.db) children[child] += amount - defer dumpChildren(*db_file, children) + defer dumpChildren(a.db, children) w.Header().Set("Content-Type", "application/json") - b, err := json.Marshal(map[string]interface{}{ - "amount": dollarize(children[child]), - "name": child, - }) + b, err := json.Marshal(map[string]interface{}{ + "amount": dollarize(children[child]), + "name": child, + }) + if err != nil { + return err + } w.Write(b) + return nil } diff --git a/main.go b/main.go deleted file mode 100644 index f23eb56..0000000 --- a/main.go +++ /dev/null @@ -1,43 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "github.com/gorilla/sessions" - "html/template" - "log" - "net/http" -) - -var addr = flag.String("addr", ":8000", "address I'll listen on.") -var static_files = flag.String("static", "./static", "location of static files") -var passes_file = flag.String("passes", "passwds.json", "the password database") -var db_file = flag.String("children", "children.json", "the children database") -var template_dir = flag.String("templates", "templates", "template dir") -var add_pw = flag.String("passwd", "", "add this pass to the db") -var check_pw = flag.String("checkpw", "", "check if this pw is in db") - -var addPath = "/add/" - -var store = sessions.NewCookieStore([]byte("hello world")) -var templates *template.Template - -func main() { - flag.Parse() - if *add_pw != "" { - add_password(*passes_file, *add_pw) - } else if *check_pw != "" { - fmt.Printf("valid password: %v\n", - check_password(*passes_file, *check_pw)) - } else { - http.HandleFunc("/", homeHandler) - http.HandleFunc("/login", loginHandler) - http.HandleFunc("/logout", logoutHandler) - http.HandleFunc(addPath, addHandler) - http.Handle("/s/", http.StripPrefix("/s/", - http.FileServer(http.Dir(*static_files)))) - if err := http.ListenAndServe(*addr, nil); err != nil { - log.Fatal("ListenAndServe:", err) - } - } -} diff --git a/routes.go b/routes.go new file mode 100644 index 0000000..9852ecf --- /dev/null +++ b/routes.go @@ -0,0 +1,53 @@ +package allowances + +import ( + "net/http" + + "github.com/elazarl/go-bindata-assetfs" + "github.com/gorilla/context" +) + +var prefix map[string]string + +func addRoutes(sm *http.ServeMux, a *Allowances, staticFiles string) { + prefix = map[string]string{ + "static": "/s/", + "auth": "/api/v0/auth/", + "reset": "/api/v0/auth/reset/", + "add": "/add/", + "login": "/login/", + "logout": "/logout/", + } + + sm.HandleFunc("/", a.protected(a.home)) + sm.HandleFunc(prefix["login"], a.login) + sm.HandleFunc(prefix["logout"], a.protected(a.logout)) + sm.HandleFunc(prefix["add"], a.protected(a.add)) + + if staticFiles == "" { + sm.Handle( + prefix["static"], + http.StripPrefix( + prefix["static"], + http.FileServer( + &assetfs.AssetFS{ + Asset: Asset, + AssetDir: AssetDir, + Prefix: "static", + }, + ), + ), + ) + + } else { + sm.Handle( + prefix["static"], + http.StripPrefix( + prefix["static"], + http.FileServer(http.Dir(staticFiles)), + ), + ) + } + + context.ClearHandler(sm) +} diff --git a/template.go b/template.go deleted file mode 100644 index 52f0f88..0000000 --- a/template.go +++ /dev/null @@ -1,34 +0,0 @@ -package main - -import ( - "fmt" - "html/template" - "path/filepath" - "sync" -) - -var cachedTemplates = map[string]*template.Template{} -var cachedMutex sync.Mutex - -func dollarize(value int) string { - return fmt.Sprintf("$%0.2f", float32(value)/100.0) -} - -var funcs = template.FuncMap{ - "dollarize": dollarize, -} - -func T(name string) *template.Template { - cachedMutex.Lock() - defer cachedMutex.Unlock() - if t, ok := cachedTemplates[name]; ok { - return t - } - t := template.New("_base.html").Funcs(funcs) - t = template.Must(t.ParseFiles( - "templates/_base.html", - filepath.Join(*template_dir, name), - )) - cachedTemplates[name] = t - return t -} diff --git a/templates.go b/templates.go new file mode 100644 index 0000000..949e587 --- /dev/null +++ b/templates.go @@ -0,0 +1,60 @@ +package allowances + +import ( + "fmt" + "html/template" + "log" + "strings" +) + +func dollarize(value int) string { + return fmt.Sprintf("$%0.2f", float32(value)/100.0) +} + +type tmap map[string]*template.Template + +var tmpls tmap + +func getTemplates() (tmap, error) { + var err error + funcMap := template.FuncMap{ + "title": strings.Title, + "dollarize": dollarize, + } + base, err := Asset("templates/base.html") + if err != nil { + return nil, err + } + tmpl, err := template.New("base").Funcs(funcMap).Parse(string(base)) + if err != nil { + return nil, err + } + templates := make(map[string]*template.Template) + templateFiles := []struct { + name string + path string + }{ + {"home", "templates/index.html"}, + {"login", "templates/login.html"}, + } + + for _, tf := range templateFiles { + a, err := Asset(tf.path) + if err != nil { + return nil, err + } + t, err := tmpl.Clone() + if err != nil { + return nil, err + } + + t, err = t.Parse(string(a)) + if err != nil { + log.Printf("XXX: %+v", err) + return nil, err + } + templates[tf.name] = t + } + + return templates, nil +} diff --git a/templates/_base.html b/templates/base.html similarity index 100% rename from templates/_base.html rename to templates/base.html diff --git a/templates/login.html b/templates/login.html index b9f40fd..db34a5e 100644 --- a/templates/login.html +++ b/templates/login.html @@ -3,7 +3,7 @@ {{ define "content" }}
-
+