osu: Implement check button!

This commit is contained in:
Natsu Kagami 2024-07-14 00:16:12 +02:00
parent 1d250b8ea7
commit 32053c3fe3
Signed by: nki
GPG key ID: 55A032EB38B49ADB
11 changed files with 274 additions and 75 deletions

View file

@ -23,6 +23,7 @@ use crate::{
}; };
use super::db::{OsuSavedUsers, OsuUser}; use super::db::{OsuSavedUsers, OsuUser};
use super::interaction::score_components;
use super::{calculate_weighted_map_length, OsuEnv}; use super::{calculate_weighted_map_length, OsuEnv};
use super::{embeds::score_embed, BeatmapWithMode}; use super::{embeds::score_embed, BeatmapWithMode};
@ -133,7 +134,7 @@ impl Announcer {
let scores = self.scan_user(osu_user, mode).await?; let scores = self.scan_user(osu_user, mode).await?;
let user = self let user = self
.client .client
.user(UserID::ID(osu_user.id), |f| { .user(&UserID::ID(osu_user.id), |f| {
f.mode(mode) f.mode(mode)
.event_days(days_since_last_update.min(31) as u8) .event_days(days_since_last_update.min(31) as u8)
}) })
@ -339,7 +340,8 @@ impl<'a> CollectedScore<'a> {
ScoreType::WorldRecord(rank) => b.world_record(rank), ScoreType::WorldRecord(rank) => b.world_record(rank),
} }
.build() .build()
}), })
.components(vec![score_components()]),
) )
.await?; .await?;

View file

