vain/db.go
stephen mcquay 680eecb111
Added simple user auth
Fixes #14.

Change-Id: I748933214f43ac7298f1e93c14bb0ee881976d43
2016-04-27 21:29:48 -07:00

269 lines
5.9 KiB
Go

package vain
import (
"fmt"
"log"
"net/http"
"time"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
verrors "mcquay.me/vain/errors"
vsql "mcquay.me/vain/sql"
)
type DB struct {
conn *sqlx.DB
}
func NewDB(path string) (*DB, error) {
conn, err := sqlx.Open("sqlite3", fmt.Sprintf("file:%s?cache=shared&mode=rwc", path))
if _, err := conn.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, err
}
return &DB{conn}, err
}
func (db *DB) Init() error {
content, err := vsql.Asset("sql/init.sql")
if err != nil {
return err
}
_, err = db.conn.Exec(string(content))
return err
}
func (db *DB) Close() error {
return db.conn.Close()
}
func (db *DB) AddPackage(p Package) error {
_, err := db.conn.NamedExec(
"INSERT INTO packages(vcs, repo, path, ns) VALUES (:vcs, :repo, :path, :ns)",
&p,
)
return err
}
func (db *DB) RemovePackage(path string) error {
_, err := db.conn.Exec("DELETE FROM packages WHERE path = ?", path)
return err
}
func (db *DB) Pkgs() []Package {
r := []Package{}
rows, err := db.conn.Queryx("SELECT * FROM packages")
if err != nil {
log.Printf("%+v", err)
return nil
}
for rows.Next() {
var p Package
err = rows.StructScan(&p)
if err != nil {
log.Printf("%+v", err)
return nil
}
r = append(r, p)
}
return r
}
func (db *DB) PackageExists(path string) bool {
var count int
if err := db.conn.Get(&count, "SELECT COUNT(*) FROM packages WHERE path = ?", path); err != nil {
log.Printf("%+v", err)
}
r := false
switch count {
case 1:
r = true
default:
log.Printf("unexpected count of packages matching %q: %d", path, count)
}
return r
}
func (db *DB) Package(path string) (Package, error) {
r := Package{}
err := db.conn.Get(&r, "SELECT * FROM packages WHERE path = ?", path)
return r, err
}
func (db *DB) NSForToken(ns string, tok string) error {
var err error
txn, err := db.conn.Beginx()
if err != nil {
return verrors.HTTP{
Message: fmt.Sprintf("problem creating transaction: %v", err),
Code: http.StatusInternalServerError,
}
}
defer func() {
if err != nil {
txn.Rollback()
} else {
txn.Commit()
}
}()
var count int
if err = txn.Get(&count, "SELECT COUNT(*) FROM namespaces WHERE namespaces.ns = ?", ns); err != nil {
return verrors.HTTP{
Message: fmt.Sprintf("problem matching fetching namespaces matching %q", ns),
Code: http.StatusInternalServerError,
}
}
if count == 0 {
if _, err = txn.Exec(
"INSERT INTO namespaces(ns, email) SELECT ?, users.email FROM users WHERE users.token = ?",
ns,
tok,
); err != nil {
return verrors.HTTP{
Message: fmt.Sprintf("problem inserting %q into namespaces for token %q: %v", ns, tok, err),
Code: http.StatusInternalServerError,
}
}
return err
}
if err = txn.Get(&count, "SELECT COUNT(*) FROM namespaces JOIN users ON namespaces.email = users.email WHERE users.token = ? AND namespaces.ns = ?", tok, ns); err != nil {
return verrors.HTTP{
Message: fmt.Sprintf("ns: %q, tok: %q; %v", ns, tok, err),
Code: http.StatusInternalServerError,
}
}
switch count {
case 1:
err = nil
case 0:
err = verrors.HTTP{
Message: fmt.Sprintf("not authorized against namespace %q", ns),
Code: http.StatusUnauthorized,
}
default:
err = verrors.HTTP{
Message: fmt.Sprintf("inconsistent db; found %d results with ns (%s) with token (%s): %d", count, ns, tok),
Code: http.StatusInternalServerError,
}
}
return err
}
func (db *DB) Register(email string) (string, error) {
var err error
txn, err := db.conn.Beginx()
if err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("problem creating transaction: %v", err),
Code: http.StatusInternalServerError,
}
}
defer func() {
if err != nil {
txn.Rollback()
} else {
txn.Commit()
}
}()
var count int
if err = txn.Get(&count, "SELECT COUNT(*) FROM users WHERE email = ?", email); err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("could not search for email %q in db: %v", email, err),
Code: http.StatusInternalServerError,
}
}
if count != 0 {
return "", verrors.HTTP{
Message: fmt.Sprintf("duplicate email %q", email),
Code: http.StatusConflict,
}
}
tok := FreshToken()
_, err = txn.Exec(
"INSERT INTO users(email, token, requested) VALUES (?, ?, ?)",
email,
tok,
time.Now(),
)
return tok, err
}
func (db *DB) Confirm(token string) (string, error) {
var err error
txn, err := db.conn.Beginx()
if err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("problem creating transaction: %v", err),
Code: http.StatusInternalServerError,
}
}
defer func() {
if err != nil {
txn.Rollback()
} else {
txn.Commit()
}
}()
var count int
if err = txn.Get(&count, "SELECT COUNT(*) FROM users WHERE token = ?", token); err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("could not perform search for user with token %q in db: %v", token, err),
Code: http.StatusInternalServerError,
}
}
if count != 1 {
return "", verrors.HTTP{
Message: fmt.Sprintf("bad token: %s", token),
Code: http.StatusNotFound,
}
}
newToken := FreshToken()
_, err = txn.Exec(
"UPDATE users SET token = ?, registered = 1 WHERE token = ?",
newToken,
token,
)
if err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("couldn't update user with token %q", token),
Code: http.StatusInternalServerError,
}
}
return newToken, nil
}
func (db *DB) forgot(email string) (string, error) {
var token string
if err := db.conn.Get(&token, "SELECT token FROM users WHERE email = ?", email); err != nil {
return "", verrors.HTTP{
Message: fmt.Sprintf("could not search for email %q in db: %v", email, err),
Code: http.StatusInternalServerError,
}
}
return token, nil
}
func (db *DB) addUser(email string) (string, error) {
tok := FreshToken()
_, err := db.conn.Exec(
"INSERT INTO users(email, token, requested) VALUES (?, ?, ?)",
email,
tok,
time.Now(),
)
return tok, err
}