2023-04-30 12:03:23 +02:00
|
|
|
package db
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/rand"
|
2023-04-30 21:08:36 +02:00
|
|
|
"crypto/subtle"
|
2023-04-30 12:03:23 +02:00
|
|
|
"encoding/base64"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
2023-04-30 18:16:52 +02:00
|
|
|
"strings"
|
2023-04-30 12:03:23 +02:00
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"golang.org/x/crypto/argon2"
|
|
|
|
"gorm.io/gorm"
|
|
|
|
)
|
|
|
|
|
|
|
|
type User struct {
|
|
|
|
ID uint `gorm:"primaryKey"`
|
|
|
|
User string `gorm:"uniqueIndex"`
|
|
|
|
Password string
|
|
|
|
Admin bool
|
|
|
|
CreatedAt time.Time
|
|
|
|
UpdatedAt time.Time
|
|
|
|
DeletedAt gorm.DeletedAt
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewUser(db *gorm.DB, username string, password string, admin bool) error {
|
|
|
|
// Generate a new UUID for the user
|
|
|
|
id, err := uuid.NewRandom()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Hash the password using argon2id
|
|
|
|
hashedPassword, err := HashPassword(password)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create a new user record
|
|
|
|
user := &User{
|
|
|
|
ID: uint(id.ID()),
|
|
|
|
User: username,
|
|
|
|
Password: hashedPassword,
|
|
|
|
Admin: admin,
|
|
|
|
}
|
|
|
|
return db.Create(user).Error
|
|
|
|
}
|
|
|
|
|
|
|
|
func Authenticate(db *gorm.DB, username string, password string) (*User, error) {
|
|
|
|
// Get the user record by username
|
|
|
|
var user User
|
|
|
|
if err := db.Where("user = ?", username).First(&user).Error; err != nil {
|
|
|
|
return nil, errors.New("user not found")
|
|
|
|
}
|
|
|
|
|
|
|
|
// Compare the hashed password with the provided password
|
2023-04-30 18:16:52 +02:00
|
|
|
valid, err := comparePassword(user.Password, password)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
if !valid {
|
2023-04-30 12:03:23 +02:00
|
|
|
return nil, errors.New("invalid password")
|
|
|
|
}
|
|
|
|
|
|
|
|
return &user, nil
|
|
|
|
}
|
|
|
|
|
2023-04-30 18:16:52 +02:00
|
|
|
const (
|
|
|
|
pwdSaltSize = 16
|
|
|
|
pwdHashSize = 32
|
|
|
|
pwdParams = "m=65536,t=2,p=1"
|
|
|
|
pwdAlgo = "argon2id"
|
|
|
|
)
|
|
|
|
|
2023-04-30 12:03:23 +02:00
|
|
|
func HashPassword(password string) (string, error) {
|
|
|
|
// Generate a salt for the password hash
|
2023-04-30 18:16:52 +02:00
|
|
|
salt := make([]byte, pwdSaltSize)
|
2023-04-30 12:03:23 +02:00
|
|
|
if _, err := rand.Read(salt); err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Hash the password using argon2id
|
2023-04-30 18:16:52 +02:00
|
|
|
hash := argon2.IDKey([]byte(password), salt, 2, 64*1024, 1, pwdHashSize)
|
2023-04-30 12:03:23 +02:00
|
|
|
|
2023-04-30 18:16:52 +02:00
|
|
|
// Encode the salt and hash as a string in PHC format
|
|
|
|
encodedSalt := base64.RawStdEncoding.EncodeToString(salt)
|
|
|
|
encodedHash := base64.RawStdEncoding.EncodeToString(hash)
|
|
|
|
return fmt.Sprintf("$%s$%s$%s$%s", pwdAlgo, pwdParams, encodedSalt, encodedHash), nil
|
2023-04-30 12:03:23 +02:00
|
|
|
}
|
|
|
|
|
2023-04-30 21:20:41 +02:00
|
|
|
var errInvalidDBPasswordFormat = errors.New("invalid password format in db")
|
2023-04-30 18:16:52 +02:00
|
|
|
func comparePassword(hashedPassword string, password string) (bool, error) {
|
|
|
|
// Extract the salt and hash from the hashed password string
|
2023-04-30 21:20:41 +02:00
|
|
|
fields := strings.Split(hashedPassword, "$")
|
|
|
|
if len(fields) != 5 || fields[1] != pwdAlgo || fields[2] != pwdParams {
|
|
|
|
return false, errInvalidDBPasswordFormat
|
2023-04-30 18:16:52 +02:00
|
|
|
}
|
2023-04-30 21:20:41 +02:00
|
|
|
encodedSalt, encodedHash := fields[3], fields[4]
|
2023-04-30 18:16:52 +02:00
|
|
|
|
|
|
|
// Decode the salt and hash from base64
|
|
|
|
salt, err := base64.RawStdEncoding.DecodeString(encodedSalt)
|
|
|
|
if err != nil {
|
|
|
|
return false, err
|
|
|
|
}
|
|
|
|
hash, err := base64.RawStdEncoding.DecodeString(encodedHash)
|
|
|
|
if err != nil {
|
|
|
|
return false, err
|
|
|
|
}
|
2023-04-30 12:03:23 +02:00
|
|
|
|
2023-04-30 18:16:52 +02:00
|
|
|
// Hash the password using the extracted salt and parameters
|
|
|
|
computedHash := argon2.IDKey([]byte(password), salt, 2, 64*1024, 1, pwdHashSize)
|
2023-04-30 12:03:23 +02:00
|
|
|
|
|
|
|
// Compare the computed hash with the stored hash
|
2023-04-30 21:08:36 +02:00
|
|
|
return subtle.ConstantTimeCompare(hash, computedHash) == 1, nil
|
2023-04-30 12:03:23 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
// DeleteUser deletes a user with the specified username from the database.
|
|
|
|
func DeleteUser(db *gorm.DB, username string) error {
|
|
|
|
// Find the user by username
|
|
|
|
var user User
|
|
|
|
if err := db.Where("user = ?", username).First(&user).Error; err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Delete the user
|
|
|
|
if err := db.Delete(&user).Error; err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func ChangeUser(db *gorm.DB, username string, updatedUser *User) error {
|
|
|
|
// Retrieve the existing user
|
|
|
|
var existingUser User
|
|
|
|
if err := db.Where("user = ?", username).First(&existingUser).Error; err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Update the user fields
|
|
|
|
existingUser.User = updatedUser.User
|
|
|
|
existingUser.Password = updatedUser.Password
|
|
|
|
existingUser.Admin = updatedUser.Admin
|
|
|
|
|
|
|
|
// Save the updated user to the database
|
|
|
|
if err := db.Save(&existingUser).Error; err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func CreateAdminUser(db *gorm.DB, adminUsername string) error {
|
|
|
|
// Check if the admin user already exists
|
|
|
|
var admins []User
|
|
|
|
if err := db.Unscoped().Limit(1).Where("user = ?", adminUsername).Find(&admins).Error; err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if len(admins) > 0 {
|
|
|
|
// already exists
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Generate a random 24-char password
|
|
|
|
passwordBytes := make([]byte, 24)
|
|
|
|
if _, err := rand.Read(passwordBytes); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
password := base64.URLEncoding.EncodeToString(passwordBytes)
|
|
|
|
|
|
|
|
// Create the admin user
|
|
|
|
if err := NewUser(db, adminUsername, password, true); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Print out the generated password
|
|
|
|
fmt.Printf("Generated password for admin user %s: %s\n", adminUsername, password)
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|