Prevent directory traversal in file upload

Fixes 
This commit is contained in:
Daan Sprenkels 2020-05-12 20:01:03 +02:00
parent 737a26fee3
commit 2c889e0808
2 changed files with 64 additions and 5 deletions

View File

@ -2,6 +2,7 @@ package rushlink
import ( import (
"bytes" "bytes"
"fmt"
"io/ioutil" "io/ioutil"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
@ -14,13 +15,14 @@ import (
"gitea.hashru.nl/dsprenkels/rushlink/internal/db" "gitea.hashru.nl/dsprenkels/rushlink/internal/db"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"go.etcd.io/bbolt"
) )
// createTemporaryRouter initializes a rushlink instance, with temporary // createTemporaryRouter initializes a rushlink instance, with temporary
// filestore and database. // filestore and database.
// //
// It will use testing.T.Cleanup to cleanup after itself. // It will use testing.T.Cleanup to cleanup after itself.
func createTemporaryRouter(t *testing.T) *mux.Router { func createTemporaryRouter(t *testing.T) (*mux.Router, *rushlink) {
tempDir, err := ioutil.TempDir("", "rushlink-tmp-*") tempDir, err := ioutil.TempDir("", "rushlink-tmp-*")
if err != nil { if err != nil {
t.Fatalf("creating temporary directory: %s\n", err) t.Fatalf("creating temporary directory: %s\n", err)
@ -55,7 +57,7 @@ func createTemporaryRouter(t *testing.T) *mux.Router {
fs: fileStore, fs: fileStore,
rootURL: rootURL, rootURL: rootURL,
} }
return CreateMainRouter(&rl) return CreateMainRouter(&rl), &rl
} }
// checkStatusCode checks whether the status code from a recorded response is equal // checkStatusCode checks whether the status code from a recorded response is equal
@ -78,7 +80,7 @@ func checkLocationHeader(t *testing.T, rr *httptest.ResponseRecorder, expected s
} }
func TestIssue43(t *testing.T) { func TestIssue43(t *testing.T) {
srv := createTemporaryRouter(t) srv, _ := createTemporaryRouter(t)
// Put a URL with a fragment identifier into the database. // Put a URL with a fragment identifier into the database.
var body bytes.Buffer var body bytes.Buffer
@ -110,3 +112,38 @@ func TestIssue43(t *testing.T) {
checkStatusCode(t, rr, http.StatusTemporaryRedirect) checkStatusCode(t, rr, http.StatusTemporaryRedirect)
checkLocationHeader(t, rr, "https://example.com#fragment") checkLocationHeader(t, rr, "https://example.com#fragment")
} }
func TestIssue53(t *testing.T) {
srv, rl := createTemporaryRouter(t)
// Put a URL with a fragment identifier into the database.
var body bytes.Buffer
form := multipart.NewWriter(&body)
if _, err := form.CreateFormFile("file", "../directory-traversal/file.txt"); err != nil {
t.Fatal(err)
}
form.Close()
req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes()))
if err != nil {
t.Fatal(err)
}
req.Header.Add("Content-Type", form.FormDataContentType())
rr := httptest.NewRecorder()
srv.ServeHTTP(rr, req)
checkStatusCode(t, rr, http.StatusFound)
// Check that any attempt to do directory traversal has failed.
rl.db.Bolt.View(func(tx *bbolt.Tx) error {
fus, err := db.AllFileUploads(tx)
if err != nil {
t.Fatal(err)
}
for _, fu := range fus {
if strings.ContainsAny(fu.FileName, "/\\") {
t.Fatalf(fmt.Sprintf("found a slash in file name: %v", fu.FileName))
}
}
return nil
})
}

View File

@ -126,7 +126,8 @@ func NewFileUpload(fs *FileStore, r io.Reader, fileName string) (*FileUpload, er
contentType := http.DetectContentType(tmpBuf.Bytes()) contentType := http.DetectContentType(tmpBuf.Bytes())
// Open the file on disk for writing // Open the file on disk for writing
filePath := fs.filePath(id, fileName) baseName := filepath.Base(fileName)
filePath := fs.filePath(id, baseName)
if err := os.Mkdir(path.Dir(filePath), dirMode); err != nil { if err := os.Mkdir(path.Dir(filePath), dirMode); err != nil {
return nil, errors.Wrap(err, "creating file dir") return nil, errors.Wrap(err, "creating file dir")
} }
@ -149,7 +150,7 @@ func NewFileUpload(fs *FileStore, r io.Reader, fileName string) (*FileUpload, er
fu := &FileUpload{ fu := &FileUpload{
State: FileUploadStatePresent, State: FileUploadStatePresent,
ID: id, ID: id,
FileName: fileName, FileName: baseName,
ContentType: contentType, ContentType: contentType,
Checksum: hash.Sum32(), Checksum: hash.Sum32(),
} }
@ -169,6 +170,27 @@ func GetFileUpload(tx *bolt.Tx, id uuid.UUID) (*FileUpload, error) {
return decodeFileUpload(storedBytes) return decodeFileUpload(storedBytes)
} }
// AllFileUploads tries to retrieve all FileUpload objects from the bolt database.
func AllFileUploads(tx *bolt.Tx) ([]FileUpload, error) {
bucket := tx.Bucket([]byte(BucketFileUpload))
if bucket == nil {
return nil, errors.Errorf("bucket %v does not exist", BucketFileUpload)
}
var fus []FileUpload
err := bucket.ForEach(func(_, storedBytes []byte) error {
fu, err := decodeFileUpload(storedBytes)
if err != nil {
return err
}
fus = append(fus, *fu)
return nil
})
if err != nil {
return nil, err
}
return fus, nil
}
func decodeFileUpload(storedBytes []byte) (*FileUpload, error) { func decodeFileUpload(storedBytes []byte) (*FileUpload, error) {
fu := &FileUpload{} fu := &FileUpload{}
err := gobmarsh.Unmarshal(storedBytes, fu) err := gobmarsh.Unmarshal(storedBytes, fu)