@ -2,7 +2,7 @@ pub use beatmapset::display_beatmapset;
pub use scores::ScoreListStyle; pub use scores::ScoreListStyle;
mod scores { mod scores {
use serenity::{framework::standard::CommandResult, model::channel::Message}; use serenity::model::channel::Message;
use youmubot_prelude::*; use youmubot_prelude::*;
@ -39,8 +39,8 @@ mod scores {
scores: Vec<Score>, scores: Vec<Score>,
mode: Mode, mode: Mode,
ctx: &'a Context, ctx: &'a Context,
m: &'a Message, m: Message,
) -> CommandResult { ) -> Result<()> {
match self { match self {
ScoreListStyle::Table => table::display_scores_table(scores, mode, ctx, m).await, ScoreListStyle::Table => table::display_scores_table(scores, mode, ctx, m).await,
ScoreListStyle::Grid => grid::display_scores_grid(scores, mode, ctx, m).await, ScoreListStyle::Grid => grid::display_scores_grid(scores, mode, ctx, m).await,
@ -48,12 +48,14 @@ mod scores {
} }
} }
pub mod grid { mod grid {
use pagination::paginate_with_first_message;
use serenity::builder::EditMessage; use serenity::builder::EditMessage;
use serenity::{framework::standard::CommandResult, model::channel::Message}; use serenity::model::channel::Message;
use youmubot_prelude::*; use youmubot_prelude::*;
use crate::discord::interaction::score_components;
use crate::discord::{cache::save_beatmap, BeatmapWithMode, OsuEnv}; use crate::discord::{cache::save_beatmap, BeatmapWithMode, OsuEnv};
use crate::models::{Mode, Score}; use crate::models::{Mode, Score};
@ -61,17 +63,18 @@ mod scores {
scores: Vec<Score>, scores: Vec<Score>,
mode: Mode, mode: Mode,
ctx: &'a Context, ctx: &'a Context,
m: &'a Message, mut on: Message,
) -> CommandResult { ) -> Result<()> {
if scores.is_empty() { if scores.is_empty() {
m.reply(&ctx, "No plays found").await?; on.edit(&ctx, EditMessage::new().content("No plays found"))
.await?;
return Ok(()); return Ok(());
} }
paginate_reply( paginate_with_first_message(
Paginate { scores, mode }, Paginate { scores, mode },
ctx, ctx,
m, on,
std::time::Duration::from_secs(60), std::time::Duration::from_secs(60),
) )
.await?; .await?;
@ -97,17 +100,19 @@ mod scores {
let bm = BeatmapWithMode(beatmap, mode); let bm = BeatmapWithMode(beatmap, mode);
let user = env let user = env
.client .client
.user(crate::request::UserID::ID(score.user_id), |f| f) .user(&crate::request::UserID::ID(score.user_id), |f| f)
.await? .await?
.ok_or_else(|| Error::msg("user not found"))?; .ok_or_else(|| Error::msg("user not found"))?;
msg.edit( msg.edit(
ctx, ctx,
EditMessage::new().embed({ EditMessage::new()
.embed({
crate::discord::embeds::score_embed(score, &bm, &content, &user) crate::discord::embeds::score_embed(score, &bm, &content, &user)
.footer(format!("Page {}/{}", page + 1, self.scores.len())) .footer(format!("Page {}/{}", page + 1, self.scores.len()))
.build() .build()
}), })
.components(vec![score_components()]),
) )
.await?; .await?;
save_beatmap(&env, msg.channel_id, &bm).await?; save_beatmap(&env, msg.channel_id, &bm).await?;
@ -126,8 +131,9 @@ mod scores {
pub mod table { pub mod table {
use std::borrow::Cow; use std::borrow::Cow;
use pagination::paginate_with_first_message;
use serenity::builder::EditMessage; use serenity::builder::EditMessage;
use serenity::{framework::standard::CommandResult, model::channel::Message}; use serenity::model::channel::Message;
use youmubot_prelude::table_format::Align::{Left, Right}; use youmubot_prelude::table_format::Align::{Left, Right};
use youmubot_prelude::table_format::{table_formatting, Align}; use youmubot_prelude::table_format::{table_formatting, Align};
@ -141,17 +147,18 @@ mod scores {
scores: Vec<Score>, scores: Vec<Score>,
mode: Mode, mode: Mode,
ctx: &'a Context, ctx: &'a Context,
m: &'a Message, mut on: Message,
) -> CommandResult { ) -> Result<()> {
if scores.is_empty() { if scores.is_empty() {
m.reply(&ctx, "No plays found").await?; on.edit(&ctx, EditMessage::new().content("No plays found"))
.await?;
return Ok(()); return Ok(());
} }
paginate_reply( paginate_with_first_message(
Paginate { scores, mode }, Paginate { scores, mode },
ctx, ctx,
m, on,
std::time::Duration::from_secs(60), std::time::Duration::from_secs(60),
) )
.await?; .await?;
@ -332,7 +339,7 @@ mod beatmapset {
use youmubot_prelude::*; use youmubot_prelude::*;
use crate::discord::OsuEnv; use crate::discord::{interaction::beatmap_components, OsuEnv};
use crate::{ use crate::{
discord::{cache::save_beatmap, oppai_cache::BeatmapInfoWithPP, BeatmapWithMode}, discord::{cache::save_beatmap, oppai_cache::BeatmapInfoWithPP, BeatmapWithMode},
models::{Beatmap, Mode, Mods}, models::{Beatmap, Mode, Mods},
@ -439,7 +446,8 @@ mod beatmapset {
SHOW_ALL_EMOTE, SHOW_ALL_EMOTE,
)) ))
}) })
), )
.components(vec![beatmap_components()]),
) )
.await?; .await?;
let env = ctx.data.read().await.get::<OsuEnv>().unwrap().clone(); let env = ctx.data.read().await.get::<OsuEnv>().unwrap().clone();

View file

@ -17,6 +17,7 @@ use crate::{
}; };
use super::embeds::beatmap_embed; use super::embeds::beatmap_embed;
use super::interaction::{beatmap_components, score_components};
use super::link_parser::*; use super::link_parser::*;
/// React to /scores/{id} links. /// React to /scores/{id} links.
@ -83,7 +84,8 @@ pub fn score_hook<'a>(
len len
) )
}) })
.embed(score_embed(&s, &b, &c, h).build()), .embed(score_embed(&s, &b, &c, h).build())
.components(vec![score_components()]),
) )
.await .await
.pls_ok(); .pls_ok();
@ -301,6 +303,7 @@ async fn handle_beatmap<'a, 'b>(
mods, mods,
info, info,
)) ))
.components(vec![beatmap_components()])
.reference_message(reply_to), .reference_message(reply_to),
) )
.await?; .await?;

View file

