diff --git a/handlers_test.go b/handlers_test.go new file mode 100644 index 0000000..ed2013d --- /dev/null +++ b/handlers_test.go @@ -0,0 +1,112 @@ +package rushlink + +import ( + "bytes" + "io/ioutil" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + + "gitea.hashru.nl/dsprenkels/rushlink/internal/db" + "github.com/gorilla/mux" +) + +// createTemporaryRouter initializes a rushlink instance, with temporary +// filestore and database. +// +// It will use testing.T.Cleanup to cleanup after itself. +func createTemporaryRouter(t *testing.T) *mux.Router { + tempDir, err := ioutil.TempDir("", "rushlink-tmp-*") + if err != nil { + t.Fatalf("creating temporary directory: %s\n", err) + } + t.Cleanup(func() { + os.RemoveAll(tempDir) + }) + + fileStore, err := db.OpenFileStore(filepath.Join(tempDir, "filestore")) + if err != nil { + t.Fatalf("opening temporary filestore: %s\n", err) + } + databasePath := filepath.Join(tempDir, "rushlink.db") + database, err := db.OpenDB(databasePath, fileStore) + if err != nil { + t.Fatalf("opening temporary database: %s\n", err) + } + t.Cleanup(func() { + if err := database.Close(); err != nil { + t.Errorf("closing database: %d\n", err) + } + }) + + // *.invalid. is guaranteed not to exist (RFC 6761). + rootURL, err := url.Parse("https://rushlink.invalid") + if err != nil { + t.Fatalf("parsing URL: %s\n", err) + } + + rl := rushlink{ + db: database, + fs: fileStore, + rootURL: rootURL, + } + return CreateMainRouter(&rl) +} + +// checkStatusCode checks whether the status code from a recorded response is equal +// to some response code. +func checkStatusCode(t *testing.T, rr *httptest.ResponseRecorder, code int) { + if actual := rr.Code; actual != code { + t.Logf("request body:\n%v\n", rr.Body.String()) + t.Fatalf("handler returned wrong status code: got %v want %v\n", + actual, code) + } +} + +// checkLocationHeader checks whether the status code from a recorded response is equal +// to some expected URL. +func checkLocationHeader(t *testing.T, rr *httptest.ResponseRecorder, expected string) { + location := rr.Header().Get("Location") + if location != expected { + t.Fatalf("handler returned bad redirect location: got %v want %v", location, expected) + } +} + +func TestIssue43(t *testing.T) { + srv := createTemporaryRouter(t) + + // Put a URL with a fragment identifier into the database. + var body bytes.Buffer + form := multipart.NewWriter(&body) + form.WriteField("shorten", "https://example.com#fragment") + form.Close() + req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes())) + if err != nil { + t.Fatal(err) + } + req.Header.Add("Content-Type", form.FormDataContentType()) + rr := httptest.NewRecorder() + srv.ServeHTTP(rr, req) + checkStatusCode(t, rr, http.StatusOK) + rawURL := strings.SplitN(rr.Body.String(), "\n", 2)[0] + pasteURL, err := url.Parse(rawURL) + if err != nil { + t.Fatal(err) + } + + // Check if the URL was encoded correctly. + req, err = http.NewRequest("GET", pasteURL.Path, nil) + if err != nil { + t.Fatal(err) + } + req = mux.SetURLVars(req, map[string]string{"key": pasteURL.Path[1:]}) + rr = httptest.NewRecorder() + srv.ServeHTTP(rr, req) + checkStatusCode(t, rr, http.StatusTemporaryRedirect) + checkLocationHeader(t, rr, "https://example.com#fragment") +} diff --git a/router.go b/router.go index 64cdfd1..9def677 100644 --- a/router.go +++ b/router.go @@ -79,23 +79,8 @@ func (w *statusResponseWriter) WriteHeader(statusCode int) { w.Inner.WriteHeader(statusCode) } -// StartMainServer starts the main http server listening on addr. -func StartMainServer(addr string, db *db.Database, fs *db.FileStore, rawRootURL string) { - var rootURL *url.URL - if rawRootURL != "" { - var err error - rootURL, err = url.Parse(rawRootURL) - if err != nil { - log.Fatalln(errors.Wrap(err, "could not parse rootURL flag")) - } - } - rl := rushlink{ - db: db, - fs: fs, - rootURL: rootURL, - } - - // Initialize Gorilla router +// CreateMainRouter creates the main Gorilla router for the application. +func CreateMainRouter(rl *rushlink) *mux.Router { router := mux.NewRouter() router.Use(rl.recoveryMiddleware) router.Use(rl.metricsMiddleware) @@ -114,9 +99,27 @@ func StartMainServer(addr string, db *db.Database, fs *db.FileStore, rawRootURL router.HandleFunc("/"+urlKeyWithExtExpr, rl.deletePasteHandler).Methods("DELETE") router.HandleFunc("/"+urlKeyExpr+"/delete", rl.deletePasteHandler).Methods("POST") router.HandleFunc("/"+urlKeyWithExtExpr+"/delete", rl.deletePasteHandler).Methods("POST") + return router +} + +// StartMainServer starts the main http server listening on addr. +func StartMainServer(addr string, db *db.Database, fs *db.FileStore, rawRootURL string) { + var rootURL *url.URL + if rawRootURL != "" { + var err error + rootURL, err = url.Parse(rawRootURL) + if err != nil { + log.Fatalln(errors.Wrap(err, "could not parse rootURL flag")) + } + } + rl := rushlink{ + db: db, + fs: fs, + rootURL: rootURL, + } srv := &http.Server{ - Handler: router, + Handler: CreateMainRouter(&rl), Addr: addr, WriteTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second,