diff --git a/handlers_test.go b/handlers_test.go index 4c3de8d..5f07184 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -2,6 +2,7 @@ package rushlink import ( "bytes" + "fmt" "io/ioutil" "mime/multipart" "net/http" @@ -14,13 +15,14 @@ import ( "gitea.hashru.nl/dsprenkels/rushlink/internal/db" "github.com/gorilla/mux" + "go.etcd.io/bbolt" ) // createTemporaryRouter initializes a rushlink instance, with temporary // filestore and database. // // It will use testing.T.Cleanup to cleanup after itself. -func createTemporaryRouter(t *testing.T) *mux.Router { +func createTemporaryRouter(t *testing.T) (*mux.Router, *rushlink) { tempDir, err := ioutil.TempDir("", "rushlink-tmp-*") if err != nil { t.Fatalf("creating temporary directory: %s\n", err) @@ -55,7 +57,7 @@ func createTemporaryRouter(t *testing.T) *mux.Router { fs: fileStore, rootURL: rootURL, } - return CreateMainRouter(&rl) + return CreateMainRouter(&rl), &rl } // checkStatusCode checks whether the status code from a recorded response is equal @@ -78,7 +80,7 @@ func checkLocationHeader(t *testing.T, rr *httptest.ResponseRecorder, expected s } func TestIssue43(t *testing.T) { - srv := createTemporaryRouter(t) + srv, _ := createTemporaryRouter(t) // Put a URL with a fragment identifier into the database. var body bytes.Buffer @@ -110,3 +112,38 @@ func TestIssue43(t *testing.T) { checkStatusCode(t, rr, http.StatusTemporaryRedirect) checkLocationHeader(t, rr, "https://example.com#fragment") } + +func TestIssue53(t *testing.T) { + srv, rl := createTemporaryRouter(t) + + // Put a URL with a fragment identifier into the database. + var body bytes.Buffer + form := multipart.NewWriter(&body) + if _, err := form.CreateFormFile("file", "../directory-traversal/file.txt"); err != nil { + t.Fatal(err) + } + form.Close() + req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes())) + if err != nil { + t.Fatal(err) + } + req.Header.Add("Content-Type", form.FormDataContentType()) + rr := httptest.NewRecorder() + srv.ServeHTTP(rr, req) + checkStatusCode(t, rr, http.StatusFound) + + // Check that any attempt to do directory traversal has failed. + rl.db.Bolt.View(func(tx *bbolt.Tx) error { + fus, err := db.AllFileUploads(tx) + if err != nil { + t.Fatal(err) + } + for _, fu := range fus { + if strings.ContainsAny(fu.FileName, "/\\") { + t.Fatalf(fmt.Sprintf("found a slash in file name: %v", fu.FileName)) + } + } + return nil + }) + +} diff --git a/internal/db/fileupload.go b/internal/db/fileupload.go index 3f3829e..f390ef2 100644 --- a/internal/db/fileupload.go +++ b/internal/db/fileupload.go @@ -126,7 +126,8 @@ func NewFileUpload(fs *FileStore, r io.Reader, fileName string) (*FileUpload, er contentType := http.DetectContentType(tmpBuf.Bytes()) // Open the file on disk for writing - filePath := fs.filePath(id, fileName) + baseName := filepath.Base(fileName) + filePath := fs.filePath(id, baseName) if err := os.Mkdir(path.Dir(filePath), dirMode); err != nil { return nil, errors.Wrap(err, "creating file dir") } @@ -149,7 +150,7 @@ func NewFileUpload(fs *FileStore, r io.Reader, fileName string) (*FileUpload, er fu := &FileUpload{ State: FileUploadStatePresent, ID: id, - FileName: fileName, + FileName: baseName, ContentType: contentType, Checksum: hash.Sum32(), } @@ -169,6 +170,27 @@ func GetFileUpload(tx *bolt.Tx, id uuid.UUID) (*FileUpload, error) { return decodeFileUpload(storedBytes) } +// AllFileUploads tries to retrieve all FileUpload objects from the bolt database. +func AllFileUploads(tx *bolt.Tx) ([]FileUpload, error) { + bucket := tx.Bucket([]byte(BucketFileUpload)) + if bucket == nil { + return nil, errors.Errorf("bucket %v does not exist", BucketFileUpload) + } + var fus []FileUpload + err := bucket.ForEach(func(_, storedBytes []byte) error { + fu, err := decodeFileUpload(storedBytes) + if err != nil { + return err + } + fus = append(fus, *fu) + return nil + }) + if err != nil { + return nil, err + } + return fus, nil +} + func decodeFileUpload(storedBytes []byte) (*FileUpload, error) { fu := &FileUpload{} err := gobmarsh.Unmarshal(storedBytes, fu)