@ -0,0 +1,78 @@
use std::pin::Pin;
use future::Future;
use serenity::all::{
ComponentInteractionDataKind, CreateActionRow, CreateButton, CreateInteractionResponseMessage,
Interaction,
};
use youmubot_prelude::*;
use crate::Mods;
use super::{display::ScoreListStyle, OsuEnv};
pub(super) const BTN_CHECK: &'static str = "youmubot_osu_btn_check";
// pub(super) const BTN_LAST: &'static str = "youmubot_osu_btn_last";
/// Create an action row for score pages.
pub fn score_components() -> CreateActionRow {
CreateActionRow::Buttons(vec![check_button()])
}
/// Create an action row for score pages.
pub fn beatmap_components() -> CreateActionRow {
CreateActionRow::Buttons(vec![check_button()])
}
/// Creates a new check button.
pub fn check_button() -> CreateButton {
CreateButton::new(BTN_CHECK)
.label("Check your score")
.emoji('🔎')
.style(serenity::all::ButtonStyle::Secondary)
}
/// Implements the `check` button on scores and beatmaps.
pub fn handle_check_button<'a>(
ctx: &'a Context,
interaction: &'a Interaction,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
let comp = match interaction.as_message_component() {
Some(comp)
if comp.data.custom_id == BTN_CHECK
&& matches!(comp.data.kind, ComponentInteractionDataKind::Button) =>
{
comp
}
_ => return Ok(()),
};
let (msg, author) = (&*comp.message, comp.user.id);
let env = ctx.data.read().await.get::<OsuEnv>().unwrap().clone();
let (bm, _) = super::load_beatmap(&env, msg).await.unwrap();
let user_id = super::to_user_id_query(None, &env, author).await?;
let scores = super::do_check(&env, &bm, Mods::NOMOD, &user_id).await?;
let reply = {
comp.create_response(
&ctx,
serenity::all::CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new().content(format!(
"Here are the scores by `{}` on `{}`!",
&user_id,
bm.short_link(Mods::NOMOD)
)),
),
)
.await?;
comp.get_response(&ctx).await?
};
ScoreListStyle::Grid
.display_scores(scores, bm.1, ctx, reply)
.await?;
Ok(())
})
}

View file

