forked from electricdusk/rushlink
		
	
		
			
				
	
	
		
			150 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			150 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package rushlink
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"fmt"
 | 
						|
	"io/ioutil"
 | 
						|
	"mime/multipart"
 | 
						|
	"net/http"
 | 
						|
	"net/http/httptest"
 | 
						|
	"net/url"
 | 
						|
	"os"
 | 
						|
	"path/filepath"
 | 
						|
	"strings"
 | 
						|
	"testing"
 | 
						|
 | 
						|
	"gitea.hashru.nl/dsprenkels/rushlink/internal/db"
 | 
						|
	"github.com/gorilla/mux"
 | 
						|
	"go.etcd.io/bbolt"
 | 
						|
)
 | 
						|
 | 
						|
// createTemporaryRouter initializes a rushlink instance, with temporary
 | 
						|
// filestore and database.
 | 
						|
//
 | 
						|
// It will use testing.T.Cleanup to cleanup after itself.
 | 
						|
func createTemporaryRouter(t *testing.T) (*mux.Router, *rushlink) {
 | 
						|
	tempDir, err := ioutil.TempDir("", "rushlink-tmp-*")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("creating temporary directory: %s\n", err)
 | 
						|
	}
 | 
						|
	t.Cleanup(func() {
 | 
						|
		os.RemoveAll(tempDir)
 | 
						|
	})
 | 
						|
 | 
						|
	fileStore, err := db.OpenFileStore(filepath.Join(tempDir, "filestore"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("opening temporary filestore: %s\n", err)
 | 
						|
	}
 | 
						|
	databasePath := filepath.Join(tempDir, "rushlink.db")
 | 
						|
	database, err := db.OpenDB(databasePath, fileStore)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("opening temporary database: %s\n", err)
 | 
						|
	}
 | 
						|
	t.Cleanup(func() {
 | 
						|
		if err := database.Close(); err != nil {
 | 
						|
			t.Errorf("closing database: %d\n", err)
 | 
						|
		}
 | 
						|
	})
 | 
						|
 | 
						|
	// *.invalid. is guaranteed not to exist (RFC 6761).
 | 
						|
	rootURL, err := url.Parse("https://rushlink.invalid")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("parsing URL: %s\n", err)
 | 
						|
	}
 | 
						|
 | 
						|
	rl := rushlink{
 | 
						|
		db:      database,
 | 
						|
		fs:      fileStore,
 | 
						|
		rootURL: rootURL,
 | 
						|
	}
 | 
						|
	return CreateMainRouter(&rl), &rl
 | 
						|
}
 | 
						|
 | 
						|
// checkStatusCode checks whether the status code from a recorded response is equal
 | 
						|
// to some response code.
 | 
						|
func checkStatusCode(t *testing.T, rr *httptest.ResponseRecorder, code int) {
 | 
						|
	if actual := rr.Code; actual != code {
 | 
						|
		t.Logf("request body:\n%v\n", rr.Body.String())
 | 
						|
		t.Fatalf("handler returned wrong status code: got %v want %v\n",
 | 
						|
			actual, code)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// checkLocationHeader checks whether the status code from a recorded response is equal
 | 
						|
// to some expected URL.
 | 
						|
func checkLocationHeader(t *testing.T, rr *httptest.ResponseRecorder, expected string) {
 | 
						|
	location := rr.Header().Get("Location")
 | 
						|
	if location != expected {
 | 
						|
		t.Fatalf("handler returned bad redirect location: got %v want %v", location, expected)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestIssue43(t *testing.T) {
 | 
						|
	srv, _ := createTemporaryRouter(t)
 | 
						|
 | 
						|
	// Put a URL with a fragment identifier into the database.
 | 
						|
	var body bytes.Buffer
 | 
						|
	form := multipart.NewWriter(&body)
 | 
						|
	form.WriteField("shorten", "https://example.com#fragment")
 | 
						|
	form.Close()
 | 
						|
	req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes()))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	req.Header.Add("Content-Type", form.FormDataContentType())
 | 
						|
	rr := httptest.NewRecorder()
 | 
						|
	srv.ServeHTTP(rr, req)
 | 
						|
	checkStatusCode(t, rr, http.StatusFound)
 | 
						|
	rawURL := strings.SplitN(rr.Body.String(), "\n", 2)[0]
 | 
						|
	pasteURL, err := url.Parse(rawURL)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Check if the URL was encoded correctly.
 | 
						|
	req, err = http.NewRequest("GET", pasteURL.Path, nil)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	req = mux.SetURLVars(req, map[string]string{"key": pasteURL.Path[1:]})
 | 
						|
	rr = httptest.NewRecorder()
 | 
						|
	srv.ServeHTTP(rr, req)
 | 
						|
	checkStatusCode(t, rr, http.StatusTemporaryRedirect)
 | 
						|
	checkLocationHeader(t, rr, "https://example.com#fragment")
 | 
						|
}
 | 
						|
 | 
						|
func TestIssue53(t *testing.T) {
 | 
						|
	srv, rl := createTemporaryRouter(t)
 | 
						|
 | 
						|
	// Put a URL with a fragment identifier into the database.
 | 
						|
	var body bytes.Buffer
 | 
						|
	form := multipart.NewWriter(&body)
 | 
						|
	if _, err := form.CreateFormFile("file", "../directory-traversal/file.txt"); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	form.Close()
 | 
						|
	req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes()))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	req.Header.Add("Content-Type", form.FormDataContentType())
 | 
						|
	rr := httptest.NewRecorder()
 | 
						|
	srv.ServeHTTP(rr, req)
 | 
						|
	checkStatusCode(t, rr, http.StatusFound)
 | 
						|
 | 
						|
	// Check that any attempt to do directory traversal has failed.
 | 
						|
	rl.db.Bolt.View(func(tx *bbolt.Tx) error {
 | 
						|
		fus, err := db.AllFileUploads(tx)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatal(err)
 | 
						|
		}
 | 
						|
		for _, fu := range fus {
 | 
						|
			if strings.ContainsAny(fu.FileName, "/\\") {
 | 
						|
				t.Fatalf(fmt.Sprintf("found a slash in file name: %v", fu.FileName))
 | 
						|
			}
 | 
						|
		}
 | 
						|
		return nil
 | 
						|
	})
 | 
						|
 | 
						|
}
 |