Refactor database login into a separate module

This commit is contained in:
Daan Sprenkels 2019-12-03 23:08:58 +01:00
parent 8b87cd0f8a
commit 0cfad96b68
10 changed files with 489 additions and 430 deletions

View File

@ -5,6 +5,7 @@ import (
"log" "log"
"gitea.hashru.nl/dsprenkels/rushlink" "gitea.hashru.nl/dsprenkels/rushlink"
"gitea.hashru.nl/dsprenkels/rushlink/internal/db"
) )
var ( var (
@ -17,14 +18,16 @@ var (
func main() { func main() {
flag.Parse() flag.Parse()
if err := rushlink.OpenDB(*databasePath); err != nil { database, err := db.OpenDB(*databasePath)
if err != nil {
log.Fatalln(err) log.Fatalln(err)
} }
defer rushlink.CloseDB() defer database.Close()
if err := rushlink.OpenFileStore(*fileStorePath); err != nil { filestore, err := db.OpenFileStore(*fileStorePath)
if err != nil {
log.Fatalln(err) log.Fatalln(err)
} }
go rushlink.StartMetricsServer(*metricsListen) go rushlink.StartMetricsServer(*metricsListen, database)
rushlink.StartMainServer(*httpListen) rushlink.StartMainServer(*httpListen, database, filestore)
} }

135
db.go
View File

@ -1,135 +0,0 @@
package rushlink
import (
"fmt"
"log"
"time"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
)
var DB *bolt.DB
// The current database version
//
// If we alter the database format, we bump this number and write a new
// database migration in migrateDatabase().
const CURRENT_MIGRATE_VERSION = 2
// Bucket storing everything that is not a bulk value. This includes stuff like
// the database version, secret site-wide keys.
const BUCKET_CONF = "conf"
// The main bucket for paste values and URL redirects
const BUCKET_PASTES = "pastes"
// The main bucket for file uploads
const BUCKET_FILE_UPLOAD = "fileUpload"
// This value stores the current migration version. If this value is less than
// CURRENT_MIGRATE_VERSION, the database has to be migrated.
const KEY_MIGRATE_VERSION = "migrate_version"
// Open the bolt database
func OpenDB(path string) error {
if path == "" {
return errors.New("database not set")
}
var err error
DB, err = bolt.Open(path, 0666, &bolt.Options{Timeout: 1 * time.Second})
if err != nil {
return errors.Wrapf(err, "failed to open database at '%v'", path)
}
return DB.Update(migrateDatabase)
}
// Close the bolt database
func CloseDB() error {
if DB == nil {
panic("no open database")
}
return DB.Close()
}
// Initialize and migrate the database to the current version
func migrateDatabase(tx *bolt.Tx) error {
dbVersion, err := dbVersion(tx)
if err != nil {
return err
}
// Migrate the database to version 1
if dbVersion < 1 {
log.Println("migrating database to version 1")
// Create conf bucket
_, err := tx.CreateBucket([]byte(BUCKET_CONF))
if err != nil {
return err
}
// Create paste bucket
_, err = tx.CreateBucket([]byte(BUCKET_PASTES))
if err != nil {
return err
}
// Update the version number
if err := setDBVersion(tx, 1); err != nil {
return err
}
}
if dbVersion < 2 {
log.Println("migrating database to version 2")
// Create fileUpload bucket
_, err := tx.CreateBucket([]byte(BUCKET_FILE_UPLOAD))
if err != nil {
return err
}
// Update the version number
if err := setDBVersion(tx, 2); err != nil {
return err
}
}
return nil
}
// Get the current migrate version from the database
func dbVersion(tx *bolt.Tx) (int, error) {
conf := tx.Bucket([]byte(BUCKET_CONF))
if conf == nil {
return 0, nil
}
dbVersionBytes := conf.Get([]byte(KEY_MIGRATE_VERSION))
if dbVersionBytes == nil {
return 0, nil
}
// Version was already stored
var dbVersion int
if err := Unmarshal(dbVersionBytes, &dbVersion); err != nil {
return 0, err
}
if dbVersion == 0 {
return 0, fmt.Errorf("database version is invalid (%v)", dbVersion)
}
if dbVersion > CURRENT_MIGRATE_VERSION {
return 0, fmt.Errorf("database version is too recent (%v > %v)", dbVersion, CURRENT_MIGRATE_VERSION)
}
return dbVersion, nil
}
// Update the current migrate version in the database
func setDBVersion(tx *bolt.Tx, version int) error {
conf, err := tx.CreateBucketIfNotExists([]byte(BUCKET_CONF))
if err != nil {
return err
}
versionBytes, err := Marshal(version)
if err != nil {
return err
}
return conf.Put([]byte(KEY_MIGRATE_VERSION), versionBytes)
}

View File

@ -1,33 +1,22 @@
package rushlink package rushlink
import ( import (
"encoding/hex"
"hash/crc32"
"io"
"log" "log"
"net/http" "net/http"
"net/url"
"os"
"path"
"github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
) )
// Use the Castagnoli checksum because of the acceleration on Intel CPUs
var checksumTable = crc32.MakeTable(crc32.Castagnoli)
// Where to store the uploaded files // Where to store the uploaded files
var fileStoreDir = "" var fileStoreDir = ""
// Custom HTTP filesystem handler // FileUploadFileSystem is a HTTP filesystem handler
type fileUploadFileSystem struct { type FileUploadFileSystem struct {
fs http.FileSystem fs http.FileSystem
} }
// Open opens file // Open opens a file
func (fs fileUploadFileSystem) Open(path string) (http.File, error) { func (fs FileUploadFileSystem) Open(path string) (http.File, error) {
log.Println(path) log.Println(path)
file, err := fs.fs.Open(path) file, err := fs.fs.Open(path)
if err != nil { if err != nil {
@ -42,153 +31,3 @@ func (fs fileUploadFileSystem) Open(path string) (http.File, error) {
} }
return file, nil return file, nil
} }
type fileUploadState int
type fileUpload struct {
State fileUploadState
ID uuid.UUID
FileName string
ContentType string
Checksum uint32
}
const (
dirMode os.FileMode = 0750
fileMode os.FileMode = 0640
)
const (
fileUploadStateUndef fileUploadState = 0
fileUploadStatePresent = 1
fileUploadStateDeleted = 2
)
func (t fileUploadState) String() string {
switch t {
case fileUploadStateUndef:
return "unknown"
case fileUploadStatePresent:
return "present"
case fileUploadStateDeleted:
return "deleted"
default:
return "invalid"
}
}
func OpenFileStore(path string) error {
if path == "" {
return errors.New("file-store not set")
}
// Try to create the file store directory if it does not yet exist
if err := os.MkdirAll(path, dirMode); err != nil {
return errors.Wrap(err, "creating file store directory")
}
fileStoreDir = path[:]
return nil
}
func newFileUpload(tx *bolt.Tx, r io.Reader, fileName string, contentType string) (*fileUpload, error) {
id, err := uuid.NewRandom()
if err != nil {
return nil, errors.Wrap(err, "generating UUID")
}
filePath := fileStorePath(id, fileName)
if err := os.Mkdir(path.Dir(filePath), dirMode); err != nil {
return nil, errors.Wrap(err, "creating file dir")
}
file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, fileMode)
if err != nil {
return nil, errors.Wrap(err, "opening file")
}
defer file.Close()
hash := crc32.New(checksumTable)
tee := io.TeeReader(r, hash)
_, err = io.Copy(file, tee)
if err != nil {
return nil, errors.Wrap(err, "writing to file")
}
fu := &fileUpload{
State: fileUploadStatePresent,
ID: id,
FileName: fileName,
ContentType: contentType,
Checksum: hash.Sum32(),
}
if err := fu.save(tx); err != nil {
return nil, err
}
return fu, nil
}
func getFileUpload(tx *bolt.Tx, id uuid.UUID) (*fileUpload, error) {
bucket := tx.Bucket([]byte(BUCKET_FILE_UPLOAD))
if bucket == nil {
return nil, errors.Errorf("bucket %v does not exist", BUCKET_FILE_UPLOAD)
}
storedBytes := bucket.Get(id[:])
if storedBytes == nil {
return nil, nil
}
fu := &fileUpload{}
err := Unmarshal(storedBytes, fu)
return fu, err
}
func (fu *fileUpload) save(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(BUCKET_FILE_UPLOAD))
if bucket == nil {
return errors.Errorf("bucket %v does not exist", BUCKET_FILE_UPLOAD)
}
buf, err := Marshal(fu)
if err != nil {
return errors.Wrap(err, "encoding for database failed")
}
if err := bucket.Put(fu.ID[:], buf); err != nil {
return errors.Wrap(err, "database transaction failed")
}
return nil
}
func (fu *fileUpload) delete(tx *bolt.Tx) error {
// Remove the file in the backend
filePath := fileStorePath(fu.ID, fu.FileName)
if err := os.Remove(filePath); err != nil {
return err
}
// Update the file in the server
if err := (&fileUpload{
ID: fu.ID,
State: fileUploadStateDeleted,
}).save(tx); err != nil {
return err
}
// Cleanup the parent directory
wrap := "deletion succeeded, but removing the file directory has failed"
return errors.Wrap(os.Remove(path.Dir(filePath)), wrap)
}
func (fu *fileUpload) url() *url.URL {
rawurl := "/uploads/" + hex.EncodeToString(fu.ID[:]) + "/" + fu.FileName
urlParse, err := url.Parse(rawurl)
if err != nil {
panic("could not construct /uploads/ url")
}
return urlParse
}
func fileStorePath(id uuid.UUID, fileName string) string {
if fileStoreDir == "" {
panic("fileStoreDir called while the file store path has not been set")
}
return path.Join(fileStoreDir, hex.EncodeToString(id[:]), fileName)
}

View File

@ -2,7 +2,6 @@ package rushlink
import ( import (
"crypto/subtle" "crypto/subtle"
"encoding/base64"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -12,6 +11,7 @@ import (
"os" "os"
"time" "time"
"gitea.hashru.nl/dsprenkels/rushlink/internal/db"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -26,32 +26,25 @@ const (
viewShowMeta viewShowMeta
) )
const CookieDeleteToken = "owner_token" const cookieDeleteToken = "owner_token"
// These keys are designated reserved, and will not be randomly chosen func (rl *rushlink) indexGetHandler(w http.ResponseWriter, r *http.Request) {
var ReservedPasteKeys = []string{"xd42", "example"}
// Base64 encoding and decoding
var base64Alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
var base64Encoder = base64.RawURLEncoding.WithPadding(base64.NoPadding)
func indexGetHandler(w http.ResponseWriter, r *http.Request) {
render(w, r, "index", map[string]interface{}{}) render(w, r, "index", map[string]interface{}{})
} }
func uploadFileGetHandler(w http.ResponseWriter, r *http.Request) { func (rl *rushlink) uploadFileGetHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
id := vars["id"] id := vars["id"]
var fu *fileUpload var fu *db.FileUpload
var badID bool var badID bool
if err := DB.View(func(tx *bolt.Tx) error { if err := rl.db.Bolt.View(func(tx *bolt.Tx) error {
fuID, err := uuid.Parse(id) fuID, err := uuid.Parse(id)
if err != nil { if err != nil {
badID = true badID = true
return err return err
} }
fu, err = getFileUpload(tx, fuID) fu, err = db.GetFileUpload(tx, fuID)
return err return err
}); err != nil { }); err != nil {
if badID { if badID {
@ -62,7 +55,7 @@ func uploadFileGetHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
filePath := fileStorePath(fu.ID, fu.FileName) filePath := rl.fs.FilePath(fu.ID, fu.FileName)
file, err := os.Open(filePath) file, err := os.Open(filePath)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -77,35 +70,35 @@ func uploadFileGetHandler(w http.ResponseWriter, r *http.Request) {
io.Copy(w, file) io.Copy(w, file)
} }
func viewPasteHandler(w http.ResponseWriter, r *http.Request) { func (rl *rushlink) viewPasteHandler(w http.ResponseWriter, r *http.Request) {
viewPasteHandlerInner(w, r, 0) rl.viewPasteHandlerInner(w, r, 0)
} }
func viewPasteHandlerNoRedirect(w http.ResponseWriter, r *http.Request) { func (rl *rushlink) viewPasteHandlerNoRedirect(w http.ResponseWriter, r *http.Request) {
viewPasteHandlerInner(w, r, viewNoRedirect) rl.viewPasteHandlerInner(w, r, viewNoRedirect)
} }
func viewPasteHandlerMeta(w http.ResponseWriter, r *http.Request) { func (rl *rushlink) viewPasteHandlerMeta(w http.ResponseWriter, r *http.Request) {
viewPasteHandlerInner(w, r, viewShowMeta) rl.viewPasteHandlerInner(w, r, viewShowMeta)
} }
func viewPasteHandlerInner(w http.ResponseWriter, r *http.Request, flags viewPaste) { func (rl *rushlink) viewPasteHandlerInner(w http.ResponseWriter, r *http.Request, flags viewPaste) {
vars := mux.Vars(r) vars := mux.Vars(r)
key := vars["key"] key := vars["key"]
var p *paste var p *db.Paste
var fuID *uuid.UUID var fuID *uuid.UUID
var fu *fileUpload var fu *db.FileUpload
if err := DB.View(func(tx *bolt.Tx) error { if err := rl.db.Bolt.View(func(tx *bolt.Tx) error {
var err error var err error
p, err = getPaste(tx, key) p, err = db.GetPaste(tx, key)
if err != nil { if err != nil {
return err return err
} }
if p != nil && p.Type == pasteTypeFileUpload { if p != nil && p.Type == db.PasteTypeFileUpload {
var id uuid.UUID var id uuid.UUID
copy(id[:], p.Content) copy(id[:], p.Content)
fuID = &id fuID = &id
fu, err = getFileUpload(tx, id) fu, err = db.GetFileUpload(tx, id)
if err != nil { if err != nil {
return err return err
} }
@ -146,17 +139,17 @@ func viewPasteHandlerInner(w http.ResponseWriter, r *http.Request, flags viewPas
} }
switch p.State { switch p.State {
case pasteStatePresent: case db.PasteStatePresent:
var location string var location string
switch p.Type { switch p.Type {
case pasteTypeFileUpload: case db.PasteTypeFileUpload:
if fu == nil { if fu == nil {
panic(fmt.Sprintf("file for id %v does not exist in database\n", fuID)) panic(fmt.Sprintf("file for id %v does not exist in database\n", fuID))
} }
location = fu.url().String() location = fu.URL().String()
break break
case pasteTypeRedirect: case db.PasteTypeRedirect:
location = p.redirectURL().String() location = p.RedirectURL().String()
break break
default: default:
panic("paste type unsupported") panic("paste type unsupported")
@ -165,17 +158,17 @@ func viewPasteHandlerInner(w http.ResponseWriter, r *http.Request, flags viewPas
http.Redirect(w, r, location, http.StatusSeeOther) http.Redirect(w, r, location, http.StatusSeeOther)
} }
fmt.Fprint(w, location) fmt.Fprint(w, location)
case pasteStateDeleted: case db.PasteStateDeleted:
renderError(w, r, http.StatusGone, "paste has been deleted\n") renderError(w, r, http.StatusGone, "paste has been deleted\n")
default: default:
panic(errors.Errorf("invalid paste.State (%v) for key '%v'", p.State, p.Key)) panic(errors.Errorf("invalid paste.State (%v) for key '%v'", p.State, p.Key))
} }
} }
func newPasteHandler(w http.ResponseWriter, r *http.Request) { func (rl *rushlink) newPasteHandler(w http.ResponseWriter, r *http.Request) {
file, fileHeader, err := r.FormFile("file") file, fileHeader, err := r.FormFile("file")
if err == nil { if err == nil {
newFileUploadPasteHandler(w, r, file, *fileHeader) rl.newFileUploadPasteHandler(w, r, file, *fileHeader)
return return
} else if err == http.ErrMissingFile { } else if err == http.ErrMissingFile {
// Fallthrough // Fallthrough
@ -187,23 +180,26 @@ func newPasteHandler(w http.ResponseWriter, r *http.Request) {
shorten := r.FormValue("shorten") shorten := r.FormValue("shorten")
if shorten != "" { if shorten != "" {
newRedirectPasteHandler(w, r, shorten) rl.newRedirectPasteHandler(w, r, shorten)
return return
} }
renderError(w, r, http.StatusBadRequest, "no 'file' and no 'shorten' fields given in form\n") renderError(w, r, http.StatusBadRequest, "no 'file' and no 'shorten' fields given in form\n")
} }
func newFileUploadPasteHandler(w http.ResponseWriter, r *http.Request, file multipart.File, header multipart.FileHeader) { func (rl *rushlink) newFileUploadPasteHandler(w http.ResponseWriter, r *http.Request, file multipart.File, header multipart.FileHeader) {
var fu *fileUpload var fu *db.FileUpload
var paste *paste var paste *db.Paste
if err := DB.Update(func(tx *bolt.Tx) error { if err := rl.db.Bolt.Update(func(tx *bolt.Tx) error {
var err error var err error
// Create the fileUpload in the database // Create the fileUpload in the database
fu, err = newFileUpload(tx, file, header.Filename, header.Header.Get("Content-Type")) fu, err = db.NewFileUpload(rl.fs, file, header.Filename, header.Header.Get("Content-Type"))
if err != nil { if err != nil {
panic(errors.Wrap(err, "creating fileUpload")) panic(errors.Wrap(err, "creating fileUpload"))
} }
if err := fu.Save(tx); err != nil {
panic(errors.Wrap(err, "saving fileUpload in db"))
}
paste, err = shortenFileUploadID(tx, fu.ID) paste, err = shortenFileUploadID(tx, fu.ID)
return err return err
@ -214,7 +210,7 @@ func newFileUploadPasteHandler(w http.ResponseWriter, r *http.Request, file mult
render(w, r, "newFileUploadPasteSuccess", data) render(w, r, "newFileUploadPasteSuccess", data)
} }
func newPasteHandlerURLEncoded(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { func (rl *rushlink) newPasteHandlerURLEncoded(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
next(w, r) next(w, r)
return return
@ -224,10 +220,10 @@ func newPasteHandlerURLEncoded(w http.ResponseWriter, r *http.Request, next http
renderError(w, r, http.StatusBadRequest, "no 'shorten' param given\n") renderError(w, r, http.StatusBadRequest, "no 'shorten' param given\n")
return return
} }
newRedirectPasteHandler(w, r, shorten) rl.newRedirectPasteHandler(w, r, shorten)
} }
func newRedirectPasteHandler(w http.ResponseWriter, r *http.Request, rawurl string) { func (rl *rushlink) newRedirectPasteHandler(w http.ResponseWriter, r *http.Request, rawurl string) {
userURL, err := url.ParseRequestURI(rawurl) userURL, err := url.ParseRequestURI(rawurl)
if err != nil { if err != nil {
msg := fmt.Sprintf("invalid url (%v): %v", err, rawurl) msg := fmt.Sprintf("invalid url (%v): %v", err, rawurl)
@ -243,8 +239,8 @@ func newRedirectPasteHandler(w http.ResponseWriter, r *http.Request, rawurl stri
return return
} }
var paste *paste var paste *db.Paste
if err := DB.Update(func(tx *bolt.Tx) error { if err := rl.db.Bolt.Update(func(tx *bolt.Tx) error {
var err error var err error
paste, err = shortenURL(tx, userURL) paste, err = shortenURL(tx, userURL)
return err return err
@ -256,7 +252,7 @@ func newRedirectPasteHandler(w http.ResponseWriter, r *http.Request, rawurl stri
} }
// Delete a URL from the database // Delete a URL from the database
func deletePasteHandler(w http.ResponseWriter, r *http.Request) { func (rl *rushlink) deletePasteHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
key := vars["key"] key := vars["key"]
@ -267,14 +263,14 @@ func deletePasteHandler(w http.ResponseWriter, r *http.Request) {
} }
var errorCode int var errorCode int
var paste paste var paste db.Paste
if err := DB.Update(func(tx *bolt.Tx) error { if err := rl.db.Bolt.Update(func(tx *bolt.Tx) error {
p, err := getPaste(tx, key) p, err := db.GetPaste(tx, key)
if err != nil { if err != nil {
errorCode = http.StatusNotFound errorCode = http.StatusNotFound
return err return err
} }
if p.State == pasteStateDeleted { if p.State == db.PasteStateDeleted {
errorCode = http.StatusGone errorCode = http.StatusGone
return errors.New("already deleted") return errors.New("already deleted")
} }
@ -282,7 +278,7 @@ func deletePasteHandler(w http.ResponseWriter, r *http.Request) {
errorCode = http.StatusForbidden errorCode = http.StatusForbidden
return errors.New("invalid delete token") return errors.New("invalid delete token")
} }
if err := p.delete(tx); err != nil { if err := p.Delete(tx, rl.fs); err != nil {
errorCode = http.StatusInternalServerError errorCode = http.StatusInternalServerError
return err return err
} }
@ -302,41 +298,41 @@ func deletePasteHandler(w http.ResponseWriter, r *http.Request) {
// //
// Returns the new paste key if the fileUpload was successfully added to the // Returns the new paste key if the fileUpload was successfully added to the
// database // database
func shortenFileUploadID(tx *bolt.Tx, id uuid.UUID) (*paste, error) { func shortenFileUploadID(tx *bolt.Tx, id uuid.UUID) (*db.Paste, error) {
return shorten(tx, pasteTypeFileUpload, id[:]) return shorten(tx, db.PasteTypeFileUpload, id[:])
} }
// Add a new URL to the database // Add a new URL to the database
// //
// Returns the new paste key if the url was successfully shortened // Returns the new paste key if the url was successfully shortened
func shortenURL(tx *bolt.Tx, userURL *url.URL) (*paste, error) { func shortenURL(tx *bolt.Tx, userURL *url.URL) (*db.Paste, error) {
return shorten(tx, pasteTypeRedirect, []byte(userURL.String())) return shorten(tx, db.PasteTypeRedirect, []byte(userURL.String()))
} }
// Add a paste (of any kind) to the database with arbitrary content. // Add a paste (of any kind) to the database with arbitrary content.
func shorten(tx *bolt.Tx, ty pasteType, content []byte) (*paste, error) { func shorten(tx *bolt.Tx, ty db.PasteType, content []byte) (*db.Paste, error) {
// Generate the paste key // Generate the paste key
pasteKey, err := generatePasteKey(tx) pasteKey, err := db.GeneratePasteKey(tx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "generating paste key") return nil, errors.Wrap(err, "generating paste key")
} }
// Also generate a deleteToken // Also generate a deleteToken
deleteToken, err := generateDeleteToken() deleteToken, err := db.GenerateDeleteToken()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "generating delete token") return nil, errors.Wrap(err, "generating delete token")
} }
// Store the new key // Store the new key
p := paste{ p := db.Paste{
Type: ty, Type: ty,
State: pasteStatePresent, State: db.PasteStatePresent,
Content: content, Content: content,
Key: pasteKey, Key: pasteKey,
DeleteToken: deleteToken, DeleteToken: deleteToken,
TimeCreated: time.Now().UTC(), TimeCreated: time.Now().UTC(),
} }
if err := p.save(tx); err != nil { if err := p.Save(tx); err != nil {
return nil, err return nil, err
} }
return &p, nil return &p, nil

147
internal/db/db.go Normal file
View File

@ -0,0 +1,147 @@
package db
import (
"fmt"
"log"
"time"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
gobmarsh "gitea.hashru.nl/dsprenkels/rushlink/pkg/gobmarsh"
)
// Database is the main rushlink database type.
//
// Open a database using DB.Open() and close it in the end using DB.Close().
// Only one instance of DB should exist in a program at any moment.
type Database struct {
Bolt *bolt.DB
}
// CurrentMigrateVersion holds the current "migrate version".
//
// If we alter the database format, we bump this number and write a new
// database migration in migrate().
const CurrentMigrateVersion = 2
// BucketConf holds the name for the "configuration" bucket.
//
// This bucket holds the database version, secret site-wide keys, etc.
const BucketConf = "conf"
// BucketPastes holds the name for the pastes bucket.
const BucketPastes = "pastes"
// BucketFileUpload holds the name for the file-upload bucket.
const BucketFileUpload = "fileUpload"
// KeyMigrateVersion stores the current migration version. If this value is less than
// CurrentMigrateVersion, the database has to be migrated.
const KeyMigrateVersion = "migrate_version"
// OpenDB opens a database file located at path.
func OpenDB(path string) (*Database, error) {
if path == "" {
return nil, errors.New("database not set")
}
db, err := bolt.Open(path, 0666, &bolt.Options{Timeout: 1 * time.Second})
if err != nil {
return nil, errors.Wrapf(err, "failed to open database at '%v'", path)
}
if err := db.Update(migrate); err != nil {
return nil, err
}
return &Database{db}, nil
}
// Close the bolt database
func (db *Database) Close() error {
if db == nil {
panic("no open database")
}
return db.Close()
}
// Initialize and migrate the database to the current version
func migrate(tx *bolt.Tx) error {
dbVersion, err := dbVersion(tx)
if err != nil {
return err
}
// Migrate the database to version 1
if dbVersion < 1 {
log.Println("migrating database to version 1")
// Create conf bucket
_, err := tx.CreateBucket([]byte(BucketConf))
if err != nil {
return err
}
// Create paste bucket
_, err = tx.CreateBucket([]byte(BucketPastes))
if err != nil {
return err
}
// Update the version number
if err := setDBVersion(tx, 1); err != nil {
return err
}
}
if dbVersion < 2 {
log.Println("migrating database to version 2")
// Create fileUpload bucket
_, err := tx.CreateBucket([]byte(BucketFileUpload))
if err != nil {
return err
}
// Update the version number
if err := setDBVersion(tx, 2); err != nil {
return err
}
}
return nil
}
// Get the current migrate version from the database
func dbVersion(tx *bolt.Tx) (int, error) {
conf := tx.Bucket([]byte(BucketConf))
if conf == nil {
return 0, nil
}
dbVersionBytes := conf.Get([]byte(KeyMigrateVersion))
if dbVersionBytes == nil {
return 0, nil
}
// Version was already stored
var dbVersion int
if err := gobmarsh.Unmarshal(dbVersionBytes, &dbVersion); err != nil {
return 0, err
}
if dbVersion == 0 {
return 0, fmt.Errorf("database version is invalid (%v)", dbVersion)
}
if dbVersion > CurrentMigrateVersion {
return 0, fmt.Errorf("database version is too recent (%v > %v)", dbVersion, CurrentMigrateVersion)
}
return dbVersion, nil
}
// Update the current migrate version in the database
func setDBVersion(tx *bolt.Tx, version int) error {
conf, err := tx.CreateBucketIfNotExists([]byte(BucketConf))
if err != nil {
return err
}
versionBytes, err := gobmarsh.Marshal(version)
if err != nil {
return err
}
return conf.Put([]byte(KeyMigrateVersion), versionBytes)
}

184
internal/db/fileupload.go Normal file
View File

@ -0,0 +1,184 @@
package db
import (
"encoding/hex"
"hash/crc32"
"io"
"net/url"
"os"
"path"
"github.com/google/uuid"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
gobmarsh "gitea.hashru.nl/dsprenkels/rushlink/pkg/gobmarsh"
)
// Use the Castagnoli checksum because of the acceleration on Intel CPUs
var checksumTable = crc32.MakeTable(crc32.Castagnoli)
// FileStore holds the path to a file storage location.
type FileStore struct {
path string
}
// FileUploadState determines the current state of a FileUpload object.
type FileUploadState int
// FileUpload models an uploaded file.
type FileUpload struct {
State FileUploadState
ID uuid.UUID
FileName string
ContentType string
Checksum uint32
}
const (
dirMode os.FileMode = 0750
fileMode os.FileMode = 0640
)
const (
// FileUploadStateUndef is an undefined FileUpload.
FileUploadStateUndef FileUploadState = 0
// FileUploadStatePresent denotes the normal (existing) state.
FileUploadStatePresent FileUploadState = 1
// FileUploadStateDeleted denotes a deleted state.
FileUploadStateDeleted FileUploadState = 2
)
func (t FileUploadState) String() string {
switch t {
case FileUploadStateUndef:
return "unknown"
case FileUploadStatePresent:
return "present"
case FileUploadStateDeleted:
return "deleted"
default:
return "invalid"
}
}
// OpenFileStore opens the file storage at path.
func OpenFileStore(path string) (*FileStore, error) {
if path == "" {
return nil, errors.New("file-store not set")
}
// Try to create the file store directory if it does not yet exist
if err := os.MkdirAll(path, dirMode); err != nil {
return nil, errors.Wrap(err, "creating file store directory")
}
return &FileStore{path[:]}, nil
}
// NewFileUpload creates a new FileUpload object.
func NewFileUpload(fs *FileStore, r io.Reader, fileName string, contentType string) (*FileUpload, error) {
id, err := uuid.NewRandom()
if err != nil {
return nil, errors.Wrap(err, "generating UUID")
}
filePath := fs.FilePath(id, fileName)
if err := os.Mkdir(path.Dir(filePath), dirMode); err != nil {
return nil, errors.Wrap(err, "creating file dir")
}
file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, fileMode)
if err != nil {
return nil, errors.Wrap(err, "opening file")
}
defer file.Close()
hash := crc32.New(checksumTable)
tee := io.TeeReader(r, hash)
_, err = io.Copy(file, tee)
if err != nil {
return nil, errors.Wrap(err, "writing to file")
}
fu := &FileUpload{
State: FileUploadStatePresent,
ID: id,
FileName: fileName,
ContentType: contentType,
Checksum: hash.Sum32(),
}
return fu, nil
}
func (fs *FileStore) Path() string {
return fs.path
}
func (fs *FileStore) FilePath(id uuid.UUID, fileName string) string {
if fs.path == "" {
panic("fileStoreDir called while the file store path has not been set")
}
return path.Join(fs.path, hex.EncodeToString(id[:]), fileName)
}
func GetFileUpload(tx *bolt.Tx, id uuid.UUID) (*FileUpload, error) {
bucket := tx.Bucket([]byte(BucketFileUpload))
if bucket == nil {
return nil, errors.Errorf("bucket %v does not exist", BucketFileUpload)
}
storedBytes := bucket.Get(id[:])
if storedBytes == nil {
return nil, nil
}
fu := &FileUpload{}
err := gobmarsh.Unmarshal(storedBytes, fu)
return fu, err
}
// Save saves a FileUpload in the database.
func (fu *FileUpload) Save(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(BucketFileUpload))
if bucket == nil {
return errors.Errorf("bucket %v does not exist", BucketFileUpload)
}
buf, err := gobmarsh.Marshal(fu)
if err != nil {
return errors.Wrap(err, "encoding for database failed")
}
if err := bucket.Put(fu.ID[:], buf); err != nil {
return errors.Wrap(err, "database transaction failed")
}
return nil
}
// Delete deletes a FileUpload from the database.
func (fu *FileUpload) Delete(tx *bolt.Tx, fs *FileStore) error {
// Remove the file in the backend
filePath := fs.FilePath(fu.ID, fu.FileName)
if err := os.Remove(filePath); err != nil {
return err
}
// Update the file in the server
if err := (&FileUpload{
ID: fu.ID,
State: FileUploadStateDeleted,
}).Save(tx); err != nil {
return err
}
// Cleanup the parent directory
wrap := "deletion succeeded, but removing the file directory has failed"
return errors.Wrap(os.Remove(path.Dir(filePath)), wrap)
}
// URL returns the URL for the FileUpload.
func (fu *FileUpload) URL() *url.URL {
rawurl := "/uploads/" + hex.EncodeToString(fu.ID[:]) + "/" + fu.FileName
urlParse, err := url.Parse(rawurl)
if err != nil {
panic("could not construct /uploads/ url")
}
return urlParse
}

View File

@ -1,97 +1,106 @@
package rushlink package db
import ( import (
"crypto/rand" "crypto/rand"
"encoding/base64"
"encoding/hex" "encoding/hex"
"net/url" "net/url"
"strings" "strings"
"time" "time"
gobmarsh "gitea.hashru.nl/dsprenkels/rushlink/pkg/gobmarsh"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
bolt "go.etcd.io/bbolt" bolt "go.etcd.io/bbolt"
) )
type pasteType int type PasteType int
type pasteState int type PasteState int
type paste struct { type Paste struct {
Type pasteType Type PasteType
State pasteState State PasteState
Content []byte Content []byte
Key string Key string
DeleteToken string DeleteToken string
TimeCreated time.Time TimeCreated time.Time
} }
// Note: we use iota here. That means removals of pasteType* are not allowed, // ReservedPasteKeys keys are designated reserved, and will not be randomly chosen
var ReservedPasteKeys = []string{"xd42", "example"}
// Note: we use iota here. That means removals of PasteType* are not allowed,
// because this changes the value of the constant. Please add the comment // because this changes the value of the constant. Please add the comment
// "// deprecated" if you want to remove the constant. Additions are only // "// deprecated" if you want to remove the constant. Additions are only
// allowed at the bottom of this block, for the same reason. // allowed at the bottom of this block, for the same reason.
const ( const (
pasteTypeUndef pasteType = iota PasteTypeUndef PasteType = iota
pasteTypePaste PasteTypePaste
pasteTypeRedirect PasteTypeRedirect
pasteTypeFileUpload PasteTypeFileUpload
) )
// Note: we use iota here. See the comment above pasteType* // Note: we use iota here. See the comment above PasteType*
const ( const (
pasteStateUndef pasteState = iota PasteStateUndef PasteState = iota
pasteStatePresent PasteStatePresent
pasteStateDeleted PasteStateDeleted
) )
func (t pasteType) String() string { // Base64 encoding and decoding
var base64Alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
var base64Encoder = base64.RawURLEncoding.WithPadding(base64.NoPadding)
func (t PasteType) String() string {
switch t { switch t {
case pasteTypeUndef: case PasteTypeUndef:
return "unknown" return "unknown"
case pasteTypePaste: case PasteTypePaste:
return "paste" return "paste"
case pasteTypeRedirect: case PasteTypeRedirect:
return "redirect" return "redirect"
case pasteTypeFileUpload: case PasteTypeFileUpload:
return "file" return "file"
default: default:
return "invalid" return "invalid"
} }
} }
func (t pasteState) String() string { func (t PasteState) String() string {
switch t { switch t {
case pasteStateUndef: case PasteStateUndef:
return "unknown" return "unknown"
case pasteStatePresent: case PasteStatePresent:
return "present" return "present"
case pasteStateDeleted: case PasteStateDeleted:
return "deleted" return "deleted"
default: default:
return "invalid" return "invalid"
} }
} }
// Retrieve a paste from the database // GetPaste retrieves a paste from the database.
func getPaste(tx *bolt.Tx, key string) (*paste, error) { func GetPaste(tx *bolt.Tx, key string) (*Paste, error) {
pastesBucket := tx.Bucket([]byte(BUCKET_PASTES)) pastesBucket := tx.Bucket([]byte(BucketPastes))
if pastesBucket == nil { if pastesBucket == nil {
return nil, errors.Errorf("bucket %v does not exist", BUCKET_PASTES) return nil, errors.Errorf("bucket %v does not exist", BucketPastes)
} }
storedBytes := pastesBucket.Get([]byte(key)) storedBytes := pastesBucket.Get([]byte(key))
if storedBytes == nil { if storedBytes == nil {
return nil, nil return nil, nil
} }
p := &paste{} p := &Paste{}
err := Unmarshal(storedBytes, p) err := gobmarsh.Unmarshal(storedBytes, p)
return p, err return p, err
} }
func (p *paste) save(tx *bolt.Tx) error { func (p *Paste) Save(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(BUCKET_PASTES)) bucket := tx.Bucket([]byte(BucketPastes))
if bucket == nil { if bucket == nil {
return errors.Errorf("bucket %v does not exist", BUCKET_PASTES) return errors.Errorf("bucket %v does not exist", BucketPastes)
} }
buf, err := Marshal(p) buf, err := gobmarsh.Marshal(p)
if err != nil { if err != nil {
return errors.Wrap(err, "encoding for database failed") return errors.Wrap(err, "encoding for database failed")
} }
@ -101,39 +110,39 @@ func (p *paste) save(tx *bolt.Tx) error {
return nil return nil
} }
func (p *paste) delete(tx *bolt.Tx) error { func (p *Paste) Delete(tx *bolt.Tx, fs *FileStore) error {
// Remove the (maybe) attached file // Remove the (maybe) attached file
if p.Type == pasteTypeFileUpload { if p.Type == PasteTypeFileUpload {
fuID, err := uuid.FromBytes(p.Content) fuID, err := uuid.FromBytes(p.Content)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to parse uuid") return errors.Wrap(err, "failed to parse uuid")
} }
fu, err := getFileUpload(tx, fuID) fu, err := GetFileUpload(tx, fuID)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to find file in database") return errors.Wrap(err, "failed to find file in database")
} }
if err := fu.delete(tx); err != nil { if err := fu.Delete(tx, fs); err != nil {
return errors.Wrap(err, "failed to remove file") return errors.Wrap(err, "failed to remove file")
} }
} }
// Replace the old paste with a new empty paste // Replace the old paste with a new empty paste
p.Type = pasteTypeUndef p.Type = PasteTypeUndef
p.State = pasteStateDeleted p.State = PasteStateDeleted
p.Content = []byte{} p.Content = []byte{}
if err := p.save(tx); err != nil { if err := p.Save(tx); err != nil {
return errors.Wrap(err, "failed to delete paste in database") return errors.Wrap(err, "failed to delete paste in database")
} }
return nil return nil
} }
// Get the URL from this paste. // RedirectURL returns the URL from this paste.
// //
// This function assumes that the paste is valid. If the paste struct is // This function assumes that the paste is valid. If the paste struct is
// corrupted in some way, this function will panic. // corrupted in some way, this function will panic.
func (p *paste) redirectURL() *url.URL { func (p *Paste) RedirectURL() *url.URL {
if p.Type != pasteTypeRedirect { if p.Type != PasteTypeRedirect {
panic("expected p.Type to be pasteTypeRedirect") panic("expected p.Type to be PasteTypeRedirect")
} }
rawurl := string(p.Content) rawurl := string(p.Content)
urlParse, err := url.Parse(rawurl) urlParse, err := url.Parse(rawurl)
@ -143,12 +152,13 @@ func (p *paste) redirectURL() *url.URL {
return urlParse return urlParse
} }
// Generate a key until it is not in the database, this occurs in O(log N), // GeneratePasteKey generates a key until it is not in the database, the
// where N is the amount of keys stored in the url-shorten database. // running time of this function is in O(log N), where N is the amount of
func generatePasteKey(tx *bolt.Tx) (string, error) { // keys stored in the url-shorten database.
pastesBucket := tx.Bucket([]byte(BUCKET_PASTES)) func GeneratePasteKey(tx *bolt.Tx) (string, error) {
pastesBucket := tx.Bucket([]byte(BucketPastes))
if pastesBucket == nil { if pastesBucket == nil {
return "", errors.Errorf("bucket %v does not exist", BUCKET_PASTES) return "", errors.Errorf("bucket %v does not exist", BucketPastes)
} }
epoch := 0 epoch := 0
@ -216,7 +226,7 @@ func generatePasteKeyInner(epoch int) (string, error) {
return string(urlKey), nil return string(urlKey), nil
} }
func generateDeleteToken() (string, error) { func GenerateDeleteToken() (string, error) {
var deleteToken [16]byte var deleteToken [16]byte
_, err := rand.Read(deleteToken[:]) _, err := rand.Read(deleteToken[:])
if err != nil { if err != nil {

View File

@ -5,6 +5,8 @@ import (
"net/http" "net/http"
"time" "time"
"gitea.hashru.nl/dsprenkels/rushlink/internal/db"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -13,7 +15,7 @@ import (
bolt "go.etcd.io/bbolt" bolt "go.etcd.io/bbolt"
) )
func StartMetricsServer(addr string) { func StartMetricsServer(addr string, db *db.Database) {
var ( var (
_ = promauto.NewGaugeFunc(prometheus.GaugeOpts{ _ = promauto.NewGaugeFunc(prometheus.GaugeOpts{
Namespace: "rushlink", Namespace: "rushlink",
@ -22,7 +24,7 @@ func StartMetricsServer(addr string) {
Help: "The current amount of pastes in the database.", Help: "The current amount of pastes in the database.",
}, func() float64 { }, func() float64 {
var metric float64 var metric float64
if err := DB.View(func(tx *bolt.Tx) error { if err := db.Bolt.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte("pastes")) bucket := tx.Bucket([]byte("pastes"))
if bucket == nil { if bucket == nil {
return errors.New("bucket 'pastes' could not be found") return errors.New("bucket 'pastes' could not be found")

View File

@ -7,6 +7,7 @@ import (
"encoding/gob" "encoding/gob"
) )
// Marshal serializes the value in v to a byte buffer.
func Marshal(v interface{}) ([]byte, error) { func Marshal(v interface{}) ([]byte, error) {
b := new(bytes.Buffer) b := new(bytes.Buffer)
err := gob.NewEncoder(b).Encode(v) err := gob.NewEncoder(b).Encode(v)
@ -16,6 +17,7 @@ func Marshal(v interface{}) ([]byte, error) {
return b.Bytes(), nil return b.Bytes(), nil
} }
// Unmarshal deserializes the data in data into the object in v.
func Unmarshal(data []byte, v interface{}) error { func Unmarshal(data []byte, v interface{}) error {
b := bytes.NewBuffer(data) b := bytes.NewBuffer(data)
return gob.NewDecoder(b).Decode(v) return gob.NewDecoder(b).Decode(v)

View File

@ -7,9 +7,15 @@ import (
"runtime/debug" "runtime/debug"
"time" "time"
"gitea.hashru.nl/dsprenkels/rushlink/internal/db"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
type rushlink struct {
db *db.Database
fs *db.FileStore
}
func recoveryMiddleware(next http.Handler) http.Handler { func recoveryMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() { defer func() {
@ -31,18 +37,23 @@ func recoveryMiddleware(next http.Handler) http.Handler {
}) })
} }
func StartMainServer(addr string) { func StartMainServer(addr string, db *db.Database, fs *db.FileStore) {
rl := rushlink{
db: db,
fs: fs,
}
// Initialize Gorilla router // Initialize Gorilla router
router := mux.NewRouter() router := mux.NewRouter()
router.Use(recoveryMiddleware) router.Use(recoveryMiddleware)
router.HandleFunc("/", indexGetHandler).Methods("GET") router.HandleFunc("/", rl.indexGetHandler).Methods("GET")
router.HandleFunc("/", newPasteHandler).Methods("POST") router.HandleFunc("/", rl.newPasteHandler).Methods("POST")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}", viewPasteHandler).Methods("GET") router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}", rl.viewPasteHandler).Methods("GET")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/nr", viewPasteHandlerNoRedirect).Methods("GET") router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/nr", rl.viewPasteHandlerNoRedirect).Methods("GET")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/meta", viewPasteHandlerMeta).Methods("GET") router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/meta", rl.viewPasteHandlerMeta).Methods("GET")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}", deletePasteHandler).Methods("DELETE") router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}", rl.deletePasteHandler).Methods("DELETE")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/delete", deletePasteHandler).Methods("POST") router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/delete", rl.deletePasteHandler).Methods("POST")
router.HandleFunc("/uploads/{id:[A-Za-z0-9-_]+}/{filename:.+}", uploadFileGetHandler).Methods("GET") router.HandleFunc("/uploads/{id:[A-Za-z0-9-_]+}/{filename:.+}", rl.uploadFileGetHandler).Methods("GET")
srv := &http.Server{ srv := &http.Server{
Handler: router, Handler: router,