From a4ad82d6dd1d8057644854c7d4099643d31c59c9 Mon Sep 17 00:00:00 2001 From: Daan Sprenkels Date: Thu, 19 Sep 2019 21:29:25 +0200 Subject: [PATCH] Refactor main into subdirs --- cmd/rushlink/main.go | 55 +++++--------------------------------------- db/db.go | 28 +++++++++++++++++++--- handlers/handlers.go | 14 +++++------ handlers/router.go | 27 ++++++++++++++++++++++ handlers/views.go | 3 +-- 5 files changed, 66 insertions(+), 61 deletions(-) create mode 100644 handlers/router.go diff --git a/cmd/rushlink/main.go b/cmd/rushlink/main.go index e2e1fae..e3d49a0 100644 --- a/cmd/rushlink/main.go +++ b/cmd/rushlink/main.go @@ -3,63 +3,20 @@ package main import ( "flag" "log" - "net/http" - "time" - - "github.com/gorilla/mux" - "github.com/pkg/errors" "gitea.hashru.nl/dsprenkels/rushlink/db" "gitea.hashru.nl/dsprenkels/rushlink/handlers" "gitea.hashru.nl/dsprenkels/rushlink/metrics" ) -type ParsedArguments struct { - databaseName string -} - -var appConfig ParsedArguments - func main() { - // Parse the arguments and construct the ParsedArguments - appConfigRef, err := parseArguments() - if err != nil { - log.Fatal(err) - } - appConfig = *appConfigRef - - db.Init(appConfig.databaseName) - - // Export prometheus metrics - go metrics.StartMetricsServer() - - // Initialize Gorilla router - router := mux.NewRouter() - router.HandleFunc("/", handlers.IndexGetHandler).Methods("GET") - router.HandleFunc("/", handlers.IndexPostHandler).Methods("POST") - router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}", handlers.PasteGetHandler).Methods("GET") - router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/nr", handlers.PasteGetHandlerNoRedirect).Methods("GET") - router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/meta", handlers.PasteGetHandlerMeta).Methods("GET") - - // Start the server - srv := &http.Server{ - Handler: router, - Addr: "127.0.0.1:8000", - WriteTimeout: 15 * time.Second, - ReadTimeout: 15 * time.Second, - } - log.Fatal(srv.ListenAndServe()) -} - -// Parse the input arguments and return the initialized application config struct -func parseArguments() (*ParsedArguments, error) { - config := ParsedArguments{} - - flag.StringVar(&config.databaseName, "database", "", "Location of the database file") flag.Parse() - if config.databaseName == "" { - return nil, errors.New("database not set") + if err := db.Open(); err != nil { + log.Fatalln(err) } - return &config, nil + defer db.Close() + + go metrics.StartMetricsServer() + handlers.StartMainServer() } diff --git a/db/db.go b/db/db.go index fdc9842..24506b3 100644 --- a/db/db.go +++ b/db/db.go @@ -1,6 +1,7 @@ package db import ( + "flag" "fmt" "log" "time" @@ -11,6 +12,7 @@ import ( "gitea.hashru.nl/dsprenkels/rushlink/gobmarsh" ) +var path = flag.String("database", "", "Location of the database file") var DB *bolt.DB // The current database version @@ -31,15 +33,35 @@ const BUCKET_PASTES = "pastes" const KEY_MIGRATE_VERSION = "migrate_version" // Open the bolt database -func Init(name string) error { +func Open() error { + if *path == "" { + return errors.New("database not set") + } + var err error - DB, err = bolt.Open(name, 0666, &bolt.Options{Timeout: 1 * time.Second}) + DB, err = bolt.Open(*path, 0666, &bolt.Options{Timeout: 1 * time.Second}) if err != nil { - return errors.Wrapf(err, "failed to open database at '%v'", name) + return errors.Wrapf(err, "failed to open database at '%v'", *path) } return DB.Update(migrateDatabase) } +// Close the bolt database +func Close() error { + if DB == nil { + panic("no open database") + } + return DB.Close() +} + +// Get the database path (as was set by flags) +func Path() string { + if path == nil { + return "" + } + return *path +} + // Initialize and migrate the database to the current version func migrateDatabase(tx *bolt.Tx) error { dbVersion, err := dbVersion(tx) diff --git a/handlers/handlers.go b/handlers/handlers.go index 58cb638..2a41f18 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -73,11 +73,11 @@ func (t PasteState) String() string { } } -func IndexGetHandler(w http.ResponseWriter, r *http.Request) { +func indexGetHandler(w http.ResponseWriter, r *http.Request) { Render(w, r, "index", nil) } -func IndexPostHandler(w http.ResponseWriter, r *http.Request) { +func indexPostHandler(w http.ResponseWriter, r *http.Request) { if err := r.ParseMultipartForm(50 * 1000 * 1000); err != nil { log.Printf("error: %v\n", err) RenderInternalServerError(w, r, err) @@ -99,18 +99,18 @@ func IndexPostHandler(w http.ResponseWriter, r *http.Request) { return } - ShortenPostHandler(w, r) + shortenPostHandler(w, r) } -func PasteGetHandler(w http.ResponseWriter, r *http.Request) { +func pasteGetHandler(w http.ResponseWriter, r *http.Request) { pasteGetHandlerInner(w, r, false, false) } -func PasteGetHandlerNoRedirect(w http.ResponseWriter, r *http.Request) { +func pasteGetHandlerNoRedirect(w http.ResponseWriter, r *http.Request) { pasteGetHandlerInner(w, r, true, false) } -func PasteGetHandlerMeta(w http.ResponseWriter, r *http.Request) { +func pasteGetHandlerMeta(w http.ResponseWriter, r *http.Request) { pasteGetHandlerInner(w, r, false, true) } @@ -169,7 +169,7 @@ func pasteGetHandlerInner(w http.ResponseWriter, r *http.Request, noRedirect, sh } } -func ShortenPostHandler(w http.ResponseWriter, r *http.Request) { +func shortenPostHandler(w http.ResponseWriter, r *http.Request) { rawurl := r.PostForm.Get("shorten") userURL, err := url.ParseRequestURI(rawurl) if err != nil { diff --git a/handlers/router.go b/handlers/router.go new file mode 100644 index 0000000..b59c58d --- /dev/null +++ b/handlers/router.go @@ -0,0 +1,27 @@ +package handlers + +import ( + "log" + "net/http" + "time" + + "github.com/gorilla/mux" +) + +func StartMainServer() { + // Initialize Gorilla router + router := mux.NewRouter() + router.HandleFunc("/", indexGetHandler).Methods("GET") + router.HandleFunc("/", indexPostHandler).Methods("POST") + router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}", pasteGetHandler).Methods("GET") + router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/nr", pasteGetHandlerNoRedirect).Methods("GET") + router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/meta", pasteGetHandlerMeta).Methods("GET") + + srv := &http.Server{ + Handler: router, + Addr: "127.0.0.1:8000", + WriteTimeout: 15 * time.Second, + ReadTimeout: 15 * time.Second, + } + log.Fatal(srv.ListenAndServe()) +} diff --git a/handlers/views.go b/handlers/views.go index 3ce84ca..0ca6f90 100644 --- a/handlers/views.go +++ b/handlers/views.go @@ -69,8 +69,7 @@ func mustMatch(pattern, name string) bool { } func parseFail(tmplName string, err error) { - err = errors.Wrapf(err, "parsing of %v failed", tmplName) - panic(err) + panic(errors.Wrapf(err, "parsing of %v failed", tmplName)) } func Render(w http.ResponseWriter, r *http.Request, tmplName string, data map[string]interface{}) {