rushlink/handlers_test.go
2020-04-27 13:00:46 +02:00

113 lines
3.0 KiB
Go

package rushlink
import (
"bytes"
"io/ioutil"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"gitea.hashru.nl/dsprenkels/rushlink/internal/db"
"github.com/gorilla/mux"
)
// createTemporaryRouter initializes a rushlink instance, with temporary
// filestore and database.
//
// It will use testing.T.Cleanup to cleanup after itself.
func createTemporaryRouter(t *testing.T) *mux.Router {
tempDir, err := ioutil.TempDir("", "rushlink-tmp-*")
if err != nil {
t.Fatalf("creating temporary directory: %s\n", err)
}
t.Cleanup(func() {
os.RemoveAll(tempDir)
})
fileStore, err := db.OpenFileStore(filepath.Join(tempDir, "filestore"))
if err != nil {
t.Fatalf("opening temporary filestore: %s\n", err)
}
databasePath := filepath.Join(tempDir, "rushlink.db")
database, err := db.OpenDB(databasePath, fileStore)
if err != nil {
t.Fatalf("opening temporary database: %s\n", err)
}
t.Cleanup(func() {
if err := database.Close(); err != nil {
t.Errorf("closing database: %d\n", err)
}
})
// *.invalid. is guaranteed not to exist (RFC 6761).
rootURL, err := url.Parse("https://rushlink.invalid")
if err != nil {
t.Fatalf("parsing URL: %s\n", err)
}
rl := rushlink{
db: database,
fs: fileStore,
rootURL: rootURL,
}
return CreateMainRouter(&rl)
}
// checkStatusCode checks whether the status code from a recorded response is equal
// to some response code.
func checkStatusCode(t *testing.T, rr *httptest.ResponseRecorder, code int) {
if actual := rr.Code; actual != code {
t.Logf("request body:\n%v\n", rr.Body.String())
t.Fatalf("handler returned wrong status code: got %v want %v\n",
actual, code)
}
}
// checkLocationHeader checks whether the status code from a recorded response is equal
// to some expected URL.
func checkLocationHeader(t *testing.T, rr *httptest.ResponseRecorder, expected string) {
location := rr.Header().Get("Location")
if location != expected {
t.Fatalf("handler returned bad redirect location: got %v want %v", location, expected)
}
}
func TestIssue43(t *testing.T) {
srv := createTemporaryRouter(t)
// Put a URL with a fragment identifier into the database.
var body bytes.Buffer
form := multipart.NewWriter(&body)
form.WriteField("shorten", "https://example.com#fragment")
form.Close()
req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes()))
if err != nil {
t.Fatal(err)
}
req.Header.Add("Content-Type", form.FormDataContentType())
rr := httptest.NewRecorder()
srv.ServeHTTP(rr, req)
checkStatusCode(t, rr, http.StatusOK)
rawURL := strings.SplitN(rr.Body.String(), "\n", 2)[0]
pasteURL, err := url.Parse(rawURL)
if err != nil {
t.Fatal(err)
}
// Check if the URL was encoded correctly.
req, err = http.NewRequest("GET", pasteURL.Path, nil)
if err != nil {
t.Fatal(err)
}
req = mux.SetURLVars(req, map[string]string{"key": pasteURL.Path[1:]})
rr = httptest.NewRecorder()
srv.ServeHTTP(rr, req)
checkStatusCode(t, rr, http.StatusTemporaryRedirect)
checkLocationHeader(t, rr, "https://example.com#fragment")
}