From bcd59c673ca89887219e0b1cedb023cc23be4cac Mon Sep 17 00:00:00 2001 From: Natsu Kagami Date: Tue, 27 Feb 2024 00:28:13 +0100 Subject: [PATCH] Move Pagination to some generic replyable trait --- Cargo.lock | 3 + youmubot-cf/Cargo.toml | 1 + youmubot-cf/src/lib.rs | 19 ++- youmubot-core/Cargo.toml | 1 + youmubot-core/src/community/roles.rs | 12 +- youmubot-core/src/fun/images.rs | 25 ++- youmubot-osu/Cargo.toml | 1 + youmubot-osu/src/discord/announcer.rs | 10 +- youmubot-osu/src/discord/app_commands.rs | 20 ++- youmubot-osu/src/discord/args.rs | 70 ++++++++ youmubot-osu/src/discord/cache.rs | 13 +- youmubot-osu/src/discord/display.rs | 198 ++++++++++------------- youmubot-osu/src/discord/hook.rs | 5 +- youmubot-osu/src/discord/mod.rs | 125 ++++++++++---- youmubot-osu/src/discord/server_rank.rs | 40 ++--- youmubot-osu/src/models/mods.rs | 14 +- youmubot-prelude/src/announcer.rs | 2 +- youmubot-prelude/src/lib.rs | 2 +- youmubot-prelude/src/pagination.rs | 139 ++++++++++------ youmubot-prelude/src/replyable.rs | 76 +++++++++ 20 files changed, 509 insertions(+), 267 deletions(-) create mode 100644 youmubot-osu/src/discord/args.rs create mode 100644 youmubot-prelude/src/replyable.rs diff --git a/Cargo.lock b/Cargo.lock index ec8b79c..19c0331 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3213,6 +3213,7 @@ dependencies = [ "dashmap", "lazy_static", "log", + "poise", "regex", "reqwest", "serde", @@ -3230,6 +3231,7 @@ dependencies = [ "dashmap", "flume 0.10.14", "futures-util", + "poise", "rand", "serde", "serenity", @@ -3280,6 +3282,7 @@ dependencies = [ "serde", "serde_json", "serenity", + "thiserror", "time", "youmubot-db", "youmubot-db-sql", diff --git a/youmubot-cf/Cargo.toml b/youmubot-cf/Cargo.toml index a3a590d..230fdfa 100644 --- a/youmubot-cf/Cargo.toml +++ b/youmubot-cf/Cargo.toml @@ -10,6 +10,7 @@ serde = { version = "1.0.137", features = ["derive"] } tokio = { version = "1.19.2", features = ["time"] } reqwest = "0.11.10" serenity = "0.12" +poise = "0.6" Inflector = "0.11.4" codeforces = "0.3.1" regex = "1.5.6" diff --git a/youmubot-cf/src/lib.rs b/youmubot-cf/src/lib.rs index 2cbae58..a5ec0fc 100644 --- a/youmubot-cf/src/lib.rs +++ b/youmubot-cf/src/lib.rs @@ -1,4 +1,5 @@ use codeforces::Contest; +use poise::CreateReply; use serenity::{ builder::{CreateMessage, EditMessage}, framework::standard::{ @@ -173,14 +174,14 @@ pub async fn ranks(ctx: &Context, m: &Message) -> CommandResult { let last_updated = ranks.iter().map(|(_, cfu)| cfu.last_update).min().unwrap(); paginate_reply_fn( - move |page, ctx, msg| { + move |page, _| { let ranks = ranks.clone(); Box::pin(async move { let page = page as usize; let start = ITEMS_PER_PAGE * page; let end = ranks.len().min(start + ITEMS_PER_PAGE); if start >= end { - return Ok(false); + return Ok(None); } let ranks = &ranks[start..end]; @@ -233,12 +234,11 @@ pub async fn ranks(ctx: &Context, m: &Message) -> CommandResult { last_updated.to_rfc2822() )); - msg.edit(ctx, EditMessage::new().content(m.build())).await?; - Ok(true) + Ok(Some(CreateReply::default().content(m.build()))) }) }, ctx, - m, + m.clone(), std::time::Duration::from_secs(60), ) .await?; @@ -328,7 +328,7 @@ pub(crate) async fn contest_rank_table( let ranks = Arc::new(ranks); paginate_reply_fn( - move |page, ctx, msg| { + move |page, ctx| { let contest = contest.clone(); let problems = problems.clone(); let ranks = ranks.clone(); @@ -337,7 +337,7 @@ pub(crate) async fn contest_rank_table( let start = page * ITEMS_PER_PAGE; let end = ranks.len().min(start + ITEMS_PER_PAGE); if start >= end { - return Ok(false); + return Ok(None); } let ranks = &ranks[start..end]; let hw = ranks @@ -412,12 +412,11 @@ pub(crate) async fn contest_rank_table( .push_line(contest.url()) .push_codeblock(table.build(), None) .push_line(format!("Page **{}/{}**", page + 1, total_pages)); - msg.edit(ctx, EditMessage::new().content(m.build())).await?; - Ok(true) + Ok(Some(CreateReply::default().content(m.build()))) }) }, ctx, - reply_to, + reply_to.clone(), Duration::from_secs(60), ) .await diff --git a/youmubot-core/Cargo.toml b/youmubot-core/Cargo.toml index e1cda78..ec61418 100644 --- a/youmubot-core/Cargo.toml +++ b/youmubot-core/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] serenity = { version = "0.12", features = ["collector"] } +poise = "0.6" rand = "0.8.5" serde = { version = "1.0.137", features = ["derive"] } chrono = "0.4.19" diff --git a/youmubot-core/src/community/roles.rs b/youmubot-core/src/community/roles.rs index 8d340e6..fd2db20 100644 --- a/youmubot-core/src/community/roles.rs +++ b/youmubot-core/src/community/roles.rs @@ -1,6 +1,6 @@ use crate::db::Roles as DB; +use poise::CreateReply; use serenity::{ - builder::EditMessage, framework::standard::{macros::command, Args, CommandResult}, model::{ channel::{Message, ReactionType}, @@ -41,14 +41,14 @@ async fn list(ctx: &Context, m: &Message, _: Args) -> CommandResult { let pages = (roles.len() + ROLES_PER_PAGE - 1) / ROLES_PER_PAGE; paginate_reply_fn( - |page, ctx, msg| { + |page, _| { let roles = roles.clone(); Box::pin(async move { let page = page as usize; let start = page * ROLES_PER_PAGE; let end = roles.len().min(start + ROLES_PER_PAGE); if end <= start { - return Ok(false); + return Ok(None); } let roles = &roles[start..end]; let nw = roles // name width @@ -101,13 +101,11 @@ async fn list(ctx: &Context, m: &Message, _: Args) -> CommandResult { m.push_line("```"); m.push(format!("Page **{}/{}**", page + 1, pages)); - msg.edit(ctx, EditMessage::new().content(m.to_string())) - .await?; - Ok(true) + Ok(Some(CreateReply::default().content(m.to_string()))) }) }, ctx, - m, + m.clone(), std::time::Duration::from_secs(60 * 10), ) .await?; diff --git a/youmubot-core/src/fun/images.rs b/youmubot-core/src/fun/images.rs index 51ff4e7..59951d5 100644 --- a/youmubot-core/src/fun/images.rs +++ b/youmubot-core/src/fun/images.rs @@ -1,3 +1,4 @@ +use poise::CreateReply; use serde::Deserialize; use serenity::builder::EditMessage; use serenity::framework::standard::CommandError as Error; @@ -66,30 +67,24 @@ async fn message_command( } let images = std::sync::Arc::new(images); paginate_reply_fn( - move |page, ctx, msg: &mut Message| { + move |page, _| { let images = images.clone(); Box::pin(async move { let page = page as usize; if page >= images.len() { - Ok(false) + Ok(None) } else { - msg.edit( - ctx, - EditMessage::new().content(format!( - "[🖼️ **{}/{}**] Here's the image you requested!\n\n{}", - page + 1, - images.len(), - images[page] - )), - ) - .await - .map(|_| true) - .map_err(|e| e.into()) + Ok(Some(CreateReply::default().content(format!( + "[🖼️ **{}/{}**] Here's the image you requested!\n\n{}", + page + 1, + images.len(), + images[page] + )))) } }) }, ctx, - msg, + msg.clone(), std::time::Duration::from_secs(120), ) .await?; diff --git a/youmubot-osu/Cargo.toml b/youmubot-osu/Cargo.toml index f42ad13..9736316 100644 --- a/youmubot-osu/Cargo.toml +++ b/youmubot-osu/Cargo.toml @@ -23,6 +23,7 @@ serenity = "0.12" poise = "0.6" zip = "0.6.2" rand = "0.8" +thiserror = "1" youmubot-db = { path = "../youmubot-db" } youmubot-db-sql = { path = "../youmubot-db-sql" } diff --git a/youmubot-osu/src/discord/announcer.rs b/youmubot-osu/src/discord/announcer.rs index e88f2bd..c23150d 100644 --- a/youmubot-osu/src/discord/announcer.rs +++ b/youmubot-osu/src/discord/announcer.rs @@ -320,9 +320,13 @@ impl<'a> CollectedScore<'a> { }), ) .await?; - save_beatmap(&*ctx.data.read().await, channel, bm) - .await - .pls_ok(); + save_beatmap( + ctx.data.read().await.get::().unwrap(), + channel, + bm, + ) + .await + .pls_ok(); Ok(m) } } diff --git a/youmubot-osu/src/discord/app_commands.rs b/youmubot-osu/src/discord/app_commands.rs index 04402c6..9ee2f8e 100644 --- a/youmubot-osu/src/discord/app_commands.rs +++ b/youmubot-osu/src/discord/app_commands.rs @@ -1,9 +1,21 @@ +use serenity::all::Member; use youmubot_prelude::*; +use crate::{discord::args::ScoreDisplay, models::Mods}; + +#[poise::command(slash_command, subcommands("check"))] +pub async fn osu + Sync>(_ctx: CmdContext<'_, T>) -> Result<(), Error> { + Ok(()) +} + #[poise::command(slash_command)] -pub async fn example + Sync>( - context: poise::Context<'_, T, Error>, - arg: String, -) -> Result<(), Error> { +/// Check your/someone's score on the last beatmap in the channel +async fn check + Sync>( + ctx: CmdContext<'_, T>, + #[description = "Pass an osu! username to check for scores"] osu_id: Option, + #[description = "Pass a member of the guild to check for scores"] member: Option, + #[description = "Filter mods that should appear in the scores returned"] mods: Option, + #[description = "Score display style"] style: Option, +) -> Result<()> { todo!() } diff --git a/youmubot-osu/src/discord/args.rs b/youmubot-osu/src/discord/args.rs new file mode 100644 index 0000000..1eeeddd --- /dev/null +++ b/youmubot-osu/src/discord/args.rs @@ -0,0 +1,70 @@ +use serenity::all::Message; +use youmubot_prelude::*; + +// One of the interaction sources. +pub enum InteractionSrc<'a, 'c: 'a, T, E> { + Serenity(&'a Message), + Poise(&'a poise::Context<'c, T, E>), +} + +impl<'a, 'c, T, E> InteractionSrc<'a, 'c, T, E> { + pub async fn reply(&self, ctx: &Context, msg: impl Into) -> Result { + Ok(match self { + InteractionSrc::Serenity(m) => m.reply(ctx, msg).await?, + InteractionSrc::Poise(ctx) => ctx.reply(msg).await?.message().await?.into_owned(), + }) + } +} + +impl<'a, 'c, T, E> From<&'a Message> for InteractionSrc<'a, 'c, T, E> { + fn from(value: &'a Message) -> Self { + Self::Serenity(value) + } +} + +impl<'a, 'c, T, E> From<&'a poise::Context<'c, T, E>> for InteractionSrc<'a, 'c, T, E> { + fn from(value: &'a poise::Context<'c, T, E>) -> Self { + Self::Poise(value) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, poise::ChoiceParameter, Default)] +pub enum ScoreDisplay { + #[name = "table"] + #[default] + Table, + #[name = "grid"] + Grid, +} + +impl std::str::FromStr for ScoreDisplay { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s { + "--table" => Ok(Self::Table), + "--grid" => Ok(Self::Grid), + _ => Err(Error::unknown(s)), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("unknown value: {0}")] + UnknownValue(String), + #[error("parse error: {0}")] + Custom(String), +} + +impl Error { + fn unknown(s: impl AsRef) -> Self { + Self::UnknownValue(s.as_ref().to_owned()) + } +} + +impl From for Error { + fn from(value: String) -> Self { + Error::Custom(value) + } +} diff --git a/youmubot-osu/src/discord/cache.rs b/youmubot-osu/src/discord/cache.rs index 2778eca..6e89f59 100644 --- a/youmubot-osu/src/discord/cache.rs +++ b/youmubot-osu/src/discord/cache.rs @@ -1,29 +1,24 @@ -use super::db::OsuLastBeatmap; use super::BeatmapWithMode; use serenity::model::id::ChannelId; use youmubot_prelude::*; /// Save the beatmap into the server data storage. pub(crate) async fn save_beatmap( - data: &TypeMap, + env: &crate::discord::Env, channel_id: ChannelId, bm: &BeatmapWithMode, ) -> Result<()> { - data.get::() - .unwrap() - .save(channel_id, &bm.0, bm.1) - .await?; + env.last_beatmaps.save(channel_id, &bm.0, bm.1).await?; Ok(()) } /// Get the last beatmap requested from this channel. pub(crate) async fn get_beatmap( - data: &TypeMap, + env: &crate::discord::Env, channel_id: ChannelId, ) -> Result> { - data.get::() - .unwrap() + env.last_beatmaps .by_channel(channel_id) .await .map(|v| v.map(|(bm, mode)| BeatmapWithMode(bm, mode))) diff --git a/youmubot-osu/src/discord/display.rs b/youmubot-osu/src/discord/display.rs index a8b29be..4ad8f75 100644 --- a/youmubot-osu/src/discord/display.rs +++ b/youmubot-osu/src/discord/display.rs @@ -1,47 +1,26 @@ pub use beatmapset::display_beatmapset; pub use scores::ScoreListStyle; -mod scores { +pub(in crate::discord) mod scores { use crate::models::{Mode, Score}; - use serenity::{framework::standard::CommandResult, model::channel::Message}; - use youmubot_prelude::*; + use serenity::{all::ChannelId, framework::standard::CommandResult}; + use youmubot_prelude::{replyable::Replyable, *}; - #[derive(Debug, Clone, Copy, PartialEq, Eq)] /// The style for the scores list to be displayed. - pub enum ScoreListStyle { - Table, - Grid, - } + pub type ScoreListStyle = crate::discord::args::ScoreDisplay; - impl Default for ScoreListStyle { - fn default() -> Self { - Self::Table - } - } - - impl std::str::FromStr for ScoreListStyle { - type Err = Error; - - fn from_str(s: &str) -> Result { - match s { - "--table" => Ok(Self::Table), - "--grid" => Ok(Self::Grid), - _ => Err(Error::msg("unknown value")), - } - } - } - - impl ScoreListStyle { - pub async fn display_scores<'a>( - self, - scores: Vec, - mode: Mode, - ctx: &'a Context, - m: &'a Message, - ) -> CommandResult { - match self { - ScoreListStyle::Table => table::display_scores_table(scores, mode, ctx, m).await, - ScoreListStyle::Grid => grid::display_scores_grid(scores, mode, ctx, m).await, + pub async fn display_scores<'a>( + style: ScoreListStyle, + scores: Vec, + mode: Mode, + ctx: &'a Context, + m: impl Replyable, + channel_id: ChannelId, + ) -> CommandResult { + match style { + ScoreListStyle::Table => table::display_scores_table(scores, mode, ctx, m).await, + ScoreListStyle::Grid => { + grid::display_scores_grid(scores, mode, ctx, m, channel_id).await } } } @@ -51,15 +30,18 @@ mod scores { cache::save_beatmap, BeatmapCache, BeatmapMetaCache, BeatmapWithMode, }; use crate::models::{Mode, Score}; - use serenity::builder::EditMessage; - use serenity::{framework::standard::CommandResult, model::channel::Message}; + use poise::CreateReply; + use serenity::all::ChannelId; + use serenity::framework::standard::CommandResult; + use youmubot_prelude::replyable::Replyable; use youmubot_prelude::*; pub async fn display_scores_grid<'a>( scores: Vec, mode: Mode, ctx: &'a Context, - m: &'a Message, + m: impl Replyable, + channel_id: ChannelId, ) -> CommandResult { if scores.is_empty() { m.reply(&ctx, "No plays found").await?; @@ -67,7 +49,11 @@ mod scores { } paginate_reply( - Paginate { scores, mode }, + Paginate { + channel_id, + scores, + mode, + }, ctx, m, std::time::Duration::from_secs(60), @@ -77,13 +63,14 @@ mod scores { } pub struct Paginate { + channel_id: ChannelId, scores: Vec, mode: Mode, } #[async_trait] impl pagination::Paginate for Paginate { - async fn render(&mut self, page: u8, ctx: &Context, msg: &mut Message) -> Result { + async fn render(&mut self, page: u8, ctx: &Context) -> Result> { let data = ctx.data.read().await; let client = data.get::().unwrap(); let osu = data.get::().unwrap(); @@ -91,7 +78,6 @@ mod scores { let page = page as usize; let score = &self.scores[page]; - let hourglass = msg.react(ctx, '⌛').await?; let mode = self.mode; let beatmap = osu.get_beatmap(score.beatmap_id, mode).await?; let content = beatmap_cache.get_beatmap(beatmap.beatmap_id).await?; @@ -101,20 +87,18 @@ mod scores { .await? .ok_or_else(|| Error::msg("user not found"))?; - msg.edit( - ctx, - EditMessage::new().embed({ - crate::discord::embeds::score_embed(score, &bm, &content, &user) - .footer(format!("Page {}/{}", page + 1, self.scores.len())) - .build() - }), + let edit = CreateReply::default().embed({ + crate::discord::embeds::score_embed(score, &bm, &content, &user) + .footer(format!("Page {}/{}", page + 1, self.scores.len())) + .build() + }); + save_beatmap( + ctx.data.read().await.get::().unwrap(), + self.channel_id, + &bm, ) .await?; - save_beatmap(&*ctx.data.read().await, msg.channel_id, &bm).await?; - - // End - hourglass.delete(ctx).await?; - Ok(true) + Ok(Some(edit)) } fn len(&self) -> Option { @@ -129,15 +113,16 @@ mod scores { use crate::discord::oppai_cache::Accuracy; use crate::discord::{Beatmap, BeatmapCache, BeatmapInfo, BeatmapMetaCache}; use crate::models::{Mode, Score}; - use serenity::builder::EditMessage; - use serenity::{framework::standard::CommandResult, model::channel::Message}; + use poise::CreateReply; + use serenity::framework::standard::CommandResult; + use youmubot_prelude::replyable::Replyable; use youmubot_prelude::*; pub async fn display_scores_table<'a>( scores: Vec, mode: Mode, ctx: &'a Context, - m: &'a Message, + m: impl Replyable, ) -> CommandResult { if scores.is_empty() { m.reply(&ctx, "No plays found").await?; @@ -169,7 +154,7 @@ mod scores { #[async_trait] impl pagination::Paginate for Paginate { - async fn render(&mut self, page: u8, ctx: &Context, msg: &mut Message) -> Result { + async fn render(&mut self, page: u8, ctx: &Context) -> Result> { let data = ctx.data.read().await; let osu = data.get::().unwrap(); let beatmap_cache = data.get::().unwrap(); @@ -177,10 +162,9 @@ mod scores { let start = page * ITEMS_PER_PAGE; let end = self.scores.len().min(start + ITEMS_PER_PAGE); if start >= end { - return Ok(false); + return Ok(None); } - let hourglass = msg.react(ctx, '⌛').await?; let plays = &self.scores[start..end]; let mode = self.mode; let beatmaps = plays @@ -330,10 +314,7 @@ mod scores { self.total_pages() )); m.push_line("[?] means pp was predicted by oppai-rs."); - msg.edit(ctx, EditMessage::new().content(m.to_string())) - .await?; - hourglass.delete(ctx).await?; - Ok(true) + Ok(Some(CreateReply::default().content(m.to_string()))) } fn len(&self) -> Option { @@ -350,13 +331,14 @@ mod beatmapset { }, models::{Beatmap, Mode, Mods}, }; + use poise::CreateReply; use serenity::{ - all::Reaction, - builder::{CreateEmbedFooter, EditMessage}, - model::channel::Message, + all::{ChannelId, Reaction}, + builder::CreateEmbedFooter, model::channel::ReactionType, }; use youmubot_prelude::*; + use youmubot_prelude::{pagination::PageUpdate, replyable::Replyable}; const SHOW_ALL_EMOTE: &str = "🗒️"; @@ -365,16 +347,18 @@ mod beatmapset { beatmapset: Vec, mode: Option, mods: Option, - reply_to: &Message, + reply_to: impl Replyable + Send + 'static, + channel_id: ChannelId, message: impl AsRef, - ) -> Result { + ) -> Result<()> { let mods = mods.unwrap_or(Mods::NOMOD); if beatmapset.is_empty() { - return Ok(false); + return Ok(()); } let p = Paginate { + channel_id, infos: vec![None; beatmapset.len()], maps: beatmapset, mode, @@ -383,16 +367,17 @@ mod beatmapset { }; let ctx = ctx.clone(); - let reply_to = reply_to.clone(); + // let reply_to = reply_to.clone(); spawn_future(async move { - pagination::paginate_reply(p, &ctx, &reply_to, std::time::Duration::from_secs(60)) + pagination::paginate_reply(p, &ctx, reply_to, std::time::Duration::from_secs(60)) .await .pls_ok(); }); - Ok(true) + Ok(()) } struct Paginate { + channel_id: ChannelId, maps: Vec, infos: Vec>, mode: Option, @@ -417,26 +402,15 @@ mod beatmapset { Some(self.maps.len()) } - async fn render( - &mut self, - page: u8, - ctx: &Context, - m: &mut serenity::model::channel::Message, - ) -> Result { + async fn render(&mut self, page: u8, ctx: &Context) -> Result> { let page = page as usize; if page == self.maps.len() { - m.edit( - ctx, - EditMessage::new().embed(crate::discord::embeds::beatmapset_embed( - &self.maps[..], - self.mode, - )), - ) - .await?; - return Ok(true); + return Ok(Some(CreateReply::default().embed( + crate::discord::embeds::beatmapset_embed(&self.maps[..], self.mode), + ))); } if page > self.maps.len() { - return Ok(false); + return Ok(None); } let map = &self.maps[page]; @@ -448,8 +422,7 @@ mod beatmapset { info } }; - m.edit(ctx, - EditMessage::new().content(self.message.as_str()).embed( + let edit = CreateReply::default().content(self.message.as_str()).embed( crate::discord::embeds::beatmap_embed( map, self.mode.unwrap_or(map.mode), @@ -464,47 +437,46 @@ mod beatmapset { SHOW_ALL_EMOTE, )) }) - ) - ) - .await?; + ); save_beatmap( - &*ctx.data.read().await, - m.channel_id, + ctx.data.read().await.get::().unwrap(), + self.channel_id, &BeatmapWithMode(map.clone(), self.mode.unwrap_or(map.mode)), ) .await .pls_ok(); - Ok(true) + Ok(Some(edit)) } - async fn prerender( - &mut self, - ctx: &Context, - m: &mut serenity::model::channel::Message, - ) -> Result<()> { - m.react(&ctx, SHOW_ALL_EMOTE.parse::().unwrap()) - .await?; - Ok(()) + async fn prerender(&mut self, ctx: &Context) -> Result { + // m.react(&ctx, SHOW_ALL_EMOTE.parse::().unwrap()) + // .await?; + Ok(PageUpdate { + react: vec![SHOW_ALL_EMOTE.parse::().unwrap()], + ..Default::default() + }) } async fn handle_reaction( &mut self, page: u8, ctx: &Context, - message: &mut serenity::model::channel::Message, reaction: &Reaction, - ) -> Result> { + ) -> Result { // Render the old style. if let ReactionType::Unicode(s) = &reaction.emoji { if s == SHOW_ALL_EMOTE { - self.render(self.maps.len() as u8, ctx, message).await?; - return Ok(Some(self.maps.len() as u8)); + let message = self.render(self.maps.len() as u8, ctx).await?; + let update = PageUpdate { + message, + page: Some(self.maps.len() as u8), + ..Default::default() + }; + return Ok(update); } } - pagination::handle_pagination_reaction(page, self, ctx, message, reaction) - .await - .map(Some) + pagination::handle_pagination_reaction(page, self, ctx, reaction).await } } } diff --git a/youmubot-osu/src/discord/hook.rs b/youmubot-osu/src/discord/hook.rs index da33442..39c8700 100644 --- a/youmubot-osu/src/discord/hook.rs +++ b/youmubot-osu/src/discord/hook.rs @@ -134,7 +134,7 @@ pub fn hook<'a>( let mode = l.mode.unwrap_or(b.mode); let bm = super::BeatmapWithMode(*b, mode); crate::discord::cache::save_beatmap( - &*ctx.data.read().await, + ctx.data.read().await.get::().unwrap(), msg.channel_id, &bm, ) @@ -413,7 +413,8 @@ async fn handle_beatmapset<'a, 'b>( beatmaps, mode, None, - reply_to, + reply_to.clone(), + reply_to.channel_id, format!("Beatmapset information for `{}`", link), ) .await diff --git a/youmubot-osu/src/discord/mod.rs b/youmubot-osu/src/discord/mod.rs index f540ba3..0e08864 100644 --- a/youmubot-osu/src/discord/mod.rs +++ b/youmubot-osu/src/discord/mod.rs @@ -8,6 +8,7 @@ use crate::{ }; use rand::seq::IteratorRandom; use serenity::{ + all::{ChannelId, Member}, builder::{CreateMessage, EditMessage}, collector, framework::standard::{ @@ -18,10 +19,11 @@ use serenity::{ utils::MessageBuilder, }; use std::{str::FromStr, sync::Arc}; -use youmubot_prelude::*; +use youmubot_prelude::{replyable::Replyable, *}; mod announcer; pub mod app_commands; +mod args; pub(crate) mod beatmap_cache; mod cache; mod db; @@ -386,17 +388,16 @@ impl FromStr for ModeArg { async fn to_user_id_query( s: Option, - data: &TypeMap, - msg: &Message, + env: &Env, + sender: &serenity::all::User, ) -> Result { let id = match s { Some(UsernameArg::Raw(s)) => return Ok(UserID::from_string(s)), Some(UsernameArg::Tagged(r)) => r, - None => msg.author.id, + None => sender.id, }; - data.get::() - .unwrap() + env.saved_users .by_user_id(id) .await? .map(|u| UserID::ID(u.id)) @@ -431,13 +432,14 @@ impl FromStr for Nth { #[max_args(4)] pub async fn recent(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let data = ctx.data.read().await; + let env = data.get::().unwrap(); let nth = args.single::().unwrap_or(Nth::All); let style = args.single::().unwrap_or_default(); let mode = args.single::().unwrap_or(ModeArg(Mode::Std)).0; let user = to_user_id_query( args.quoted().trimmed().single::().ok(), - &data, - msg, + &env, + &msg.author, ) .await?; @@ -471,13 +473,14 @@ pub async fn recent(ctx: &Context, msg: &Message, mut args: Args) -> CommandResu .await?; // Save the beatmap... - cache::save_beatmap(&data, msg.channel_id, &beatmap_mode).await?; + cache::save_beatmap(&env, msg.channel_id, &beatmap_mode).await?; } Nth::All => { let plays = osu .user_recent(UserID::ID(user.id), |f| f.mode(mode).limit(50)) .await?; - style.display_scores(plays, mode, ctx, msg).await?; + display::scores::display_scores(style, plays, mode, ctx, msg.clone(), msg.channel_id) + .await?; } } Ok(()) @@ -499,12 +502,11 @@ impl FromStr for OptBeatmapset { /// Load the mentioned beatmap from the given message. pub(crate) async fn load_beatmap( - ctx: &Context, - msg: &Message, + env: &Env, + msg: Option<&Message>, + channel_id: ChannelId, ) -> Option<(BeatmapWithMode, Option)> { - let data = ctx.data.read().await; - - if let Some(replied) = &msg.referenced_message { + if let Some(replied) = msg.and_then(|m| m.referenced_message.as_ref()) { // Try to look for a mention of the replied message. let beatmap_id = SHORT_LINK_REGEX.captures(&replied.content).or_else(|| { replied.embeds.iter().find_map(|e| { @@ -526,8 +528,8 @@ pub(crate) async fn load_beatmap( let mods = caps .name("mods") .and_then(|m| m.as_str().parse::().ok()); - let osu = data.get::().unwrap(); - let bms = osu + let bms = env + .client .beatmaps(BeatmapRequestKind::Beatmap(id), |f| f.maybe_mode(mode)) .await .ok() @@ -536,19 +538,14 @@ pub(crate) async fn load_beatmap( let bm_mode = beatmap.mode; let bm = BeatmapWithMode(beatmap, mode.unwrap_or(bm_mode)); // Store the beatmap in history - cache::save_beatmap(&data, msg.channel_id, &bm) - .await - .pls_ok(); + cache::save_beatmap(env, channel_id, &bm).await.pls_ok(); return Some((bm, mods)); } } } - let b = cache::get_beatmap(&data, msg.channel_id) - .await - .ok() - .flatten(); + let b = cache::get_beatmap(env, channel_id).await.ok().flatten(); b.map(|b| (b, None)) } @@ -560,7 +557,8 @@ pub(crate) async fn load_beatmap( #[max_args(2)] pub async fn last(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let data = ctx.data.read().await; - let b = load_beatmap(ctx, msg).await; + let env = data.get::().unwrap(); + let b = load_beatmap(&env, Some(msg), msg.channel_id).await; let beatmapset = args.find::().is_ok(); match b { @@ -574,7 +572,8 @@ pub async fn last(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult beatmapset, None, Some(mods), - msg, + msg.clone(), + msg.channel_id, "Here is the beatmapset you requested!", ) .await?; @@ -612,7 +611,8 @@ pub async fn last(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult #[max_args(3)] pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let data = ctx.data.read().await; - let bm = load_beatmap(ctx, msg).await; + let env = data.get::().unwrap(); + let bm = load_beatmap(&env, Some(msg), msg.channel_id).await; match bm { None => { @@ -632,7 +632,7 @@ pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResul None => Some(msg.author.id), _ => None, }; - let user = to_user_id_query(username_arg, &data, msg).await?; + let user = to_user_id_query(username_arg, &env, &msg.author).await?; let osu = data.get::().unwrap(); @@ -662,7 +662,63 @@ pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResul .pls_ok(); } - style.display_scores(scores, m, ctx, msg).await?; + display::scores::display_scores(style, scores, m, ctx, msg.clone(), msg.channel_id) + .await?; + } + } + + Ok(()) +} + +pub async fn check_impl( + env: &Env, + ctx: &Context, + reply: impl Replyable, + channel_id: ChannelId, + sender: &serenity::all::User, + msg: Option<&Message>, + osu_id: Option, + member: Option, + mods: Option, + style: Option, +) -> CommandResult { + let bm = load_beatmap(&env, msg, channel_id).await; + + match bm { + None => { + reply + .reply(&ctx, "No beatmap queried on this channel.") + .await?; + } + Some((bm, mods_def)) => { + let mods = mods.unwrap_or_default(); + let b = &bm.0; + let m = bm.1; + let style = style.unwrap_or_default(); + let username_arg = member + .map(|m| UsernameArg::Tagged(m.user.id)) + .or(osu_id.map(|id| UsernameArg::Raw(id))); + let user = to_user_id_query(username_arg, env, sender).await?; + + let user = env + .client + .user(user, |f| f) + .await? + .ok_or_else(|| Error::msg("User not found"))?; + let mut scores = env + .client + .scores(b.beatmap_id, |f| f.user(UserID::ID(user.id)).mode(m)) + .await? + .into_iter() + .filter(|s| s.mods.contains(mods)) + .collect::>(); + scores.sort_by(|a, b| b.pp.unwrap().partial_cmp(&a.pp.unwrap()).unwrap()); + + if scores.is_empty() { + reply.reply(&ctx, "No scores found").await?; + return Ok(()); + } + display::scores::display_scores(style, scores, m, ctx, reply, channel_id).await?; } } @@ -677,6 +733,7 @@ pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResul #[max_args(4)] pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let data = ctx.data.read().await; + let env = data.get::().unwrap(); let nth = args.single::().unwrap_or(Nth::All); let style = args.single::().unwrap_or_default(); let mode = args @@ -684,7 +741,7 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult .map(|ModeArg(t)| t) .unwrap_or(Mode::Std); - let user = to_user_id_query(args.single::().ok(), &data, msg).await?; + let user = to_user_id_query(args.single::().ok(), &env, &msg.author).await?; let meta_cache = data.get::().unwrap(); let osu = data.get::().unwrap(); @@ -726,13 +783,14 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult .await?; // Save the beatmap... - cache::save_beatmap(&data, msg.channel_id, &beatmap).await?; + cache::save_beatmap(data.get::().unwrap(), msg.channel_id, &beatmap).await?; } Nth::All => { let plays = osu .user_best(UserID::ID(user.id), |f| f.mode(mode).limit(100)) .await?; - style.display_scores(plays, mode, ctx, msg).await?; + display::scores::display_scores(style, plays, mode, ctx, msg.clone(), msg.channel_id) + .await?; } } Ok(()) @@ -757,7 +815,8 @@ pub async fn clean_cache(ctx: &Context, msg: &Message, args: Args) -> CommandRes async fn get_user(ctx: &Context, msg: &Message, mut args: Args, mode: Mode) -> CommandResult { let data = ctx.data.read().await; - let user = to_user_id_query(args.single::().ok(), &data, msg).await?; + let env = data.get::().unwrap(); + let user = to_user_id_query(args.single::().ok(), &env, &msg.author).await?; let osu = data.get::().unwrap(); let cache = data.get::().unwrap(); let user = osu.user(user, |f| f.mode(mode)).await?; diff --git a/youmubot-osu/src/discord/server_rank.rs b/youmubot-osu/src/discord/server_rank.rs index dd639db..20214b6 100644 --- a/youmubot-osu/src/discord/server_rank.rs +++ b/youmubot-osu/src/discord/server_rank.rs @@ -10,6 +10,7 @@ use crate::{ request::UserID, }; +use poise::CreateReply; use serenity::{ builder::EditMessage, framework::standard::{macros::command, Args, CommandResult}, @@ -88,14 +89,14 @@ pub async fn server_rank(ctx: &Context, m: &Message, mut args: Args) -> CommandR let users = std::sync::Arc::new(users); let last_update = last_update.unwrap(); paginate_reply_fn( - move |page: u8, ctx: &Context, m: &mut Message| { + move |page: u8, ctx: &Context| { const ITEMS_PER_PAGE: usize = 10; let users = users.clone(); Box::pin(async move { let start = (page as usize) * ITEMS_PER_PAGE; let end = (start + ITEMS_PER_PAGE).min(users.len()); if start >= end { - return Ok(false); + return Ok(None); } let total_len = users.len(); let users = &users[start..end]; @@ -142,13 +143,11 @@ pub async fn server_rank(ctx: &Context, m: &Message, mut args: Args) -> CommandR (total_len + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE, last_update.format(""), )); - m.edit(ctx, EditMessage::new().content(content.to_string())) - .await?; - Ok(true) + Ok(Some(CreateReply::default().content(content.to_string()))) }) }, ctx, - m, + m.clone(), std::time::Duration::from_secs(60), ) .await?; @@ -191,9 +190,10 @@ pub async fn show_leaderboard(ctx: &Context, m: &Message, mut args: Args) -> Com let style = args.single::().unwrap_or_default(); let data = ctx.data.read().await; + let env = data.get::().unwrap(); let member_cache = data.get::().unwrap(); - let (bm, _) = match super::load_beatmap(ctx, m).await { + let (bm, _) = match super::load_beatmap(env, Some(m), m.channel_id).await { Some((bm, mods_def)) => { let mods = args.find::().ok().or(mods_def).unwrap_or(Mods::NOMOD); (bm, mods) @@ -295,25 +295,26 @@ pub async fn show_leaderboard(ctx: &Context, m: &Message, mut args: Args) -> Com } if let ScoreListStyle::Grid = style { - style - .display_scores( - scores.into_iter().map(|(_, _, a)| a).collect(), - mode, - ctx, - m, - ) - .await?; + crate::discord::display::scores::display_scores( + style, + scores.into_iter().map(|(_, _, a)| a).collect(), + mode, + ctx, + m.clone(), + m.channel_id, + ) + .await?; return Ok(()); } let has_lazer_score = scores.iter().any(|(_, _, v)| v.score.is_none()); paginate_reply_fn( - move |page: u8, ctx: &Context, m: &mut Message| { + move |page: u8, ctx: &Context| { const ITEMS_PER_PAGE: usize = 5; let start = (page as usize) * ITEMS_PER_PAGE; let end = (start + ITEMS_PER_PAGE).min(scores.len()); if start >= end { - return Box::pin(future::ready(Ok(false))); + return Box::pin(future::ready(Ok(None))); } let total_len = scores.len(); let scores = scores[start..end].to_vec(); @@ -436,12 +437,11 @@ pub async fn show_leaderboard(ctx: &Context, m: &Message, mut args: Args) -> Com content.push_line("PP was calculated by `oppai-rs`, **not** official values."); } - m.edit(&ctx, EditMessage::new().content(content.build())).await?; - Ok(true) + Ok(Some(CreateReply::default().content(content.build()))) }) }, ctx, - m, + m.clone(), std::time::Duration::from_secs(60), ) .await?; diff --git a/youmubot-osu/src/models/mods.rs b/youmubot-osu/src/models/mods.rs index 3d8858c..20faf07 100644 --- a/youmubot-osu/src/models/mods.rs +++ b/youmubot-osu/src/models/mods.rs @@ -97,8 +97,16 @@ impl Mods { } } +#[derive(Debug, thiserror::Error)] +pub enum ModParseError { + #[error("String of odd length is not a mod string")] + OddLength, + #[error("{0} is not a valid mod")] + InvalidMod(String), +} + impl std::str::FromStr for Mods { - type Err = String; + type Err = ModParseError; fn from_str(mut s: &str) -> Result { let mut res = Self::default(); // Strip leading + @@ -134,11 +142,11 @@ impl std::str::FromStr for Mods { "8K" => res |= Mods::KEY8, "9K" => res |= Mods::KEY9, "??" => res |= Mods::UNKNOWN, - v => return Err(format!("{} is not a valid mod", v)), + v => return Err(ModParseError::InvalidMod(v.to_owned())), } } if !s.is_empty() { - Err("String of odd length is not a mod string".to_owned()) + Err(ModParseError::OddLength) } else { Ok(res) } diff --git a/youmubot-prelude/src/announcer.rs b/youmubot-prelude/src/announcer.rs index 34100ee..944eee6 100644 --- a/youmubot-prelude/src/announcer.rs +++ b/youmubot-prelude/src/announcer.rs @@ -18,7 +18,7 @@ use serenity::{ prelude::*, utils::MessageBuilder, }; -use std::{arch::x86_64::_bittestandcomplement, collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; use youmubot_db::DB; #[derive(Debug, Clone)] diff --git a/youmubot-prelude/src/lib.rs b/youmubot-prelude/src/lib.rs index c5a653b..e360dd4 100644 --- a/youmubot-prelude/src/lib.rs +++ b/youmubot-prelude/src/lib.rs @@ -1,4 +1,3 @@ -use announcer::AnnouncerChannels; /// Module `prelude` provides a sane set of default imports that can be used inside /// a Youmubot source file. pub use serenity::prelude::*; @@ -11,6 +10,7 @@ pub mod hook; pub mod member_cache; pub mod pagination; pub mod ratelimit; +pub mod replyable; pub mod setup; pub use announcer::{Announcer, AnnouncerHandler}; diff --git a/youmubot-prelude/src/pagination.rs b/youmubot-prelude/src/pagination.rs index fe5335a..2a04210 100644 --- a/youmubot-prelude/src/pagination.rs +++ b/youmubot-prelude/src/pagination.rs @@ -1,10 +1,14 @@ -use crate::{Context, OkPrint, Result}; +use crate::{ + replyable::{Replyable, Updateable}, + Context, OkPrint, Result, +}; use futures_util::{future::Future, StreamExt as _}; +use poise::CreateReply; use serenity::{ builder::CreateMessage, collector, model::{ - channel::{Message, Reaction, ReactionType}, + channel::{Reaction, ReactionType}, id::ChannelId, }, }; @@ -16,15 +20,41 @@ const ARROW_LEFT: &str = "⬅️"; const REWIND: &str = "⏪"; const FAST_FORWARD: &str = "⏩"; +/// Represents a page update. +#[derive(Default)] +pub struct PageUpdate { + pub message: Option, + pub page: Option, + pub react: Vec, +} + +impl From for PageUpdate { + fn from(value: u8) -> Self { + PageUpdate { + page: Some(value), + ..Default::default() + } + } +} + +impl From for PageUpdate { + fn from(value: CreateReply) -> Self { + PageUpdate { + message: Some(value), + ..Default::default() + } + } +} + /// A trait that provides the implementation of a paginator. #[async_trait::async_trait] pub trait Paginate: Send + Sized { /// Render the given page. - async fn render(&mut self, page: u8, ctx: &Context, m: &mut Message) -> Result; + async fn render(&mut self, page: u8, ctx: &Context) -> Result>; /// Any setting-up before the rendering stage. - async fn prerender(&mut self, _ctx: &Context, _m: &mut Message) -> Result<()> { - Ok(()) + async fn prerender(&mut self, _ctx: &Context) -> Result { + Ok(PageUpdate::default()) } /// Handle the incoming reaction. Defaults to calling `handle_pagination_reaction`, but you can do some additional handling @@ -35,12 +65,9 @@ pub trait Paginate: Send + Sized { &mut self, page: u8, ctx: &Context, - message: &mut Message, reaction: &Reaction, - ) -> Result> { - handle_pagination_reaction(page, self, ctx, message, reaction) - .await - .map(Some) + ) -> Result { + handle_pagination_reaction(page, self, ctx, reaction).await } /// Return the number of pages, if it is known in advance. @@ -60,12 +87,12 @@ where T: for<'m> FnMut( u8, &'m Context, - &'m mut Message, - ) -> std::pin::Pin> + Send + 'm>> - + Send, + ) -> std::pin::Pin< + Box>> + Send + 'm>, + > + Send, { - async fn render(&mut self, page: u8, ctx: &Context, m: &mut Message) -> Result { - self(page, ctx, m).await + async fn render(&mut self, page: u8, ctx: &Context) -> Result> { + self(page, ctx).await } } @@ -74,13 +101,13 @@ where pub async fn paginate_reply( pager: impl Paginate, ctx: &Context, - reply_to: &Message, + reply_to: impl Replyable, timeout: std::time::Duration, ) -> Result<()> { - let message = reply_to + let update = reply_to .reply(&ctx, "Youmu is loading the first page...") .await?; - paginate_with_first_message(pager, ctx, message, timeout).await + paginate_with_first_message(pager, ctx, update, timeout).await } // Paginate! with a pager function. @@ -103,11 +130,17 @@ pub async fn paginate( async fn paginate_with_first_message( mut pager: impl Paginate, ctx: &Context, - mut message: Message, + mut update: impl Updateable, timeout: std::time::Duration, ) -> Result<()> { - pager.prerender(ctx, &mut message).await?; - pager.render(0, ctx, &mut message).await?; + let message = update.message().await?; + let prerender = pager.prerender(ctx).await?; + if let Some(cr) = prerender.message { + update.edit(ctx, cr).await?; + } + if let Some(cr) = pager.render(0, ctx).await? { + update.edit(ctx, cr).await?; + } // Just quit if there is only one page if pager.len().filter(|&v| v == 1).is_some() { return Ok(()); @@ -115,7 +148,7 @@ async fn paginate_with_first_message( // React to the message let large_count = pager.len().filter(|&p| p > 10).is_some(); let reactions = { - let mut rs = Vec::::with_capacity(4); + let mut rs = Vec::::with_capacity(4 + prerender.react.len()); if large_count { // add >> and << buttons rs.push(message.react(&ctx, ReactionType::try_from(REWIND)?).await?); @@ -138,6 +171,9 @@ async fn paginate_with_first_message( .await?, ); } + for r in prerender.react.into_iter() { + rs.push(message.react(&ctx, r).await?); + } rs }; // Build a reaction collector @@ -161,12 +197,16 @@ async fn paginate_with_first_message( Err(_) => break Ok(()), Ok(None) => break Ok(()), Ok(Some(reaction)) => { - page = match pager - .handle_reaction(page, ctx, &mut message, &reaction) - .await - { - Ok(Some(v)) => v, - Ok(None) => break Ok(()), + page = match pager.handle_reaction(page, ctx, &reaction).await { + Ok(pu) => { + if let Some(cr) = pu.message { + update.edit(ctx, cr).await?; + } + match pu.page { + Some(v) => v, + None => break Ok(()), + } + } Err(e) => break Err(e), }; } @@ -188,9 +228,9 @@ pub async fn paginate_fn( pager: impl for<'m> FnMut( u8, &'m Context, - &'m mut Message, - ) -> std::pin::Pin> + Send + 'm>> - + Send, + ) -> std::pin::Pin< + Box>> + Send + 'm>, + > + Send, ctx: &Context, channel: ChannelId, timeout: std::time::Duration, @@ -203,11 +243,11 @@ pub async fn paginate_reply_fn( pager: impl for<'m> FnMut( u8, &'m Context, - &'m mut Message, - ) -> std::pin::Pin> + Send + 'm>> - + Send, + ) -> std::pin::Pin< + Box>> + Send + 'm>, + > + Send, ctx: &Context, - reply_to: &Message, + reply_to: impl Replyable, timeout: std::time::Duration, ) -> Result<()> { paginate_reply(pager, ctx, reply_to, timeout).await @@ -218,15 +258,14 @@ pub async fn handle_pagination_reaction( page: u8, pager: &mut impl Paginate, ctx: &Context, - message: &mut Message, reaction: &Reaction, -) -> Result { +) -> Result { let pages = pager.len(); let fast = pages.map(|v| v / 10).unwrap_or(5).max(5) as u8; match &reaction.emoji { ReactionType::Unicode(ref s) => { let new_page = match s.as_str() { - ARROW_LEFT | REWIND if page == 0 => return Ok(page), + ARROW_LEFT | REWIND if page == 0 => return Ok(page.into()), ARROW_LEFT => page - 1, REWIND => { if page < fast { @@ -236,18 +275,26 @@ pub async fn handle_pagination_reaction( } } ARROW_RIGHT if pages.filter(|&pages| page as usize + 1 >= pages).is_some() => { - return Ok(page) + return Ok(page.into()) } ARROW_RIGHT => page + 1, FAST_FORWARD => (pages.unwrap() as u8 - 1).min(page + fast), - _ => return Ok(page), + _ => return Ok(page.into()), }; - Ok(if pager.render(new_page, ctx, message).await? { - new_page - } else { - page - }) + let reply = pager.render(new_page, ctx).await?; + Ok(reply + .map(|cr| PageUpdate { + message: Some(cr), + page: Some(page), + ..Default::default() + }) + .unwrap_or_else(|| page.into())) + // Ok(if pager.render(new_page, ctx, message).await? { + // new_page + // } else { + // page + // }) } - _ => Ok(page), + _ => Ok(page.into()), } } diff --git a/youmubot-prelude/src/replyable.rs b/youmubot-prelude/src/replyable.rs new file mode 100644 index 0000000..ae32319 --- /dev/null +++ b/youmubot-prelude/src/replyable.rs @@ -0,0 +1,76 @@ +use poise::{CreateReply, ReplyHandle}; +use serenity::{all::Message, builder::EditMessage}; + +use crate::*; + +/// Represents a target where replying is possible and returns a message. +#[async_trait] +pub trait Replyable { + type Resp: Updateable + Send; + /// Reply to the context. + async fn reply( + &self, + ctx: impl CacheHttp + Send, + content: impl Into + Send, + ) -> Result; +} + +#[async_trait] +impl Replyable for Message { + type Resp = Message; + async fn reply( + &self, + ctx: impl CacheHttp + Send, + content: impl Into + Send, + ) -> Result { + Ok(Message::reply(self, ctx, content).await?) + } +} + +#[async_trait] +impl<'c, T: Sync, E> Replyable for poise::Context<'c, T, E> { + type Resp = (ReplyHandle<'c>, Self); + async fn reply( + &self, + _ctx: impl CacheHttp + Send, + content: impl Into + Send, + ) -> Result { + let handle = poise::Context::reply(*self, content).await?; + Ok((handle, *self)) + } +} + +/// Represents a message representation that allows deletion and editing. +#[async_trait] +pub trait Updateable { + async fn message(&self) -> Result; + async fn edit(&mut self, ctx: impl CacheHttp + Send, content: CreateReply) -> Result<()>; + async fn delete(&self, ctx: impl CacheHttp + Send) -> Result<()>; +} + +#[async_trait] +impl Updateable for Message { + async fn message(&self) -> Result { + Ok(self.clone()) + } + async fn edit(&mut self, ctx: impl CacheHttp + Send, content: CreateReply) -> Result<()> { + let content = content.to_prefix_edit(EditMessage::new()); + Ok(Message::edit(self, ctx, content).await?) + } + async fn delete(&self, ctx: impl CacheHttp + Send) -> Result<()> { + Ok(Message::delete(self, ctx).await?) + } +} + +#[async_trait] +impl<'a, T: Sync, E> Updateable for (poise::ReplyHandle<'a>, poise::Context<'a, T, E>) { + async fn message(&self) -> Result { + Ok(poise::ReplyHandle::message(&self.0).await?.into_owned()) + } + async fn edit(&mut self, _ctx: impl CacheHttp, content: CreateReply) -> Result<()> { + Ok(poise::ReplyHandle::edit(&self.0, self.1, content).await?) + } + async fn delete(&self, _ctx: impl CacheHttp) -> Result<()> { + Ok(poise::ReplyHandle::delete(&self.0, self.1).await?) + } +}