@ -1,6 +1,7 @@
use std::{str::FromStr, sync::Arc}; use std::{str::FromStr, sync::Arc};
use futures_util::join; use futures_util::join;
use interaction::{beatmap_components, score_components};
use rand::seq::IteratorRandom; use rand::seq::IteratorRandom;
use serenity::{ use serenity::{
builder::{CreateMessage, EditMessage}, builder::{CreateMessage, EditMessage},
@ -36,6 +37,7 @@ mod db;
pub(crate) mod display; pub(crate) mod display;
pub(crate) mod embeds; pub(crate) mod embeds;
mod hook; mod hook;
pub mod interaction;
mod link_parser; mod link_parser;
pub(crate) mod oppai_cache; pub(crate) mod oppai_cache;
mod server_rank; mod server_rank;
@ -201,6 +203,10 @@ pub async fn mania(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
pub(crate) struct BeatmapWithMode(pub Beatmap, pub Mode); pub(crate) struct BeatmapWithMode(pub Beatmap, pub Mode);
impl BeatmapWithMode { impl BeatmapWithMode {
pub fn short_link(&self, mods: Mods) -> String {
self.0.short_link(Some(self.1), Some(mods))
}
fn mode(&self) -> Mode { fn mode(&self) -> Mode {
self.1 self.1
} }
@ -221,7 +227,7 @@ pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
let osu_client = &env.client; let osu_client = &env.client;
let user = args.single::<String>()?; let user = args.single::<String>()?;
let u = match osu_client.user(UserID::from_string(user), |f| f).await? { let u = match osu_client.user(&UserID::from_string(user), |f| f).await? {
Some(u) => u, Some(u) => u,
None => { None => {
msg.reply(&ctx, "user not found...").await?; msg.reply(&ctx, "user not found...").await?;
@ -288,7 +294,9 @@ pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
reply reply
.edit( .edit(
&ctx, &ctx,
EditMessage::new().embed(beatmap_embed(&beatmap, mode, Mods::NOMOD, info)), EditMessage::new()
.embed(beatmap_embed(&beatmap, mode, Mods::NOMOD, info))
.components(vec![beatmap_components()]),
) )
.await?; .await?;
let reaction = reply.react(&ctx, '👌').await?; let reaction = reply.react(&ctx, '👌').await?;
@ -343,7 +351,7 @@ pub async fn forcesave(ctx: &Context, msg: &Message, mut args: Args) -> CommandR
let username = args.quoted().trimmed().single::<String>()?; let username = args.quoted().trimmed().single::<String>()?;
let user: Option<User> = osu_client let user: Option<User> = osu_client
.user(UserID::from_string(username.clone()), |f| f) .user(&UserID::from_string(username.clone()), |f| f)
.await?; .await?;
match user { match user {
Some(u) => { Some(u) => {
@ -370,7 +378,7 @@ async fn add_user(target: serenity::model::id::UserId, user: User, env: &OsuEnv)
.into_iter() .into_iter()
.map(|mode| async move { .map(|mode| async move {
env.client env.client
.user(UserID::ID(user.id), |f| f.mode(mode)) .user(&UserID::ID(user.id), |f| f.mode(mode))
.await .await
.unwrap_or_else(|err| { .unwrap_or_else(|err| {
eprintln!("{}", err); eprintln!("{}", err);
@ -431,12 +439,12 @@ impl FromStr for ModeArg {
async fn to_user_id_query( async fn to_user_id_query(
s: Option<UsernameArg>, s: Option<UsernameArg>,
env: &OsuEnv, env: &OsuEnv,
msg: &Message, author: serenity::all::UserId,
) -> Result<UserID, Error> { ) -> Result<UserID, Error> {
let id = match s { let id = match s {
Some(UsernameArg::Raw(s)) => return Ok(UserID::from_string(s)), Some(UsernameArg::Raw(s)) => return Ok(UserID::from_string(s)),
Some(UsernameArg::Tagged(r)) => r, Some(UsernameArg::Tagged(r)) => r,
None => msg.author.id, None => author,
}; };
env.saved_users env.saved_users
@ -481,14 +489,14 @@ pub async fn recent(ctx: &Context, msg: &Message, mut args: Args) -> CommandResu
let user = to_user_id_query( let user = to_user_id_query(
args.quoted().trimmed().single::<UsernameArg>().ok(), args.quoted().trimmed().single::<UsernameArg>().ok(),
&env, &env,
msg, msg.author.id,
) )
.await?; .await?;
let osu_client = &env.client; let osu_client = &env.client;
let user = osu_client let user = osu_client
.user(user, |f| f.mode(mode)) .user(&user, |f| f.mode(mode))
.await? .await?
.ok_or_else(|| Error::msg("User not found"))?; .ok_or_else(|| Error::msg("User not found"))?;
match nth { match nth {
@ -512,6 +520,7 @@ pub async fn recent(ctx: &Context, msg: &Message, mut args: Args) -> CommandResu
CreateMessage::new() CreateMessage::new()
.content("Here is the play that you requested".to_string()) .content("Here is the play that you requested".to_string())
.embed(score_embed(&recent_play, &beatmap_mode, &content, &user).build()) .embed(score_embed(&recent_play, &beatmap_mode, &content, &user).build())
.components(vec![score_components()])
.reference_message(msg), .reference_message(msg),
) )
.await?; .await?;
@ -523,7 +532,13 @@ pub async fn recent(ctx: &Context, msg: &Message, mut args: Args) -> CommandResu
let plays = osu_client let plays = osu_client
.user_recent(UserID::ID(user.id), |f| f.mode(mode).limit(50)) .user_recent(UserID::ID(user.id), |f| f.mode(mode).limit(50))
.await?; .await?;
style.display_scores(plays, mode, ctx, msg).await?; let reply = msg
.reply(
ctx,
format!("Here are the recent plays by `{}`!", user.username),
)
.await?;
style.display_scores(plays, mode, ctx, reply).await?;
} }
} }
Ok(()) Ok(())
@ -626,6 +641,7 @@ pub async fn last(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
CreateMessage::new() CreateMessage::new()
.content("Here is the beatmap you requested!") .content("Here is the beatmap you requested!")
.embed(beatmap_embed(&b, m, mods, info)) .embed(beatmap_embed(&b, m, mods, info))
.components(vec![beatmap_components()])
.reference_message(msg), .reference_message(msg),
) )
.await?; .await?;
@ -656,29 +672,51 @@ pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResul
return Ok(()); return Ok(());
} }
}; };
let mode = bm.1;
let mods = args.find::<Mods>().ok().unwrap_or_default(); let mods = args.find::<Mods>().ok().unwrap_or_default();
let b = &bm.0;
let m = bm.1;
let style = args let style = args
.single::<ScoreListStyle>() .single::<ScoreListStyle>()
.unwrap_or(ScoreListStyle::Grid); .unwrap_or(ScoreListStyle::Grid);
let username_arg = args.single::<UsernameArg>().ok(); let username_arg = args.single::<UsernameArg>().ok();
let user_id = match username_arg.as_ref() { let user = to_user_id_query(username_arg, &env, msg.author.id).await?;
Some(UsernameArg::Tagged(v)) => Some(*v),
None => Some(msg.author.id),
_ => None,
};
let user = to_user_id_query(username_arg, &env, msg).await?;
let osu_client = env.client; let scores = do_check(&env, &bm, mods, &user).await?;
if scores.is_empty() {
msg.reply(&ctx, "No scores found").await?;
return Ok(());
}
let reply = msg
.reply(
&ctx,
format!(
"Here are the scores by `{}` on `{}`!",
&user,
bm.short_link(mods)
),
)
.await?;
style.display_scores(scores, mode, ctx, reply).await?;
Ok(())
}
pub(crate) async fn do_check(
env: &OsuEnv,
bm: &BeatmapWithMode,
mods: Mods,
user: &UserID,
) -> Result<Vec<Score>> {
let BeatmapWithMode(b, m) = bm;
let osu_client = &env.client;
let user = osu_client let user = osu_client
.user(user, |f| f) .user(user, |f| f)
.await? .await?
.ok_or_else(|| Error::msg("User not found"))?; .ok_or_else(|| Error::msg("User not found"))?;
let mut scores = osu_client let mut scores = osu_client
.scores(b.beatmap_id, |f| f.user(UserID::ID(user.id)).mode(m)) .scores(b.beatmap_id, |f| f.user(UserID::ID(user.id)).mode(*m))
.await? .await?
.into_iter() .into_iter()
.filter(|s| s.mods.contains(mods)) .filter(|s| s.mods.contains(mods))
@ -688,23 +726,7 @@ pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResul
.partial_cmp(&a.pp.unwrap_or(-1.0)) .partial_cmp(&a.pp.unwrap_or(-1.0))
.unwrap() .unwrap()
}); });
Ok(scores)
if scores.is_empty() {
msg.reply(&ctx, "No scores found").await?;
return Ok(());
}
if let Some(user_id) = user_id {
// Save to database
env.user_bests
.save(user_id, m, scores.clone())
.await
.pls_ok();
}
style.display_scores(scores, m, ctx, msg).await?;
Ok(())
} }
#[command] #[command]
@ -722,10 +744,10 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
.map(|ModeArg(t)| t) .map(|ModeArg(t)| t)
.unwrap_or(Mode::Std); .unwrap_or(Mode::Std);
let user = to_user_id_query(args.single::<UsernameArg>().ok(), &env, msg).await?; let user_id = to_user_id_query(args.single::<UsernameArg>().ok(), &env, msg.author.id).await?;
let osu_client = &env.client; let osu_client = &env.client;
let user = osu_client let user = osu_client
.user(user, |f| f.mode(mode)) .user(&user_id, |f| f.mode(mode))
.await? .await?
.ok_or_else(|| Error::msg("User not found"))?; .ok_or_else(|| Error::msg("User not found"))?;
@ -757,6 +779,7 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
.top_record(rank) .top_record(rank)
.build(), .build(),
) )
.components(vec![score_components()])
}) })
.await?; .await?;
@ -767,7 +790,10 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
let plays = osu_client let plays = osu_client
.user_best(UserID::ID(user.id), |f| f.mode(mode).limit(100)) .user_best(UserID::ID(user.id), |f| f.mode(mode).limit(100))
.await?; .await?;
style.display_scores(plays, mode, ctx, msg).await?; let reply = msg
.reply(&ctx, format!("Here are the top plays by `{}`!", user_id))
.await?;
style.display_scores(plays, mode, ctx, reply).await?;
} }
} }
Ok(()) Ok(())
@ -796,10 +822,10 @@ async fn get_user(
mut args: Args, mut args: Args,
mode: Mode, mode: Mode,
) -> CommandResult { ) -> CommandResult {
let user = to_user_id_query(args.single::<UsernameArg>().ok(), &env, msg).await?; let user = to_user_id_query(args.single::<UsernameArg>().ok(), &env, msg.author.id).await?;
let osu_client = &env.client; let osu_client = &env.client;
let meta_cache = &env.beatmaps; let meta_cache = &env.beatmaps;
let user = osu_client.user(user, |f| f.mode(mode)).await?; let user = osu_client.user(&user, |f| f.mode(mode)).await?;
match user { match user {
Some(u) => { Some(u) => {

View file

@ -327,12 +327,21 @@ pub async fn show_leaderboard(ctx: &Context, msg: &Message, mut args: Args) -> C
} }
if let ScoreListStyle::Grid = style { if let ScoreListStyle::Grid = style {
let reply = msg
.reply(
&ctx,
format!(
"Here are the top scores on beatmap `{}` of this server!",
bm.short_link(Mods::NOMOD)
),
)
.await?;
style style
.display_scores( .display_scores(
scores.into_iter().map(|(_, _, a)| a).collect(), scores.into_iter().map(|(_, _, a)| a).collect(),
mode, mode,
ctx, ctx,
msg, reply,
) )
.await?; .await?;
return Ok(()); return Ok(());

View file

@ -56,7 +56,7 @@ impl Client {
pub async fn user( pub async fn user(
&self, &self,
user: UserID, user: &UserID,
f: impl FnOnce(&mut UserRequestBuilder) -> &mut UserRequestBuilder, f: impl FnOnce(&mut UserRequestBuilder) -> &mut UserRequestBuilder,
) -> Result<Option<User>, Error> { ) -> Result<Option<User>, Error> {
let mut r = UserRequestBuilder::new(user.clone()); let mut r = UserRequestBuilder::new(user.clone());
@ -66,7 +66,7 @@ impl Client {
self.user_header_cache self.user_header_cache
.lock() .lock()
.await .await
.insert(id, u.clone().map(|v| v.into())); .insert(*id, u.clone().map(|v| v.into()));
} }
Ok(u) Ok(u)
} }
@ -77,7 +77,7 @@ impl Client {
let v = self.user_header_cache.lock().await.get(&id).cloned(); let v = self.user_header_cache.lock().await.get(&id).cloned();
match v { match v {
Some(v) => v, Some(v) => v,
None => self.user(UserID::ID(id), |f| f).await?.map(|v| v.into()), None => self.user(&UserID::ID(id), |f| f).await?.map(|v| v.into()),
} }
}) })
} }

