Refactor main into subdirs

This commit is contained in:
Daan Sprenkels 2019-09-19 21:29:25 +02:00
parent 60fd92c956
commit a4ad82d6dd
5 changed files with 66 additions and 61 deletions

View File

@ -3,63 +3,20 @@ package main
import ( import (
"flag" "flag"
"log" "log"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/pkg/errors"
"gitea.hashru.nl/dsprenkels/rushlink/db" "gitea.hashru.nl/dsprenkels/rushlink/db"
"gitea.hashru.nl/dsprenkels/rushlink/handlers" "gitea.hashru.nl/dsprenkels/rushlink/handlers"
"gitea.hashru.nl/dsprenkels/rushlink/metrics" "gitea.hashru.nl/dsprenkels/rushlink/metrics"
) )
type ParsedArguments struct {
databaseName string
}
var appConfig ParsedArguments
func main() { func main() {
// Parse the arguments and construct the ParsedArguments
appConfigRef, err := parseArguments()
if err != nil {
log.Fatal(err)
}
appConfig = *appConfigRef
db.Init(appConfig.databaseName)
// Export prometheus metrics
go metrics.StartMetricsServer()
// Initialize Gorilla router
router := mux.NewRouter()
router.HandleFunc("/", handlers.IndexGetHandler).Methods("GET")
router.HandleFunc("/", handlers.IndexPostHandler).Methods("POST")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}", handlers.PasteGetHandler).Methods("GET")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/nr", handlers.PasteGetHandlerNoRedirect).Methods("GET")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/meta", handlers.PasteGetHandlerMeta).Methods("GET")
// Start the server
srv := &http.Server{
Handler: router,
Addr: "127.0.0.1:8000",
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
}
log.Fatal(srv.ListenAndServe())
}
// Parse the input arguments and return the initialized application config struct
func parseArguments() (*ParsedArguments, error) {
config := ParsedArguments{}
flag.StringVar(&config.databaseName, "database", "", "Location of the database file")
flag.Parse() flag.Parse()
if config.databaseName == "" { if err := db.Open(); err != nil {
return nil, errors.New("database not set") log.Fatalln(err)
} }
return &config, nil defer db.Close()
go metrics.StartMetricsServer()
handlers.StartMainServer()
} }

View File

@ -1,6 +1,7 @@
package db package db
import ( import (
"flag"
"fmt" "fmt"
"log" "log"
"time" "time"
@ -11,6 +12,7 @@ import (
"gitea.hashru.nl/dsprenkels/rushlink/gobmarsh" "gitea.hashru.nl/dsprenkels/rushlink/gobmarsh"
) )
var path = flag.String("database", "", "Location of the database file")
var DB *bolt.DB var DB *bolt.DB
// The current database version // The current database version
@ -31,15 +33,35 @@ const BUCKET_PASTES = "pastes"
const KEY_MIGRATE_VERSION = "migrate_version" const KEY_MIGRATE_VERSION = "migrate_version"
// Open the bolt database // Open the bolt database
func Init(name string) error { func Open() error {
if *path == "" {
return errors.New("database not set")
}
var err error var err error
DB, err = bolt.Open(name, 0666, &bolt.Options{Timeout: 1 * time.Second}) DB, err = bolt.Open(*path, 0666, &bolt.Options{Timeout: 1 * time.Second})
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to open database at '%v'", name) return errors.Wrapf(err, "failed to open database at '%v'", *path)
} }
return DB.Update(migrateDatabase) return DB.Update(migrateDatabase)
} }
// Close the bolt database
func Close() error {
if DB == nil {
panic("no open database")
}
return DB.Close()
}
// Get the database path (as was set by flags)
func Path() string {
if path == nil {
return ""
}
return *path
}
// Initialize and migrate the database to the current version // Initialize and migrate the database to the current version
func migrateDatabase(tx *bolt.Tx) error { func migrateDatabase(tx *bolt.Tx) error {
dbVersion, err := dbVersion(tx) dbVersion, err := dbVersion(tx)

View File

@ -73,11 +73,11 @@ func (t PasteState) String() string {
} }
} }
func IndexGetHandler(w http.ResponseWriter, r *http.Request) { func indexGetHandler(w http.ResponseWriter, r *http.Request) {
Render(w, r, "index", nil) Render(w, r, "index", nil)
} }
func IndexPostHandler(w http.ResponseWriter, r *http.Request) { func indexPostHandler(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(50 * 1000 * 1000); err != nil { if err := r.ParseMultipartForm(50 * 1000 * 1000); err != nil {
log.Printf("error: %v\n", err) log.Printf("error: %v\n", err)
RenderInternalServerError(w, r, err) RenderInternalServerError(w, r, err)
@ -99,18 +99,18 @@ func IndexPostHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
ShortenPostHandler(w, r) shortenPostHandler(w, r)
} }
func PasteGetHandler(w http.ResponseWriter, r *http.Request) { func pasteGetHandler(w http.ResponseWriter, r *http.Request) {
pasteGetHandlerInner(w, r, false, false) pasteGetHandlerInner(w, r, false, false)
} }
func PasteGetHandlerNoRedirect(w http.ResponseWriter, r *http.Request) { func pasteGetHandlerNoRedirect(w http.ResponseWriter, r *http.Request) {
pasteGetHandlerInner(w, r, true, false) pasteGetHandlerInner(w, r, true, false)
} }
func PasteGetHandlerMeta(w http.ResponseWriter, r *http.Request) { func pasteGetHandlerMeta(w http.ResponseWriter, r *http.Request) {
pasteGetHandlerInner(w, r, false, true) pasteGetHandlerInner(w, r, false, true)
} }
@ -169,7 +169,7 @@ func pasteGetHandlerInner(w http.ResponseWriter, r *http.Request, noRedirect, sh
} }
} }
func ShortenPostHandler(w http.ResponseWriter, r *http.Request) { func shortenPostHandler(w http.ResponseWriter, r *http.Request) {
rawurl := r.PostForm.Get("shorten") rawurl := r.PostForm.Get("shorten")
userURL, err := url.ParseRequestURI(rawurl) userURL, err := url.ParseRequestURI(rawurl)
if err != nil { if err != nil {

27
handlers/router.go Normal file
View File

@ -0,0 +1,27 @@
package handlers
import (
"log"
"net/http"
"time"
"github.com/gorilla/mux"
)
func StartMainServer() {
// Initialize Gorilla router
router := mux.NewRouter()
router.HandleFunc("/", indexGetHandler).Methods("GET")
router.HandleFunc("/", indexPostHandler).Methods("POST")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}", pasteGetHandler).Methods("GET")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/nr", pasteGetHandlerNoRedirect).Methods("GET")
router.HandleFunc("/{key:[A-Za-z0-9-_]{4,}}/meta", pasteGetHandlerMeta).Methods("GET")
srv := &http.Server{
Handler: router,
Addr: "127.0.0.1:8000",
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
}
log.Fatal(srv.ListenAndServe())
}

View File

@ -69,8 +69,7 @@ func mustMatch(pattern, name string) bool {
} }
func parseFail(tmplName string, err error) { func parseFail(tmplName string, err error) {
err = errors.Wrapf(err, "parsing of %v failed", tmplName) panic(errors.Wrapf(err, "parsing of %v failed", tmplName))
panic(err)
} }
func Render(w http.ResponseWriter, r *http.Request, tmplName string, data map[string]interface{}) { func Render(w http.ResponseWriter, r *http.Request, tmplName string, data map[string]interface{}) {