package rushlink import ( "bytes" "fmt" "io/ioutil" "mime/multipart" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "strings" "testing" "gitea.hashru.nl/dsprenkels/rushlink/internal/db" "github.com/gorilla/mux" "go.etcd.io/bbolt" ) // createTemporaryRouter initializes a rushlink instance, with temporary // filestore and database. // // It will use testing.T.Cleanup to cleanup after itself. func createTemporaryRouter(t *testing.T) (*mux.Router, *rushlink) { tempDir, err := ioutil.TempDir("", "rushlink-tmp-*") if err != nil { t.Fatalf("creating temporary directory: %s\n", err) } t.Cleanup(func() { os.RemoveAll(tempDir) }) fileStore, err := db.OpenFileStore(filepath.Join(tempDir, "filestore")) if err != nil { t.Fatalf("opening temporary filestore: %s\n", err) } databasePath := filepath.Join(tempDir, "rushlink.db") database, err := db.OpenDB(databasePath, fileStore) if err != nil { t.Fatalf("opening temporary database: %s\n", err) } t.Cleanup(func() { if err := database.Close(); err != nil { t.Errorf("closing database: %d\n", err) } }) // *.invalid. is guaranteed not to exist (RFC 6761). rootURL, err := url.Parse("https://rushlink.invalid") if err != nil { t.Fatalf("parsing URL: %s\n", err) } rl := rushlink{ db: database, fs: fileStore, rootURL: rootURL, } return CreateMainRouter(&rl), &rl } // checkStatusCode checks whether the status code from a recorded response is equal // to some response code. func checkStatusCode(t *testing.T, rr *httptest.ResponseRecorder, code int) { if actual := rr.Code; actual != code { t.Logf("request body:\n%v\n", rr.Body.String()) t.Fatalf("handler returned wrong status code: got %v want %v\n", actual, code) } } // checkLocationHeader checks whether the status code from a recorded response is equal // to some expected URL. func checkLocationHeader(t *testing.T, rr *httptest.ResponseRecorder, expected string) { location := rr.Header().Get("Location") if location != expected { t.Fatalf("handler returned bad redirect location: got %v want %v", location, expected) } } func TestIssue43(t *testing.T) { srv, _ := createTemporaryRouter(t) // Put a URL with a fragment identifier into the database. var body bytes.Buffer form := multipart.NewWriter(&body) form.WriteField("shorten", "https://example.com#fragment") form.Close() req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes())) if err != nil { t.Fatal(err) } req.Header.Add("Content-Type", form.FormDataContentType()) rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusFound) rawURL := strings.SplitN(rr.Body.String(), "\n", 2)[0] pasteURL, err := url.Parse(rawURL) if err != nil { t.Fatal(err) } // Check if the URL was encoded correctly. req, err = http.NewRequest("GET", pasteURL.Path, nil) if err != nil { t.Fatal(err) } req = mux.SetURLVars(req, map[string]string{"key": pasteURL.Path[1:]}) rr = httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusTemporaryRedirect) checkLocationHeader(t, rr, "https://example.com#fragment") } func TestIssue53(t *testing.T) { srv, rl := createTemporaryRouter(t) // Put a URL with a fragment identifier into the database. var body bytes.Buffer form := multipart.NewWriter(&body) if _, err := form.CreateFormFile("file", "../directory-traversal/file.txt"); err != nil { t.Fatal(err) } form.Close() req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes())) if err != nil { t.Fatal(err) } req.Header.Add("Content-Type", form.FormDataContentType()) rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusFound) // Check that any attempt to do directory traversal has failed. rl.db.Bolt.View(func(tx *bbolt.Tx) error { fus, err := db.AllFileUploads(tx) if err != nil { t.Fatal(err) } for _, fu := range fus { if strings.ContainsAny(fu.FileName, "/\\") { t.Fatalf(fmt.Sprintf("found a slash in file name: %v", fu.FileName)) } } return nil }) } func TestIssue60(t *testing.T) { srv, _ := createTemporaryRouter(t) // Request a nonexistent static file req, err := http.NewRequest("GET", "/css/nonexistent_file.css", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusNotFound) }