use crate::diesel::BoolExpressionMethods; use crate::diesel::Connection; use crate::diesel::ExpressionMethods; use crate::diesel::QueryDsl; use argon2::password_hash::SaltString; use argon2::PasswordHash; use argon2::PasswordVerifier; use argon2::{ password_hash::{rand_core::OsRng, PasswordHasher}, Argon2, }; use diesel::dsl::count; use diesel::RunQueryDsl; use diesel_derive_enum::DbEnum; use rocket::{Build, Rocket}; use rocket_sync_db_pools::database; use serde::{Deserialize, Serialize}; use std::ops::Deref; use validator::{Validate, ValidationError}; #[database("gamenight_database")] pub struct DbConn(diesel::SqliteConnection); impl Deref for DbConn { type Target = rocket_sync_db_pools::Connection; fn deref(&self) -> &Self::Target { &self.0 } } table! { gamenight (id) { id -> Integer, game -> Text, datetime -> Text, } } table! { known_games (game) { id -> Integer, game -> Text, } } table! { use diesel::sql_types::Integer; use diesel::sql_types::Text; use super::RoleMapping; user(id) { id -> Integer, username -> Text, email -> Text, role -> RoleMapping, } } table! { pwd(id) { id -> Integer, password -> Text, } } allow_tables_to_appear_in_same_query!(gamenight, known_games,); pub enum DatabaseError { Hash(password_hash::Error), Query(String), } impl std::fmt::Display for DatabaseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { match self { DatabaseError::Hash(err) => write!(f, "{}", err), DatabaseError::Query(err) => write!(f, "{}", err), } } } pub async fn get_all_gamenights(conn: DbConn) -> Vec { conn.run(|c| gamenight::table.load::(c).unwrap()) .await } pub async fn insert_gamenight(conn: DbConn, new_gamenight: GameNightNoId) -> () { conn.run(|c| { diesel::insert_into(gamenight::table) .values(new_gamenight) .execute(c) .unwrap() }) .await; } pub async fn insert_user(conn: DbConn, new_user: Register) -> Result<(), DatabaseError> { let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); let password_hash = match argon2.hash_password(new_user.password.as_bytes(), &salt) { Ok(hash) => hash.to_string(), Err(error) => return Err(DatabaseError::Hash(error)), }; let user_insert_result = conn .run(move |c| { c.transaction(|| { diesel::insert_into(user::table) .values(( user::username.eq(&new_user.username), user::email.eq(&new_user.email), user::role.eq(Role::User), )) .execute(c)?; let ids: Vec = match user::table .filter( user::username .eq(&new_user.username) .and(user::email.eq(&new_user.email)), ) .select(user::id) .get_results(c) { Ok(id) => id, Err(e) => return Err(e), }; diesel::insert_into(pwd::table) .values((pwd::id.eq(ids[0]), pwd::password.eq(&password_hash))) .execute(c) }) }) .await; match user_insert_result { Err(e) => Err(DatabaseError::Query(e.to_string())), _ => Ok(()), } } pub async fn login(conn: DbConn, login: Login) -> Result { conn.run(move |c| -> Result { let id: i32 = match user::table .filter(user::username.eq(&login.username)) .or_filter(user::email.eq(&login.username)) .select(user::id) .get_results(c) { Ok(id) => id[0], Err(error) => return Err(DatabaseError::Query(error.to_string())), }; let pwd: String = match pwd::table .filter(pwd::id.eq(id)) .select(pwd::password) .get_results::(c) { Ok(pwd) => pwd[0].clone(), Err(error) => return Err(DatabaseError::Query(error.to_string())), }; let parsed_hash = match PasswordHash::new(&pwd) { Ok(hash) => hash, Err(error) => return Err(DatabaseError::Hash(error)), }; if Argon2::default() .verify_password(&login.password.as_bytes(), &parsed_hash) .is_ok() { let role: Role = match user::table .filter(user::id.eq(id)) .select(user::role) .get_results::(c) { Ok(role) => role[0].clone(), Err(error) => return Err(DatabaseError::Query(error.to_string())), }; Ok(LoginResult { result: true, id: Some(id), role: Some(role), }) } else { Ok(LoginResult { result: false, id: None, role: None, }) } }) .await } pub async fn get_user(conn: DbConn, id: i32) -> User { conn.run(move |c| user::table.filter(user::id.eq(id)).first(c).unwrap()) .await } pub fn unique_username( username: &String, conn: &diesel::SqliteConnection, ) -> Result<(), ValidationError> { match user::table .select(count(user::username)) .filter(user::username.eq(username)) .execute(conn) { Ok(0) => Ok(()), Ok(_) => Err(ValidationError::new("User already exists")), Err(_) => Err(ValidationError::new("Database error while validating user")), } } pub fn unique_email( email: &String, conn: &diesel::SqliteConnection, ) -> Result<(), ValidationError> { match user::table .select(count(user::email)) .filter(user::email.eq(email)) .execute(conn) { Ok(0) => Ok(()), Ok(_) => Err(ValidationError::new("email already exists")), Err(_) => Err(ValidationError::new( "Database error while validating email", )), } } pub async fn run_migrations(rocket: Rocket) -> Rocket { // This macro from `diesel_migrations` defines an `embedded_migrations` // module containing a function named `run`. This allows the example to be // run and tested without any outside setup of the database. embed_migrations!(); let conn = DbConn::get_one(&rocket).await.expect("database connection"); conn.run(|c| embedded_migrations::run(c)) .await .expect("can run migrations"); rocket } #[derive(Debug, Serialize, Deserialize, DbEnum, Clone)] pub enum Role { Admin, User, } #[derive(Serialize, Deserialize, Debug, Insertable, Queryable)] #[table_name = "user"] pub struct User { pub id: i32, pub username: String, pub email: String, pub role: Role, } #[derive(Serialize, Deserialize, Debug, FromForm, Insertable)] #[table_name = "known_games"] pub struct GameNoId { pub game: String, } #[derive(Serialize, Deserialize, Debug, FromForm, Queryable)] pub struct Game { pub id: i32, pub game: String, } #[derive(Serialize, Deserialize, Debug, FromForm, Insertable)] #[table_name = "gamenight"] pub struct GameNightNoId { pub game: String, pub datetime: String, } #[derive(Serialize, Deserialize, Debug, FromForm, Queryable)] pub struct GameNight { pub id: i32, pub game: String, pub datetime: String, } #[derive(Serialize, Deserialize, Debug, Validate, Clone)] pub struct Register { #[validate( length(min = 1), custom(function = "unique_username", arg = "&'v_a diesel::SqliteConnection") )] pub username: String, #[validate( email, custom(function = "unique_email", arg = "&'v_a diesel::SqliteConnection") )] pub email: String, #[validate(length(min = 10), must_match = "password_repeat")] pub password: String, pub password_repeat: String, } #[derive(Serialize, Deserialize, Debug)] pub struct Login { pub username: String, pub password: String, } #[derive(Serialize, Deserialize, Debug)] pub struct LoginResult { pub result: bool, pub id: Option, pub role: Option, }