diff --git a/youmubot-osu/src/discord/hook.rs b/youmubot-osu/src/discord/hook.rs index 5efb21b..1e44c76 100644 --- a/youmubot-osu/src/discord/hook.rs +++ b/youmubot-osu/src/discord/hook.rs @@ -25,49 +25,54 @@ lazy_static! { ).unwrap(); } -pub async fn hook(ctx: &Context, msg: &Message) -> Result<()> { - if msg.author.bot { - return Ok(()); - } - let (old_links, new_links, short_links) = ( - handle_old_links(ctx, &msg.content), - handle_new_links(ctx, &msg.content), - handle_short_links(ctx, &msg, &msg.content), - ); - let last_beatmap = stream::select(old_links, stream::select(new_links, short_links)) - .then(|l| async move { - let mut bm: Option = None; - msg.channel_id - .send_message(&ctx, |m| match l.embed { - EmbedType::Beatmap(b, info, mods) => { - let t = handle_beatmap(&b, info, l.link, l.mode, mods, m); - let mode = l.mode.unwrap_or(b.mode); - bm = Some(super::BeatmapWithMode(b, mode)); - t +pub fn hook<'a>( + ctx: &'a Context, + msg: &'a Message, +) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + if msg.author.bot { + return Ok(()); + } + let (old_links, new_links, short_links) = ( + handle_old_links(ctx, &msg.content), + handle_new_links(ctx, &msg.content), + handle_short_links(ctx, &msg, &msg.content), + ); + let last_beatmap = stream::select(old_links, stream::select(new_links, short_links)) + .then(|l| async move { + let mut bm: Option = None; + msg.channel_id + .send_message(&ctx, |m| match l.embed { + EmbedType::Beatmap(b, info, mods) => { + let t = handle_beatmap(&b, info, l.link, l.mode, mods, m); + let mode = l.mode.unwrap_or(b.mode); + bm = Some(super::BeatmapWithMode(b, mode)); + t + } + EmbedType::Beatmapset(b) => handle_beatmapset(b, l.link, l.mode, m), + }) + .await?; + let r: Result<_> = Ok(bm); + r + }) + .filter_map(|v| async move { + match v { + Ok(v) => v, + Err(e) => { + eprintln!("{}", e); + None } - EmbedType::Beatmapset(b) => handle_beatmapset(b, l.link, l.mode, m), - }) - .await?; - let r: Result<_> = Ok(bm); - r - }) - .filter_map(|v| async move { - match v { - Ok(v) => v, - Err(e) => { - eprintln!("{}", e); - None } - } - }) - .fold(None, |_, v| async move { Some(v) }) - .await; + }) + .fold(None, |_, v| async move { Some(v) }) + .await; - // Save the beatmap for query later. - if let Some(t) = last_beatmap { - super::cache::save_beatmap(&*ctx.data.read().await, msg.channel_id, &t)?; - } - Ok(()) + // Save the beatmap for query later. + if let Some(t) = last_beatmap { + super::cache::save_beatmap(&*ctx.data.read().await, msg.channel_id, &t)?; + } + Ok(()) + }) } enum EmbedType { diff --git a/youmubot-osu/src/discord/mod.rs b/youmubot-osu/src/discord/mod.rs index d64fc48..49c9abf 100644 --- a/youmubot-osu/src/discord/mod.rs +++ b/youmubot-osu/src/discord/mod.rs @@ -49,7 +49,7 @@ impl TypeMapKey for OsuClient { /// - Commands on the "osu" prefix /// - Hooks. Hooks are completely opt-in. /// -pub async fn setup( +pub fn setup( path: &std::path::Path, data: &mut TypeMap, announcers: &mut AnnouncerHandler, diff --git a/youmubot-prelude/src/hook.rs b/youmubot-prelude/src/hook.rs new file mode 100644 index 0000000..5760b17 --- /dev/null +++ b/youmubot-prelude/src/hook.rs @@ -0,0 +1,24 @@ +use crate::{async_trait, future, Context, Result}; +use serenity::model::channel::Message; + +/// Hook represents the asynchronous hook that is run on every message. +#[async_trait] +pub trait Hook: Send + Sync { + async fn call(&mut self, ctx: &Context, message: &Message) -> Result<()>; +} + +#[async_trait] +impl Hook for T +where + T: for<'a> FnMut( + &'a Context, + &'a Message, + ) + -> std::pin::Pin> + 'a + Send>> + + Send + + Sync, +{ + async fn call(&mut self, ctx: &Context, message: &Message) -> Result<()> { + self(ctx, message).await + } +} diff --git a/youmubot-prelude/src/lib.rs b/youmubot-prelude/src/lib.rs index 1efc063..b979bae 100644 --- a/youmubot-prelude/src/lib.rs +++ b/youmubot-prelude/src/lib.rs @@ -5,11 +5,13 @@ use std::sync::Arc; pub mod announcer; pub mod args; +pub mod hook; pub mod pagination; pub mod setup; pub use announcer::{Announcer, AnnouncerHandler}; pub use args::{Duration, UsernameArg}; +pub use hook::Hook; pub use pagination::paginate; /// Re-exporting async_trait helps with implementing Announcer. diff --git a/youmubot/src/main.rs b/youmubot/src/main.rs index 67345e9..1f3fe4a 100644 --- a/youmubot/src/main.rs +++ b/youmubot/src/main.rs @@ -13,13 +13,17 @@ use serenity::{ use youmubot_prelude::*; struct Handler { - hooks: Vec ()>, + hooks: Vec>>, } impl Handler { fn new() -> Handler { Handler { hooks: vec![] } } + + fn push_hook(&mut self, f: T) { + self.hooks.push(RwLock::new(Box::new(f))); + } } #[async_trait] @@ -28,8 +32,22 @@ impl EventHandler for Handler { println!("{} is connected!", ready.user.name); } - async fn message(&self, mut ctx: Context, message: Message) { - self.hooks.iter().for_each(|f| f(&mut ctx, &message)); + async fn message(&self, ctx: Context, message: Message) { + self.hooks + .iter() + .map(|hook| { + let ctx = ctx.clone(); + let message = message.clone(); + hook.write() + .then(|mut h| async move { h.call(&ctx, &message).await }) + }) + .collect::>() + .for_each(|v| async move { + if let Err(e) = v { + eprintln!("{}", e) + } + }) + .await; } } @@ -53,12 +71,12 @@ async fn main() { println!("Loaded dotenv from {:?}", path); } - let handler = Handler::new(); + let mut handler = Handler::new(); // Set up hooks #[cfg(feature = "osu")] - handler.hooks.push(youmubot_osu::discord::hook); + handler.push_hook(youmubot_osu::discord::hook); #[cfg(feature = "codeforces")] - handler.hooks.push(youmubot_cf::codeforces_info_hook); + handler.push_hook(youmubot_cf::codeforces_info_hook); // Collect the token let token = var("TOKEN").expect("Please set TOKEN as the Discord Bot's token to be used.");