forked from electricdusk/rushlink
222 lines
5.9 KiB
Go
222 lines
5.9 KiB
Go
package rushlink
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"gitea.hashru.nl/dsprenkels/rushlink/internal/boltdb"
|
|
"github.com/gorilla/mux"
|
|
"go.etcd.io/bbolt"
|
|
)
|
|
|
|
// createTemporaryRouter initializes a rushlink instance, with temporary
|
|
// filestore and database.
|
|
//
|
|
// It will use testing.T.Cleanup to cleanup after itself.
|
|
func createTemporaryRouter(t *testing.T) (*mux.Router, *rushlink) {
|
|
tempDir, err := ioutil.TempDir("", "rushlink-tmp-*")
|
|
if err != nil {
|
|
t.Fatalf("creating temporary directory: %s\n", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
os.RemoveAll(tempDir)
|
|
})
|
|
|
|
fileStore, err := boltdb.OpenFileStore(filepath.Join(tempDir, "filestore"))
|
|
if err != nil {
|
|
t.Fatalf("opening temporary filestore: %s\n", err)
|
|
}
|
|
databasePath := filepath.Join(tempDir, "rushlink.db")
|
|
database, err := boltdb.OpenDB(databasePath, fileStore)
|
|
if err != nil {
|
|
t.Fatalf("opening temporary database: %s\n", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
if err := database.Close(); err != nil {
|
|
t.Errorf("closing database: %d\n", err)
|
|
}
|
|
})
|
|
|
|
// *.invalid. is guaranteed not to exist (RFC 6761).
|
|
rootURL, err := url.Parse("https://rushlink.invalid")
|
|
if err != nil {
|
|
t.Fatalf("parsing URL: %s\n", err)
|
|
}
|
|
|
|
rl := rushlink{
|
|
db: database,
|
|
fs: fileStore,
|
|
rootURL: rootURL,
|
|
}
|
|
r := mux.NewRouter()
|
|
InitMainRouter(r, &rl)
|
|
return r, &rl
|
|
}
|
|
|
|
// checkStatusCode checks whether the status code from a recorded response is equal
|
|
// to some response code.
|
|
func checkStatusCode(t *testing.T, rr *httptest.ResponseRecorder, code int) {
|
|
if actual := rr.Code; actual != code {
|
|
t.Logf("request body:\n%v\n", rr.Body.String())
|
|
t.Fatalf("handler returned wrong status code: got %v want %v\n",
|
|
actual, code)
|
|
}
|
|
}
|
|
|
|
// checkLocationHeader checks whether the status code from a recorded response is equal
|
|
// to some expected URL.
|
|
func checkLocationHeader(t *testing.T, rr *httptest.ResponseRecorder, expected string) {
|
|
location := rr.Header().Get("Location")
|
|
if location != expected {
|
|
t.Fatalf("handler returned bad redirect location: got %v want %v", location, expected)
|
|
}
|
|
}
|
|
|
|
func TestIssue43(t *testing.T) {
|
|
srv, _ := createTemporaryRouter(t)
|
|
|
|
// Put a URL with a fragment identifier into the database.
|
|
var body bytes.Buffer
|
|
form := multipart.NewWriter(&body)
|
|
form.WriteField("shorten", "https://example.com#fragment")
|
|
form.Close()
|
|
req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes()))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Add("Content-Type", form.FormDataContentType())
|
|
rr := httptest.NewRecorder()
|
|
srv.ServeHTTP(rr, req)
|
|
checkStatusCode(t, rr, http.StatusFound)
|
|
rawURL := strings.SplitN(rr.Body.String(), "\n", 2)[0]
|
|
pasteURL, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Check if the URL was encoded correctly.
|
|
req, err = http.NewRequest("GET", pasteURL.Path, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req = mux.SetURLVars(req, map[string]string{"key": pasteURL.Path[1:]})
|
|
rr = httptest.NewRecorder()
|
|
srv.ServeHTTP(rr, req)
|
|
checkStatusCode(t, rr, http.StatusTemporaryRedirect)
|
|
checkLocationHeader(t, rr, "https://example.com#fragment")
|
|
}
|
|
|
|
func TestIssue53(t *testing.T) {
|
|
srv, rl := createTemporaryRouter(t)
|
|
|
|
// Put a URL with a fragment identifier into the database.
|
|
var body bytes.Buffer
|
|
form := multipart.NewWriter(&body)
|
|
if _, err := form.CreateFormFile("file", "../directory-traversal/file.txt"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
form.Close()
|
|
req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes()))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Add("Content-Type", form.FormDataContentType())
|
|
rr := httptest.NewRecorder()
|
|
srv.ServeHTTP(rr, req)
|
|
checkStatusCode(t, rr, http.StatusFound)
|
|
|
|
// Check that any attempt to do directory traversal has failed.
|
|
rl.db.Bolt.View(func(tx *bbolt.Tx) error {
|
|
fus, err := boltdb.AllFileUploads(tx)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for _, fu := range fus {
|
|
if strings.ContainsAny(fu.FileName, "/\\") {
|
|
t.Fatalf(fmt.Sprintf("found a slash in file name: %v", fu.FileName))
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func TestIssue60(t *testing.T) {
|
|
srv, _ := createTemporaryRouter(t)
|
|
|
|
// Request a nonexistent static file
|
|
req, err := http.NewRequest("GET", "/css/nonexistent_file.css", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
rr := httptest.NewRecorder()
|
|
srv.ServeHTTP(rr, req)
|
|
checkStatusCode(t, rr, http.StatusNotFound)
|
|
}
|
|
|
|
func TestIssue56(t *testing.T) {
|
|
srv, _ := createTemporaryRouter(t)
|
|
|
|
// Make a POST request with both a 'file' *and* a 'shorten' part.
|
|
var body bytes.Buffer
|
|
form := multipart.NewWriter(&body)
|
|
if _, err := form.CreateFormFile("file", "empty.txt"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := form.CreateFormField("shorten"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
form.Close()
|
|
req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes()))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Add("Content-Type", form.FormDataContentType())
|
|
rr := httptest.NewRecorder()
|
|
srv.ServeHTTP(rr, req)
|
|
checkStatusCode(t, rr, http.StatusBadRequest)
|
|
}
|
|
|
|
func TestIssue68(t *testing.T) {
|
|
srv, _ := createTemporaryRouter(t)
|
|
originalURL := "https://example.com"
|
|
|
|
var body bytes.Buffer
|
|
form := multipart.NewWriter(&body)
|
|
form.WriteField("shorten", "https://example.com")
|
|
form.Close()
|
|
req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes()))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Add("Content-Type", form.FormDataContentType())
|
|
rr := httptest.NewRecorder()
|
|
srv.ServeHTTP(rr, req)
|
|
checkStatusCode(t, rr, http.StatusFound)
|
|
rawURL := strings.SplitN(rr.Body.String(), "\n", 2)[0]
|
|
pasteURL, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Check if the no-redirect handler works properly.
|
|
req, err = http.NewRequest("GET", pasteURL.Path+"/nr", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
rr = httptest.NewRecorder()
|
|
srv.ServeHTTP(rr, req)
|
|
checkStatusCode(t, rr, http.StatusOK)
|
|
if rr.Body.String() != originalURL {
|
|
t.Errorf("incorrect URL = %v, want %v", rr.Body.String(), originalURL)
|
|
}
|
|
}
|