diff --git a/youmubot-db-sql/src/lib.rs b/youmubot-db-sql/src/lib.rs index fd0c3b0..b15939e 100644 --- a/youmubot-db-sql/src/lib.rs +++ b/youmubot-db-sql/src/lib.rs @@ -36,9 +36,27 @@ pub mod errors { #[derive(thiserror::Error, Debug)] pub enum Error { #[error("sqlx error: {:?}", .0)] - SQLx(#[from] sqlx::Error), + SQLx(sqlx::Error), #[error("sqlx migration error: {:?}", .0)] Migration(#[from] sqlx::migrate::MigrateError), + #[error("values already existed for: {}", .0)] + Duplicate(String), + } + + impl From for Error { + fn from(value: sqlx::Error) -> Self { + match value { + // if we can match a constraint error, give it a special case. + sqlx::Error::Database(database_error) => { + let msg = database_error.message(); + match msg.strip_prefix("UNIQUE constraint failed: ") { + Some(con) => Error::Duplicate(con.to_owned()), + None => Error::SQLx(sqlx::Error::Database(database_error)), + } + } + e => Error::SQLx(e), + } + } } } diff --git a/youmubot-osu/src/discord/commands.rs b/youmubot-osu/src/discord/commands.rs index ac6f150..c21019d 100644 --- a/youmubot-osu/src/discord/commands.rs +++ b/youmubot-osu/src/discord/commands.rs @@ -186,21 +186,31 @@ pub async fn save( CreateReply::default() .content(save_request_message(&u.username, score.beatmap_id, mode)) .embed(beatmap_embed(&beatmap, mode, Mods::NOMOD, &info)) - .components(vec![beatmap_components(mode, ctx.guild_id())]), + .components(vec![ + beatmap_components(mode, ctx.guild_id()), + save_button(), + ]), ) - .await? - .into_message() .await?; - handle_save_respond( + let mut p = (reply, ctx.clone()); + match handle_save_respond( ctx.serenity_context(), &env, ctx.author().id, - reply, + &mut p, &beatmap, u, mode, ) - .await?; + .await + { + Ok(_) => (), + Err(e) => { + p.0.delete(ctx).await?; + return Err(e.into()); + } + }; + Ok(()) } diff --git a/youmubot-osu/src/discord/db.rs b/youmubot-osu/src/discord/db.rs index 8896bbe..eaa5501 100644 --- a/youmubot-osu/src/discord/db.rs +++ b/youmubot-osu/src/discord/db.rs @@ -58,7 +58,14 @@ impl OsuSavedUsers { let mut t = self.pool.begin().await?; model::OsuUser::delete(u.user_id.get() as i64, &mut *t).await?; assert!( - model::OsuUser::from(u).store(&mut t).await?, + match model::OsuUser::from(u).store(&mut t).await { + Ok(v) => v, + Err(youmubot_db_sql::Error::Duplicate(_)) => + return Err(Error::msg( + "another Discord user has already saved your account with the same id!" + )), + Err(e) => return Err(e.into()), + }, "Should be updated" ); t.commit().await?; diff --git a/youmubot-osu/src/discord/mod.rs b/youmubot-osu/src/discord/mod.rs index d60e139..30c10c8 100644 --- a/youmubot-osu/src/discord/mod.rs +++ b/youmubot-osu/src/discord/mod.rs @@ -8,8 +8,8 @@ use link_parser::EmbedType; use oppai_cache::BeatmapInfoWithPP; use rand::seq::IteratorRandom; use serenity::{ - builder::{CreateMessage, EditMessage}, - collector, + all::{CreateActionRow, CreateButton}, + builder::CreateMessage, framework::standard::{ macros::{command, group}, Args, CommandResult, @@ -257,10 +257,17 @@ pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult CreateMessage::new() .content(save_request_message(&u.username, score.beatmap_id, mode)) .embed(beatmap_embed(&beatmap, mode, Mods::NOMOD, &info)) - .components(vec![beatmap_components(mode, msg.guild_id)]), + .components(vec![beatmap_components(mode, msg.guild_id), save_button()]), ) .await?; - handle_save_respond(ctx, &env, msg.author.id, reply, &beatmap, u, mode).await?; + let mut p = (reply, ctx); + match handle_save_respond(ctx, &env, msg.author.id, &mut p, &beatmap, u, mode).await { + Ok(_) => (), + Err(e) => { + p.0.delete(&ctx).await?; + return Err(e.into()); + } + }; Ok(()) } @@ -325,11 +332,18 @@ pub(crate) async fn find_save_requirements( Ok((u, mode, score, beatmap, info)) } +const SAVE_BUTTON: &str = "youmubot::osu::save"; +pub(crate) fn save_button() -> CreateActionRow { + CreateActionRow::Buttons(vec![CreateButton::new(SAVE_BUTTON) + .label("I'm done!") + .emoji('👌') + .style(serenity::all::ButtonStyle::Primary)]) +} pub(crate) async fn handle_save_respond( ctx: &Context, env: &OsuEnv, sender: serenity::all::UserId, - mut reply: Message, + reply: &mut impl CanEdit, beatmap: &Beatmap, user: crate::models::User, mode: Mode, @@ -343,50 +357,36 @@ pub(crate) async fn handle_save_respond( .take(1) .any(|s| s.beatmap_id == map_id)) } - let reaction = reply.react(&ctx, '👌').await?; + let msg_id = reply.get_message().await?.id; + let recv = InteractionCollector::create(&ctx, msg_id).await?; + let timeout = std::time::Duration::from_secs(300) + beatmap.difficulty.total_length; let completed = loop { - let emoji = reaction.emoji.clone(); - let user_reaction = collector::ReactionCollector::new(ctx) - .message_id(reply.id) - .author_id(sender) - .filter(move |r| r.emoji == emoji) - .timeout(std::time::Duration::from_secs(300) + beatmap.difficulty.total_length) - .next() - .await; - if let Some(ur) = user_reaction { - if check(osu_client, &user, mode, beatmap.beatmap_id).await? { - break true; - } - ur.delete(&ctx).await?; - } else { + let Some(reaction) = recv.next(timeout).await else { break false; + }; + if reaction == SAVE_BUTTON && check(osu_client, &user, mode, beatmap.beatmap_id).await? { + break true; } }; if !completed { reply - .edit( - &ctx, - EditMessage::new() + .apply_edit( + CreateReply::default() .content(format!( "Setting username to **{}** failed due to timeout. Please try again!", user.username )) - .embeds(vec![]) .components(vec![]), ) .await?; - reaction.delete(&ctx).await?; return Ok(()); } add_user(sender, &user, &env).await?; let ex = UserExtras::from_user(env, &user, mode).await?; reply - .channel_id - .send_message( - &ctx, - CreateMessage::new() - .reference_message(&reply) + .apply_edit( + CreateReply::default() .content( MessageBuilder::new() .push("Youmu is now tracking user ") @@ -395,7 +395,8 @@ pub(crate) async fn handle_save_respond( .push(user.mention().to_string()) .build(), ) - .add_embed(user_embed(user.clone(), ex)), + .embed(user_embed(user.clone(), ex)) + .components(vec![]), ) .await?; Ok(()) diff --git a/youmubot-prelude/src/interaction_collector.rs b/youmubot-prelude/src/interaction_collector.rs index 3de42ac..09007ad 100644 --- a/youmubot-prelude/src/interaction_collector.rs +++ b/youmubot-prelude/src/interaction_collector.rs @@ -3,7 +3,7 @@ use serenity::{ all::{CreateInteractionResponse, Interaction, MessageId}, prelude::TypeMapKey, }; -use std::{ops::Deref, sync::Arc}; +use std::sync::Arc; #[derive(Debug, Clone)] /// Handles distributing interaction to the handlers. @@ -13,16 +13,25 @@ pub struct InteractionCollector { /// Wraps the interfaction receiver channel, automatically cleaning up upon drop. #[derive(Debug)] -struct InteractionCollectorGuard { +pub struct InteractionCollectorGuard { msg_id: MessageId, ch: flume::Receiver, collector: InteractionCollector, } -impl Deref for InteractionCollectorGuard { - type Target = flume::Receiver; +impl InteractionCollectorGuard { + /// Returns the next fetched interaction, with the given timeout. + pub async fn next(&self, timeout: std::time::Duration) -> Option { + match tokio::time::timeout(timeout, self.ch.clone().into_recv_async()).await { + Err(_) => None, + Ok(Err(_)) => None, + Ok(Ok(interaction)) => Some(interaction), + } + } +} - fn deref(&self) -> &Self::Target { +impl AsRef> for InteractionCollectorGuard { + fn as_ref(&self) -> &flume::Receiver { &self.ch } } @@ -40,7 +49,7 @@ impl InteractionCollector { } } /// Create a new collector, returning a receiver. - pub fn create_collector(&self, msg: MessageId) -> impl Deref> { + pub fn create_collector(&self, msg: MessageId) -> InteractionCollectorGuard { let (send, recv) = flume::unbounded(); self.channels.insert(msg.clone(), send); InteractionCollectorGuard { @@ -51,10 +60,7 @@ impl InteractionCollector { } /// Create a new collector, returning a receiver. - pub(crate) async fn create( - ctx: &Context, - msg: MessageId, - ) -> Result>> { + pub async fn create(ctx: &Context, msg: MessageId) -> Result { Ok(ctx .data .read() diff --git a/youmubot-prelude/src/pagination.rs b/youmubot-prelude/src/pagination.rs index 43a9400..19d0808 100644 --- a/youmubot-prelude/src/pagination.rs +++ b/youmubot-prelude/src/pagination.rs @@ -8,12 +8,6 @@ use serenity::{ builder::CreateMessage, model::{channel::Message, id::ChannelId}, }; -use tokio::time as tokio_time; - -// const ARROW_RIGHT: &str = "➡️"; -// const ARROW_LEFT: &str = "⬅️"; -// const REWIND: &str = "⏪"; -// const FAST_FORWARD: &str = "⏩"; const NEXT: &str = "youmubot_pagination_next"; const PREV: &str = "youmubot_pagination_prev"; @@ -269,10 +263,9 @@ pub async fn paginate_with_first_message( // Loop the handler function. let res: Result<()> = loop { - match tokio_time::timeout(timeout, recv.clone().into_recv_async()).await { - Err(_) => break Ok(()), - Ok(Err(_)) => break Ok(()), - Ok(Ok(reaction)) => { + match recv.next(timeout).await { + None => break Ok(()), + Some(reaction) => { page = match pager .handle_reaction(page, ctx, &mut message, &reaction) .await