View file

@ -1,3 +1,5 @@
use core::fmt;
use crate::models::{Mode, Mods}; use crate::models::{Mode, Mods};
use crate::Client; use crate::Client;
use rosu_v2::error::OsuError; use rosu_v2::error::OsuError;
@ -9,6 +11,15 @@ pub enum UserID {
ID(u64), ID(u64),
} }
impl fmt::Display for UserID {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UserID::Username(u) => u.fmt(f),
UserID::ID(id) => id.fmt(f),
}
}
}
impl From<UserID> for rosu_v2::prelude::UserId { impl From<UserID> for rosu_v2::prelude::UserId {
fn from(value: UserID) -> Self { fn from(value: UserID) -> Self {
match value { match value {

View file

@ -1,5 +1,5 @@
use crate::{async_trait, future, Context, Result}; use crate::{async_trait, future, Context, Result};
use serenity::model::channel::Message; use serenity::{all::Interaction, model::channel::Message};
/// Hook represents the asynchronous hook that is run on every message. /// Hook represents the asynchronous hook that is run on every message.
#[async_trait] #[async_trait]
@ -22,3 +22,25 @@ where
self(ctx, message).await self(ctx, message).await
} }
} }
/// InteractionHook represents the asynchronous hook that is run on every interaction.
#[async_trait]
pub trait InteractionHook: Send + Sync {
async fn call(&mut self, ctx: &Context, interaction: &Interaction) -> Result<()>;
}
#[async_trait]
impl<T> InteractionHook for T
where
T: for<'a> FnMut(
&'a Context,
&'a Interaction,
)
-> std::pin::Pin<Box<dyn future::Future<Output = Result<()>> + 'a + Send>>
+ Send
+ Sync,
{
async fn call(&mut self, ctx: &Context, interaction: &Interaction) -> Result<()> {
self(ctx, interaction).await
}
}

View file

@ -164,7 +164,8 @@ pub async fn paginate(
paginate_with_first_message(pager, ctx, message, timeout).await paginate_with_first_message(pager, ctx, message, timeout).await
} }
async fn paginate_with_first_message( /// Paginate with the first message already created.
pub async fn paginate_with_first_message(
mut pager: impl Paginate, mut pager: impl Paginate,
ctx: &Context, ctx: &Context,
mut message: Message, mut message: Message,

View file

@ -1,5 +1,7 @@
use dotenv::var; use dotenv::var;
use hook::InteractionHook;
use serenity::{ use serenity::{
all::{CreateInteractionResponseMessage, Interaction},
framework::standard::{ framework::standard::{
macros::hook, BucketBuilder, CommandResult, Configuration, DispatchError, StandardFramework, macros::hook, BucketBuilder, CommandResult, Configuration, DispatchError, StandardFramework,
}, },
@ -19,6 +21,7 @@ mod compose_framework;
struct Handler { struct Handler {
hooks: Vec<RwLock<Box<dyn Hook>>>, hooks: Vec<RwLock<Box<dyn Hook>>>,
interaction_hooks: Vec<RwLock<Box<dyn InteractionHook>>>,
ready_hooks: Vec<fn(&Context) -> CommandResult>, ready_hooks: Vec<fn(&Context) -> CommandResult>,
} }
@ -26,6 +29,7 @@ impl Handler {
fn new() -> Handler { fn new() -> Handler {
Handler { Handler {
hooks: vec![], hooks: vec![],
interaction_hooks: vec![],
ready_hooks: vec![], ready_hooks: vec![],
} }
} }
@ -37,6 +41,10 @@ impl Handler {
fn push_ready_hook(&mut self, f: fn(&Context) -> CommandResult) { fn push_ready_hook(&mut self, f: fn(&Context) -> CommandResult) {
self.ready_hooks.push(f); self.ready_hooks.push(f);
} }
fn push_interaction_hook<T: InteractionHook + 'static>(&mut self, f: T) {
self.interaction_hooks.push(RwLock::new(Box::new(f)));
}
} }
/// Environment to be passed into the framework /// Environment to be passed into the framework
@ -99,6 +107,36 @@ impl EventHandler for Handler {
f(&ctx).pls_ok(); f(&ctx).pls_ok();
} }
} }
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
let ctx = &ctx;
let interaction = &interaction;
self.interaction_hooks
.iter()
.map(|hook| {
hook.write()
.then(|mut h| async move { h.call(&ctx, &interaction).await })
})
.collect::<stream::FuturesUnordered<_>>()
.for_each(|v| async move {
if let Err(e) = v {
let response = serenity::all::CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.ephemeral(true)
.content(format!("Interaction failed: {}", e)),
);
match interaction {
Interaction::Command(c) => c.create_response(ctx, response).await.pls_ok(),
Interaction::Component(c) => {
c.create_response(ctx, response).await.pls_ok()
}
Interaction::Modal(c) => c.create_response(ctx, response).await.pls_ok(),
_ => None,
};
}
})
.await;
}
} }
/// Returns whether the user has "MANAGE_MESSAGES" permission in the channel. /// Returns whether the user has "MANAGE_MESSAGES" permission in the channel.
@ -129,6 +167,7 @@ async fn main() {
handler.push_hook(youmubot_osu::discord::hook); handler.push_hook(youmubot_osu::discord::hook);
handler.push_hook(youmubot_osu::discord::dot_osu_hook); handler.push_hook(youmubot_osu::discord::dot_osu_hook);
handler.push_hook(youmubot_osu::discord::score_hook); handler.push_hook(youmubot_osu::discord::score_hook);
handler.push_interaction_hook(youmubot_osu::discord::interaction::handle_check_button)
} }
#[cfg(feature = "codeforces")] #[cfg(feature = "codeforces")]
handler.push_hook(youmubot_cf::InfoHook); handler.push_hook(youmubot_cf::InfoHook);