forked from electricdusk/rushlink
Use sql database instead of bolt
This commit is contained in:
@@ -322,3 +322,24 @@ func GenerateDeleteToken() (string, error) {
|
||||
}
|
||||
return hex.EncodeToString(deleteToken[:]), nil
|
||||
}
|
||||
|
||||
// AllPastes tries to retrieve all the Paste objects from the database.
|
||||
func AllPastes(tx *bolt.Tx) ([]Paste, error) {
|
||||
bucket := tx.Bucket([]byte(BucketPastes))
|
||||
if bucket == nil {
|
||||
return nil, errors.Errorf("bucket %v does not exist", BucketPastes)
|
||||
}
|
||||
var ps []Paste
|
||||
err := bucket.ForEach(func(_, storedBytes []byte) error {
|
||||
p, err := decodePaste(storedBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ps = append(ps, *p)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
74
internal/db/db.go
Normal file
74
internal/db/db.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDatabaseDriver is the key for the 'database driver' environment variable.
|
||||
EnvDatabaseDriver = "RUSHLINK_DATABASE_DRIVER"
|
||||
// EnvDatabasePath is the key for the 'database path' environment variable.
|
||||
EnvDatabasePath = "RUSHLINK_DATABASE_PATH"
|
||||
// EnvPostgresURL is the key for the 'postgresql URL' environment variable.
|
||||
EnvPostgresURL = "RUSHLINK_POSTGRES_URL"
|
||||
// EnvFileStorePath is the key for the environment variable locating the file store.
|
||||
EnvFileStorePath = "RUSHLINK_FILE_STORE_PATH"
|
||||
)
|
||||
|
||||
// Database is the main rushlink database type.
|
||||
//
|
||||
// Open a database using OpenDBFromEnvironment(). Closing is not necessary.
|
||||
// Only one instance of DB should exist in a program at any moment.
|
||||
type Database = gorm.DB
|
||||
|
||||
var (
|
||||
gormLogger logger.Interface = logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{
|
||||
SlowThreshold: 50 * time.Millisecond,
|
||||
LogLevel: logger.Error,
|
||||
Colorful: true,
|
||||
})
|
||||
gormConfig = gorm.Config{Logger: gormLogger}
|
||||
)
|
||||
|
||||
// OpenDBFromEnvironment tries to open an SQL database, described by
|
||||
func OpenDBFromEnvironment() (*Database, error) {
|
||||
const envNotSetMsg = "%v environment variable is not set"
|
||||
|
||||
driver, prs := os.LookupEnv(EnvDatabaseDriver)
|
||||
if !prs {
|
||||
return nil, errors.Errorf(envNotSetMsg, EnvDatabaseDriver)
|
||||
}
|
||||
switch driver {
|
||||
case "sqlite":
|
||||
path, prs := os.LookupEnv(EnvDatabasePath)
|
||||
if !prs {
|
||||
return nil, errors.Errorf(envNotSetMsg, EnvDatabasePath)
|
||||
}
|
||||
db, err := gorm.Open(sqlite.Open(path), &gormConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
case "postgres":
|
||||
dsn, prs := os.LookupEnv(EnvPostgresURL)
|
||||
if !prs {
|
||||
return nil, errors.Errorf(envNotSetMsg, EnvPostgresURL)
|
||||
}
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gormConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
default:
|
||||
return nil, errors.Errorf("RUSHLINK_DATABASE_DRIVER should be either 'sqlite' or 'postgres' (not '%v'), "+
|
||||
"for more info see <https://stackoverflow.com/q/3582552/5207081>", driver)
|
||||
}
|
||||
}
|
||||
240
internal/db/fileupload.go
Normal file
240
internal/db/fileupload.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Use the Castagnoli checksum because of the acceleration on Intel CPUs
|
||||
var checksumTable = crc32.MakeTable(crc32.Castagnoli)
|
||||
|
||||
// FileStore holds the path to a file storage location.
|
||||
type FileStore struct {
|
||||
path string
|
||||
}
|
||||
|
||||
// FileUploadState determines the current state of a FileUpload object.
|
||||
type FileUploadState int
|
||||
|
||||
// FileUpload models an uploaded file.
|
||||
type FileUpload struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
|
||||
// State of the FileUpload (present/deleted/etc).
|
||||
State FileUploadState `gorm:"index"`
|
||||
|
||||
// UUID publically identifies this FileUpload.
|
||||
PubID uuid.UUID `gorm:"uniqueIndex"`
|
||||
|
||||
// FileName contains the original filename of this FileUpload.
|
||||
FileName string
|
||||
|
||||
// Content type as determined by http.DetectContentType.
|
||||
ContentType string
|
||||
|
||||
// Checksum holds a crc32c checksum of the file.
|
||||
//
|
||||
// This checksum is only meant to allow for the detection of random
|
||||
// database corruption.
|
||||
Checksum uint32
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt
|
||||
}
|
||||
|
||||
const (
|
||||
dirMode os.FileMode = 0750
|
||||
fileMode os.FileMode = 0640
|
||||
)
|
||||
|
||||
const (
|
||||
// FileUploadStateUndef is an undefined FileUpload.
|
||||
FileUploadStateUndef FileUploadState = 0
|
||||
// FileUploadStatePresent denotes the normal (existing) state.
|
||||
FileUploadStatePresent FileUploadState = 1
|
||||
// FileUploadStateDeleted denotes a deleted state.
|
||||
FileUploadStateDeleted FileUploadState = 2
|
||||
)
|
||||
|
||||
func (t FileUploadState) String() string {
|
||||
switch t {
|
||||
case FileUploadStateUndef:
|
||||
return "unknown"
|
||||
case FileUploadStatePresent:
|
||||
return "present"
|
||||
case FileUploadStateDeleted:
|
||||
return "deleted"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
}
|
||||
|
||||
// ErrFileUploadDoesNotExist occurs when a key does not exist in the database.
|
||||
var ErrFileUploadDoesNotExist = errors.New("file not found in the database")
|
||||
|
||||
// OpenFileStoreFromEnvironment tries to open a file store located at ${RUSHLINK_FILE_STORE_PATH}.
|
||||
func OpenFileStoreFromEnvironment() (*FileStore, error) {
|
||||
path, prs := os.LookupEnv(EnvFileStorePath)
|
||||
if !prs {
|
||||
err := errors.Errorf("%v environment variable is not set", EnvFileStorePath)
|
||||
return nil, errors.Wrap(err, "opening file store")
|
||||
}
|
||||
return OpenFileStore(path)
|
||||
}
|
||||
|
||||
// OpenFileStore opens the file storage at path.
|
||||
func OpenFileStore(path string) (*FileStore, error) {
|
||||
if path == "" {
|
||||
return nil, errors.New("file-store not set")
|
||||
}
|
||||
|
||||
// Try to create the file store directory if it does not yet exist
|
||||
if err := os.MkdirAll(path, dirMode); err != nil {
|
||||
return nil, errors.Wrap(err, "creating file store directory")
|
||||
}
|
||||
|
||||
return &FileStore{path[:]}, nil
|
||||
}
|
||||
|
||||
// Path returns the path of the FileStore root.
|
||||
func (fs *FileStore) Path() string {
|
||||
return fs.path
|
||||
}
|
||||
|
||||
// filePath resolves the path of a file in the FileStore given some id and filename.
|
||||
func (fs *FileStore) filePath(pubID uuid.UUID, fileName string) string {
|
||||
if fs.path == "" {
|
||||
panic("fileStoreDir called while the file store path has not been set")
|
||||
}
|
||||
return path.Join(fs.path, hex.EncodeToString(pubID[:]), fileName)
|
||||
}
|
||||
|
||||
// NewFileUpload creates a new FileUpload object.
|
||||
//
|
||||
// Internally, this function detects the type of the file stored in `r` using
|
||||
// `http.DetectContentType`.
|
||||
func NewFileUpload(fs *FileStore, r io.Reader, fileName string) (*FileUpload, error) {
|
||||
// Generate a file ID
|
||||
pubID, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "generating UUID")
|
||||
}
|
||||
|
||||
// Construct a checksum for this file
|
||||
hash := crc32.New(checksumTable)
|
||||
tee := io.TeeReader(r, hash)
|
||||
|
||||
// Detect the file type
|
||||
var tmpBuf bytes.Buffer
|
||||
tmpBuf.Grow(512)
|
||||
io.CopyN(&tmpBuf, tee, 512)
|
||||
contentType := http.DetectContentType(tmpBuf.Bytes())
|
||||
|
||||
// Open the file on disk for writing
|
||||
baseName := filepath.Base(fileName)
|
||||
filePath := fs.filePath(pubID, baseName)
|
||||
if err := os.Mkdir(path.Dir(filePath), dirMode); err != nil {
|
||||
return nil, errors.Wrap(err, "creating file dir")
|
||||
}
|
||||
file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, fileMode)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "opening file")
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write the file to disk
|
||||
_, err = io.Copy(file, &tmpBuf)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "writing to file")
|
||||
}
|
||||
_, err = io.Copy(file, tee)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "writing to file")
|
||||
}
|
||||
|
||||
fu := &FileUpload{
|
||||
State: FileUploadStatePresent,
|
||||
PubID: pubID,
|
||||
FileName: baseName,
|
||||
ContentType: contentType,
|
||||
Checksum: hash.Sum32(),
|
||||
}
|
||||
return fu, nil
|
||||
}
|
||||
|
||||
// GetFileUpload tries to retrieve a FileUpload object from the bolt database.
|
||||
func GetFileUpload(db *gorm.DB, pubID uuid.UUID) (*FileUpload, error) {
|
||||
var fus []FileUpload
|
||||
if err := db.Unscoped().Limit(1).Where("pub_id = ?", pubID).Find(&fus).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(fus) == 0 {
|
||||
return nil, ErrFileUploadDoesNotExist
|
||||
}
|
||||
return &fus[0], nil
|
||||
}
|
||||
|
||||
// AllFileUploads tries to retrieve all FileUpload objects from the bolt database.
|
||||
func AllFileUploads(db *gorm.DB) ([]FileUpload, error) {
|
||||
var fus []FileUpload
|
||||
if err := db.Find(&fus).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fus, nil
|
||||
}
|
||||
|
||||
// Save saves a FileUpload in the database.
|
||||
func (fu *FileUpload) Save(db *gorm.DB) error {
|
||||
return db.Save(fu).Error
|
||||
}
|
||||
|
||||
// Delete deletes a FileUpload from the database.
|
||||
func (fu *FileUpload) Delete(db *gorm.DB, fs *FileStore) error {
|
||||
// Remove the file in the backend
|
||||
filePath := fu.Path(fs)
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the file in the server
|
||||
if err := db.Delete(fu).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cleanup the parent directory
|
||||
wrap := "deletion succeeded, but removing the file directory has failed"
|
||||
return errors.Wrap(os.Remove(path.Dir(filePath)), wrap)
|
||||
}
|
||||
|
||||
// Path returns the path to this FileUpload in the FileStore provided in fs.
|
||||
func (fu *FileUpload) Path(fs *FileStore) string {
|
||||
return fs.filePath(fu.PubID, fu.FileName)
|
||||
}
|
||||
|
||||
// URL returns the URL for the FileUpload.
|
||||
func (fu *FileUpload) URL() *url.URL {
|
||||
rawurl := "/uploads/" + hex.EncodeToString(fu.PubID[:]) + "/" + fu.FileName
|
||||
urlParse, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
panic("could not construct /uploads/ url")
|
||||
}
|
||||
return urlParse
|
||||
}
|
||||
|
||||
// Ext returns the extension of the file attached to this FileUpload.
|
||||
func (fu *FileUpload) Ext() string {
|
||||
return filepath.Ext(fu.FileName)
|
||||
}
|
||||
46
internal/db/migrate.go
Normal file
46
internal/db/migrate.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
gormigrate "github.com/go-gormigrate/gormigrate/v2"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Gormigrate returns a Gormigrate migrator for the database.
|
||||
func Gormigrate(db *gorm.DB) *gormigrate.Gormigrate {
|
||||
return gormigrate.New(db, gormigrate.DefaultOptions, []*gormigrate.Migration{
|
||||
{
|
||||
ID: "202010251337",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
type FileUpload struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
State FileUploadState `gorm:"index"`
|
||||
PubID uuid.UUID `gorm:"uniqueIndex"`
|
||||
FileName string
|
||||
ContentType string
|
||||
Checksum uint32
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt
|
||||
}
|
||||
type Paste struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Type PasteType `gorm:"index"`
|
||||
State PasteState `gorm:"index"`
|
||||
Content []byte
|
||||
Key string `gorm:"uniqueIndex"`
|
||||
DeleteToken string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt
|
||||
}
|
||||
return tx.AutoMigrate(&FileUpload{}, &Paste{})
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.Migrator().DropTable(&FileUpload{}, &Paste{})
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
317
internal/db/paste.go
Normal file
317
internal/db/paste.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gobmarsh "gitea.hashru.nl/dsprenkels/rushlink/pkg/gobmarsh"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PasteType describes the type of Paste (i.e. file, redirect, [...]).
|
||||
type PasteType int
|
||||
|
||||
// PasteState describes the state of a Paste (i.e. present, deleted, [...]).
|
||||
type PasteState int
|
||||
|
||||
// Paste describes the main Paste model in the database.
|
||||
type Paste struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Type PasteType `gorm:"index"`
|
||||
State PasteState `gorm:"index"`
|
||||
Content []byte
|
||||
Key string `gorm:"uniqueIndex"`
|
||||
DeleteToken string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt
|
||||
}
|
||||
|
||||
// ReservedPasteKeys keys are designated reserved, and will not be randomly chosen.
|
||||
var ReservedPasteKeys = []string{"xd42", "example"}
|
||||
|
||||
// Note: we use iota here. That means removals of PasteType* are not allowed,
|
||||
// because this changes the value of the constant. Please add the comment
|
||||
// "// deprecated" if you want to remove the constant. Additions are only
|
||||
// allowed at the bottom of this block, for the same reason.
|
||||
const (
|
||||
PasteTypeUndef PasteType = iota
|
||||
// PasteTypePaste is as of yet unused. It is still unclear if this type
|
||||
// will ever get a proper meaning.
|
||||
PasteTypePaste
|
||||
PasteTypeRedirect
|
||||
PasteTypeFileUpload
|
||||
)
|
||||
|
||||
// Note: we use iota here. See the comment above PasteType*
|
||||
const (
|
||||
PasteStateUndef PasteState = iota
|
||||
PasteStatePresent
|
||||
PasteStateDeleted
|
||||
)
|
||||
|
||||
// minKeyLen specifies the mimimum length of a paste key.
|
||||
const minKeyLen = 4
|
||||
|
||||
var (
|
||||
// ErrKeyInvalidChar occurs when a key contains an invalid character.
|
||||
ErrKeyInvalidChar = errors.New("invalid character in key")
|
||||
// ErrKeyInvalidLength occurs when a key embeds a length that is incorrect.
|
||||
ErrKeyInvalidLength = errors.New("key length encoding is incorrect")
|
||||
// ErrPasteDoesNotExist occurs when a key does not exist in the database.
|
||||
ErrPasteDoesNotExist = errors.New("url key not found in the database")
|
||||
)
|
||||
|
||||
// ErrHTTPStatusCode returns the HTTP status code that should correspond to
|
||||
// the provided error.
|
||||
// server error, or false if it is not.
|
||||
func ErrHTTPStatusCode(err error) int {
|
||||
switch err {
|
||||
case nil:
|
||||
return 0
|
||||
case gorm.ErrRecordNotFound, ErrKeyInvalidChar, ErrKeyInvalidLength, ErrPasteDoesNotExist:
|
||||
return http.StatusNotFound
|
||||
}
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// Base64 encoding and decoding
|
||||
var base64Alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
|
||||
var base64Encoder = base64.RawURLEncoding.WithPadding(base64.NoPadding)
|
||||
|
||||
func (t PasteType) String() string {
|
||||
switch t {
|
||||
case PasteTypeUndef:
|
||||
return "unknown"
|
||||
case PasteTypePaste:
|
||||
return "paste"
|
||||
case PasteTypeRedirect:
|
||||
return "redirect"
|
||||
case PasteTypeFileUpload:
|
||||
return "file"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
}
|
||||
|
||||
func (t PasteState) String() string {
|
||||
switch t {
|
||||
case PasteStateUndef:
|
||||
return "unknown"
|
||||
case PasteStatePresent:
|
||||
return "present"
|
||||
case PasteStateDeleted:
|
||||
return "deleted"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
}
|
||||
|
||||
// GetPaste retrieves a paste from the database.
|
||||
func GetPaste(db *gorm.DB, key string) (*Paste, error) {
|
||||
if err := ValidatePasteKey(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return GetPasteNoValidate(db, key)
|
||||
}
|
||||
|
||||
// ValidatePasteKey validates the format of the key that has
|
||||
func ValidatePasteKey(key string) error {
|
||||
internalLen := minKeyLen
|
||||
countingOnes := true
|
||||
for _, ch := range key {
|
||||
limb := strings.IndexRune(base64Alphabet, ch)
|
||||
if limb == -1 {
|
||||
return ErrKeyInvalidChar
|
||||
}
|
||||
for i := 5; i >= 0 && countingOnes; i-- {
|
||||
if (limb>>uint(i))&0x1 == 0 {
|
||||
countingOnes = false
|
||||
break
|
||||
}
|
||||
internalLen++
|
||||
}
|
||||
}
|
||||
if internalLen != len(key) {
|
||||
return ErrKeyInvalidLength
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPasteNoValidate retrieves a paste from the database without validating
|
||||
// the key format first.
|
||||
func GetPasteNoValidate(db *gorm.DB, key string) (*Paste, error) {
|
||||
var ps []Paste
|
||||
if err := db.Unscoped().Limit(1).Where("key = ?", key).Find(&ps).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ps) == 0 {
|
||||
return nil, ErrPasteDoesNotExist
|
||||
}
|
||||
return &ps[0], nil
|
||||
}
|
||||
|
||||
func decodePaste(storedBytes []byte) (*Paste, error) {
|
||||
p := &Paste{}
|
||||
err := gobmarsh.Unmarshal(storedBytes, p)
|
||||
return p, err
|
||||
}
|
||||
|
||||
// Save saves this Paste to the database.
|
||||
func (p *Paste) Save(db *gorm.DB) error {
|
||||
return db.Save(p).Error
|
||||
}
|
||||
|
||||
// Delete deletes this Paste from the database.
|
||||
func (p *Paste) Delete(db *gorm.DB, fs *FileStore) error {
|
||||
// Remove the (maybe) attached file
|
||||
if p.Type == PasteTypeFileUpload {
|
||||
fuID, err := uuid.FromBytes(p.Content)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to parse uuid")
|
||||
}
|
||||
fu, err := GetFileUpload(db, fuID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to find file in database")
|
||||
}
|
||||
if err := fu.Delete(db, fs); err != nil {
|
||||
return errors.Wrap(err, "failed to remove file")
|
||||
}
|
||||
}
|
||||
|
||||
// Wipe the old paste
|
||||
p.Type = PasteTypeUndef
|
||||
p.State = PasteStateDeleted
|
||||
p.Content = []byte{}
|
||||
if err := db.Save(&p).Error; err != nil {
|
||||
return errors.Wrap(err, "failed to wipe paste in database")
|
||||
}
|
||||
// Soft-delete the paste as well
|
||||
if err := db.Delete(&p).Error; err != nil {
|
||||
return errors.Wrap(err, "failed to delete paste in database")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RedirectURL returns the URL from this paste.
|
||||
//
|
||||
// This function assumes that the paste is valid. If the paste struct is
|
||||
// corrupted in some way, this function will panic.
|
||||
func (p *Paste) RedirectURL() *url.URL {
|
||||
if p.Type != PasteTypeRedirect {
|
||||
panic("expected p.Type to be PasteTypeRedirect")
|
||||
}
|
||||
rawurl := string(p.Content)
|
||||
urlParse, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
panic(errors.Wrapf(err, "invalid URL ('%v') in database for key '%v'", rawurl, p.Key))
|
||||
}
|
||||
return urlParse
|
||||
}
|
||||
|
||||
// GeneratePasteKey generates a new paste key. It will ensure that the newly
|
||||
// generated paste key does not already exist in the database.
|
||||
// The running time of this function is in O(log N), where N is the amount of
|
||||
// keys stored in the url-shorten database.
|
||||
// In tx, a Bolt transaction is given. Use minimumEntropy to set the mimimum
|
||||
// guessing entropy of the generated key.
|
||||
func GeneratePasteKey(db *gorm.DB, minimumEntropy int) (string, error) {
|
||||
epoch := 0
|
||||
var key string
|
||||
for {
|
||||
var err error
|
||||
key, err = generatePasteKeyInner(epoch, minimumEntropy)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "url-key generation failed")
|
||||
}
|
||||
|
||||
alreadyInUse := true
|
||||
var p Paste
|
||||
err = db.Unscoped().Where("key = ?", []byte(key)).First(&p).Error
|
||||
if err != nil && err == gorm.ErrRecordNotFound {
|
||||
alreadyInUse = false
|
||||
err = nil
|
||||
} else if err != nil {
|
||||
return "", errors.Wrap(err, "failed to check if key already exists")
|
||||
}
|
||||
|
||||
isReserved := false
|
||||
for _, reservedKey := range ReservedPasteKeys {
|
||||
if strings.HasPrefix(key, reservedKey) {
|
||||
isReserved = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !alreadyInUse && !isReserved {
|
||||
break
|
||||
}
|
||||
epoch++
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// generatePasteKeyInner generates a new paste key, but leaves the
|
||||
// uniqueness and is-reserved checks to the caller. That is, it only
|
||||
// generates a random key in the correct (syntactical) format.
|
||||
// Both epoch and entropy can be used to set the key length. Epoch is used
|
||||
// to prevent collisions in retrying to generate new keys. Entropy (in bits)
|
||||
// is used to ensure that a new key has at least some amount of guessing
|
||||
// entropy.
|
||||
func generatePasteKeyInner(epoch, entropy int) (string, error) {
|
||||
entropyEpoch := entropy
|
||||
entropyEpoch -= minKeyLen * 6 // First 4 characters provide 24 bits.
|
||||
entropyEpoch++ // One bit less because of '0' bit.
|
||||
entropyEpoch = (entropyEpoch-1)/5 + 1 // 5 bits for every added epoch.
|
||||
if epoch < entropyEpoch {
|
||||
epoch = entropyEpoch
|
||||
}
|
||||
urlKey := make([]byte, minKeyLen+epoch)
|
||||
_, err := rand.Read(urlKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Put all the values in the range 0..64 for easier base64-encoding
|
||||
for i := 0; i < len(urlKey); i++ {
|
||||
urlKey[i] &= 0x3F
|
||||
}
|
||||
// Implement truncate-resistance by forcing the prefix to
|
||||
// 0b111110xxxxxxxxxx
|
||||
// ^----- {epoch} ones followed by a single 0
|
||||
//
|
||||
// Example when epoch is 1: prefix is 0b10.
|
||||
i := 0
|
||||
for i < epoch {
|
||||
// Set this bit to 1
|
||||
limb := i / 6
|
||||
bit := i % 6
|
||||
urlKey[limb] |= 1 << uint(5-bit)
|
||||
i++
|
||||
}
|
||||
// Finally set the next bit to 0
|
||||
limb := i / 6
|
||||
bit := i % 6
|
||||
urlKey[limb] &= ^(1 << uint(5-bit))
|
||||
|
||||
// Convert this ID to a canonical base64 notation
|
||||
for i := range urlKey {
|
||||
urlKey[i] = base64Alphabet[urlKey[i]]
|
||||
}
|
||||
return string(urlKey), nil
|
||||
}
|
||||
|
||||
// GenerateDeleteToken generates a new (random) delete token.
|
||||
func GenerateDeleteToken() (string, error) {
|
||||
var deleteToken [16]byte
|
||||
_, err := rand.Read(deleteToken[:])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(deleteToken[:]), nil
|
||||
}
|
||||
38
internal/db/paste_test.go
Normal file
38
internal/db/paste_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package db
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidatePasteKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
key string
|
||||
wantErr bool
|
||||
errKind error
|
||||
}{
|
||||
{"xd42__", false, nil},
|
||||
{"xd42_*", true, ErrKeyInvalidChar},
|
||||
{"xd42_/", true, ErrKeyInvalidChar},
|
||||
{"xd42_=", true, ErrKeyInvalidChar},
|
||||
{"xd42_", true, ErrKeyInvalidLength},
|
||||
{"xd42", true, ErrKeyInvalidLength},
|
||||
{"xd4", true, ErrKeyInvalidLength},
|
||||
{"xd", true, ErrKeyInvalidLength},
|
||||
{"x", true, ErrKeyInvalidLength},
|
||||
{"", true, ErrKeyInvalidLength},
|
||||
|
||||
{"KoJ5", false, nil},
|
||||
|
||||
{"__dGSJIIbBpr-SD0", false, nil},
|
||||
{"__dGSJIIbBpr-SD", true, ErrKeyInvalidLength},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
err := ValidatePasteKey(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidatePasteKey() got error = %v, want error %v", err != nil, tt.wantErr)
|
||||
}
|
||||
if (err != nil) && err != tt.errKind {
|
||||
t.Errorf("ValidatePasteKey() error = %v, want errKind %v", err, tt.errKind)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user