package rushlink import ( "bytes" "fmt" "io/ioutil" "mime/multipart" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "strings" "testing" "gitea.hashru.nl/dsprenkels/rushlink/internal/db" "github.com/gorilla/mux" ) // createTemporaryRouter initializes a rushlink instance, with temporary // filestore and database. // // It will use testing.T.Cleanup to cleanup after itself. func createTemporaryRouter(t *testing.T) (*mux.Router, *rushlink) { tempDir, err := ioutil.TempDir("", "rushlink-tmp-*") if err != nil { t.Fatalf("creating temporary directory: %s\n", err) } t.Cleanup(func() { os.RemoveAll(tempDir) }) fileStore, err := db.OpenFileStore(filepath.Join(tempDir, "filestore")) if err != nil { t.Fatalf("opening temporary filestore: %s\n", err) } os.Setenv(db.EnvDatabaseDriver, "sqlite") os.Setenv(db.EnvDatabasePath, "file::memory:?cache=shared") database, err := db.OpenDBFromEnvironment() if err != nil { t.Fatalf("opening temporary database: %v\n", err) } database.Debug() if err := db.Gormigrate(database).Migrate(); err != nil { t.Fatalf("migration failed: %v\n", err) } // *.invalid. is guaranteed not to exist (RFC 6761). rootURL, err := url.Parse("https://rushlink.invalid") if err != nil { t.Fatalf("parsing URL: %s\n", err) } rl := rushlink{ db: database, fs: fileStore, rootURL: rootURL, } r := mux.NewRouter() InitMainRouter(r, &rl) return r, &rl } // checkStatusCode checks whether the status code from a recorded response is equal // to some response code. func checkStatusCode(t *testing.T, rr *httptest.ResponseRecorder, code int) { if actual := rr.Code; actual != code { t.Logf("request body:\n%v\n", rr.Body.String()) t.Fatalf("handler returned wrong status code: got %v want %v\n", actual, code) } } // checkLocationHeader checks whether the status code from a recorded response is equal // to some expected URL. func checkLocationHeader(t *testing.T, rr *httptest.ResponseRecorder, expected string) { location := rr.Header().Get("Location") if location != expected { t.Fatalf("handler returned bad redirect location: got %v want %v", location, expected) } } func TestIssue43(t *testing.T) { srv, _ := createTemporaryRouter(t) // Put a URL with a fragment identifier into the database. var body bytes.Buffer form := multipart.NewWriter(&body) form.WriteField("shorten", "https://example.com#fragment") form.Close() req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes())) if err != nil { t.Fatal(err) } req.Header.Add("Content-Type", form.FormDataContentType()) rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusFound) rawURL := strings.SplitN(rr.Body.String(), "\n", 2)[0] pasteURL, err := url.Parse(rawURL) if err != nil { t.Fatal(err) } // Check if the URL was encoded correctly. req, err = http.NewRequest("GET", pasteURL.Path, nil) if err != nil { t.Fatal(err) } req = mux.SetURLVars(req, map[string]string{"key": pasteURL.Path[1:]}) rr = httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusTemporaryRedirect) checkLocationHeader(t, rr, "https://example.com#fragment") } func TestIssue53(t *testing.T) { srv, rl := createTemporaryRouter(t) // Put a URL with a fragment identifier into the database. var body bytes.Buffer form := multipart.NewWriter(&body) if _, err := form.CreateFormFile("file", "../directory-traversal/file.txt"); err != nil { t.Fatal(err) } form.Close() req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes())) if err != nil { t.Fatal(err) } req.Header.Add("Content-Type", form.FormDataContentType()) rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusFound) // Check that any attempt to do directory traversal has failed. rl.db.Transaction(func(tx *db.Database) error { fus, err := db.AllFileUploads(tx) if err != nil { t.Fatal(err) } for _, fu := range fus { if strings.ContainsAny(fu.FileName, "/\\") { t.Fatalf(fmt.Sprintf("found a slash in file name: %v", fu.FileName)) } } return nil }) } func TestIssue60(t *testing.T) { srv, _ := createTemporaryRouter(t) // Request a nonexistent static file req, err := http.NewRequest("GET", "/css/nonexistent_file.css", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusNotFound) } func TestIssue56(t *testing.T) { srv, _ := createTemporaryRouter(t) // Make a POST request with both a 'file' *and* a 'shorten' part. var body bytes.Buffer form := multipart.NewWriter(&body) if _, err := form.CreateFormFile("file", "empty.txt"); err != nil { t.Fatal(err) } if _, err := form.CreateFormField("shorten"); err != nil { t.Fatal(err) } form.Close() req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes())) if err != nil { t.Fatal(err) } req.Header.Add("Content-Type", form.FormDataContentType()) rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusBadRequest) } func TestIssue68(t *testing.T) { srv, _ := createTemporaryRouter(t) originalURL := "https://example.com" var body bytes.Buffer form := multipart.NewWriter(&body) form.WriteField("shorten", "https://example.com") form.Close() req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes())) if err != nil { t.Fatal(err) } req.Header.Add("Content-Type", form.FormDataContentType()) rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusFound) rawURL := strings.SplitN(rr.Body.String(), "\n", 2)[0] pasteURL, err := url.Parse(rawURL) if err != nil { t.Fatal(err) } // Check if the no-redirect handler works properly. req, err = http.NewRequest("GET", pasteURL.Path+"/nr", nil) if err != nil { t.Fatal(err) } rr = httptest.NewRecorder() srv.ServeHTTP(rr, req) checkStatusCode(t, rr, http.StatusOK) if rr.Body.String() != originalURL { t.Errorf("incorrect URL = %v, want %v", rr.Body.String(), originalURL) } }