package rushlink import ( "bytes" "io/ioutil" "mime/multipart" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "strings" "testing" "gitea.hashru.nl/dsprenkels/rushlink/internal/db" "github.com/gorilla/mux" ) // createTemporaryRouter initializes a rushlink instance, with temporary // filestore and database. // // It will use testing.T.Cleanup to cleanup after itself. func createTemporaryRouter(t *testing.T) *mux.Router { tempDir, err := ioutil.TempDir("", "rushlink-tmp-*") if err != nil { t.Fatalf("creating temporary directory: %s\n", err) } t.Cleanup(func() { os.RemoveAll(tempDir) }) fileStore, err := db.OpenFileStore(filepath.Join(tempDir, "filestore")) if err != nil { t.Fatalf("opening temporary filestore: %s\n", err) } databasePath := filepath.Join(tempDir, "rushlink.db") database, err := db.OpenDB(databasePath, fileStore) if err != nil { t.Fatalf("opening temporary database: %s\n", err) } t.Cleanup(func() { if err := database.Close(); err != nil { t.Errorf("closing database: %d\n", err) } }) // *.invalid. is guaranteed not to exist (RFC 6761). rootURL, err := url.Parse("https://rushlink.invalid") if err != nil { t.Fatalf("parsing URL: %s\n", err) } rl := rushlink{ db: database, fs: fileStore, rootURL: rootURL, } return CreateMainRouter(&rl) } // checkStatusCode checks whether the status code from a recorded response is equal // to some response code. func checkStatusCode(t *testing.T, rr *httptest.ResponseRecorder, code int) { if actual := rr.Code; actual != code { t.Logf("request body:\n%v\n", rr.Body.String()) t.Fatalf("handler returned wrong status code: got %v want %v\n", actual, code) } } // checkLocationHeader checks whether the status code from a recorded response is equal // to some expected URL. func checkLocationHeader(t *testing.T, rr *httptest.ResponseRecorder, expected string) { location := rr.Header().Get("Location") if location != expected { t.Fatalf("handler returned bad redirect location: got %v want %v", location, expected) } } func TestIssue43(t *testing.T) { srv := createTemporaryRouter(t) // Put a URL with a fragment identifier into the database. var body bytes.Buffer form := multipart.NewWriter(&body) form.WriteField("shorten", "https://example.com#fragment") form.Close() req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes())) if err != nil { t.Fatal(err) } req.Header.Add("Content-Type", form.FormDataContentType()) rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusOK) rawURL := strings.SplitN(rr.Body.String(), "\n", 2)[0] pasteURL, err := url.Parse(rawURL) if err != nil { t.Fatal(err) } // Check if the URL was encoded correctly. req, err = http.NewRequest("GET", pasteURL.Path, nil) if err != nil { t.Fatal(err) } req = mux.SetURLVars(req, map[string]string{"key": pasteURL.Path[1:]}) rr = httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusTemporaryRedirect) checkLocationHeader(t, rr, "https://example.com#fragment") }