From f1719019d1e73214e31f8eaa5850cb6b9aa8bec4 Mon Sep 17 00:00:00 2001 From: Natsu Kagami Date: Mon, 7 Sep 2020 02:09:06 -0400 Subject: [PATCH] Core/Prelude: Fix lifetime unsoundness --- youmubot-core/src/admin/mod.rs | 4 +- youmubot-core/src/admin/soft_ban.rs | 7 +- youmubot-core/src/community/mod.rs | 18 +++-- youmubot-core/src/community/roles.rs | 117 ++++++++++++++------------- youmubot-core/src/community/votes.rs | 22 ++--- youmubot-core/src/lib.rs | 14 ++-- youmubot-prelude/src/announcer.rs | 16 ++-- youmubot-prelude/src/pagination.rs | 60 +++++++++----- youmubot-prelude/src/setup.rs | 4 +- 9 files changed, 150 insertions(+), 112 deletions(-) diff --git a/youmubot-core/src/admin/mod.rs b/youmubot-core/src/admin/mod.rs index c9bcfa0..4270883 100644 --- a/youmubot-core/src/admin/mod.rs +++ b/youmubot-core/src/admin/mod.rs @@ -40,7 +40,7 @@ async fn clean(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { messages .into_iter() .filter(|v| v.author.id == self_id) - .map(|m| m.delete(&ctx)) + .map(|m| async move { m.delete(&ctx).await }) .collect::>() .try_collect::<()>() .await?; @@ -54,7 +54,7 @@ async fn clean(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { msg.react(&ctx, '🌋').await?; if let Channel::Guild(_) = &channel { tokio::time::delay_for(std::time::Duration::from_secs(2)).await; - msg.delete(&ctx).await; + msg.delete(&ctx).await.ok(); } Ok(()) diff --git a/youmubot-core/src/admin/soft_ban.rs b/youmubot-core/src/admin/soft_ban.rs index 56f1d84..ce2b510 100644 --- a/youmubot-core/src/admin/soft_ban.rs +++ b/youmubot-core/src/admin/soft_ban.rs @@ -30,7 +30,7 @@ pub async fn soft_ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandRe }; let guild = msg.guild_id.ok_or(Error::msg("Command is guild only"))?; - let db = SoftBans::open(&*data); + let mut db = SoftBans::open(&*data); let val = db .borrow()? .get(&guild) @@ -82,6 +82,7 @@ pub async fn soft_ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandRe #[only_in("guilds")] pub async fn soft_ban_init(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let role_id = args.single::()?; + let data = ctx.data.read().await; let guild = msg.guild(&ctx).await.unwrap(); // Check whether the role_id is the one we wanted if !guild.roles.contains_key(&role_id) { @@ -91,7 +92,7 @@ pub async fn soft_ban_init(ctx: &Context, msg: &Message, mut args: Args) -> Comm )))?; } // Check if we already set up - let db = SoftBans::open(&*ctx.data.read().await); + let mut db = SoftBans::open(&*data); let set_up = db.borrow()?.contains_key(&guild.id); if !set_up { @@ -104,7 +105,7 @@ pub async fn soft_ban_init(ctx: &Context, msg: &Message, mut args: Args) -> Comm Ok(()) } -// Watch the soft bans. +// Watch the soft bans. Blocks forever. pub async fn watch_soft_bans(cache_http: Arc, data: AppData) { loop { // Scope so that locks are released diff --git a/youmubot-core/src/community/mod.rs b/youmubot-core/src/community/mod.rs index e3f29a4..54b5e47 100644 --- a/youmubot-core/src/community/mod.rs +++ b/youmubot-core/src/community/mod.rs @@ -41,10 +41,7 @@ pub async fn choose(ctx: &Context, m: &Message, mut args: Args) -> CommandResult } else { args.single::()? }; - let role = match args.single::().ok() { - Some(v) => v.to_role_cached(&ctx).await, - None => None, - }; + let role = args.single::().ok(); let users: Result, Error> = { let guild = m.guild(&ctx).await.unwrap(); @@ -68,16 +65,21 @@ pub async fn choose(ctx: &Context, m: &Message, mut args: Args) -> CommandResult }) .map(|mem| future::ready(mem)) .collect::>() - .filter(|member| async { + .filter_map(|member| async move { // Filter by role if provided if let Some(role) = role { - member + if member .roles(&ctx) .await - .map(|roles| roles.into_iter().any(|r| role.id == r.id)) + .map(|roles| roles.into_iter().any(|r| role == r.id)) .unwrap_or(false) + { + Some(member) + } else { + None + } } else { - true + Some(member) } }) .collect() diff --git a/youmubot-core/src/community/roles.rs b/youmubot-core/src/community/roles.rs index 5181307..93a75aa 100644 --- a/youmubot-core/src/community/roles.rs +++ b/youmubot-core/src/community/roles.rs @@ -34,66 +34,69 @@ async fn list(ctx: &Context, m: &Message, _: Args) -> CommandResult { let pages = (roles.len() + ROLES_PER_PAGE - 1) / ROLES_PER_PAGE; paginate( - |page, ctx, msg| async move { - let page = page as usize; - let start = page * ROLES_PER_PAGE; - let end = roles.len().min(start + ROLES_PER_PAGE); - if end <= start { - return Ok(false); - } - let roles = &roles[start..end]; - let nw = roles // name width - .iter() - .map(|(r, _)| r.name.len()) - .max() - .unwrap() - .max(6); - let idw = roles[0].0.id.to_string().len(); - let dw = roles - .iter() - .map(|v| v.1.len()) - .max() - .unwrap() - .max(" Description ".len()); - let mut m = MessageBuilder::new(); - m.push_line("```"); + |page, ctx, msg| { + let roles = roles.clone(); + Box::pin(async move { + let page = page as usize; + let start = page * ROLES_PER_PAGE; + let end = roles.len().min(start + ROLES_PER_PAGE); + if end <= start { + return Ok(false); + } + let roles = &roles[start..end]; + let nw = roles // name width + .iter() + .map(|(r, _)| r.name.len()) + .max() + .unwrap() + .max(6); + let idw = roles[0].0.id.to_string().len(); + let dw = roles + .iter() + .map(|v| v.1.len()) + .max() + .unwrap() + .max(" Description ".len()); + let mut m = MessageBuilder::new(); + m.push_line("```"); - // Table header - m.push_line(format!( - "{:nw$} | {:idw$} | {:dw$}", - "Name", - "ID", - "Description", - nw = nw, - idw = idw, - dw = dw, - )); - m.push_line(format!( - "{:->nw$}---{:->idw$}---{:->dw$}", - "", - "", - "", - nw = nw, - idw = idw, - dw = dw, - )); - - for (role, description) in roles.iter() { + // Table header m.push_line(format!( "{:nw$} | {:idw$} | {:dw$}", - role.name, - role.id, - description, + "Name", + "ID", + "Description", + nw = nw, + idw = idw, + dw = dw, + )); + m.push_line(format!( + "{:->nw$}---{:->idw$}---{:->dw$}", + "", + "", + "", nw = nw, idw = idw, dw = dw, )); - } - m.push_line("```"); - m.push(format!("Page **{}/{}**", page + 1, pages)); - msg.edit(ctx, |f| f.content(m.to_string())).await?; - Ok(true) + for (role, description) in roles.iter() { + m.push_line(format!( + "{:nw$} | {:idw$} | {:dw$}", + role.name, + role.id, + description, + nw = nw, + idw = idw, + dw = dw, + )); + } + m.push_line("```"); + m.push(format!("Page **{}/{}**", page + 1, pages)); + + msg.edit(ctx, |f| f.content(m.to_string())).await?; + Ok(true) + }) }, ctx, m.channel_id, @@ -155,6 +158,7 @@ async fn toggle(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { #[only_in(guilds)] async fn add(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { let role = args.single_quoted::()?; + let data = ctx.data.read().await; let description = args.single::()?; let guild_id = m.guild_id.unwrap(); let roles = guild_id.to_partial_guild(&ctx).await?.roles; @@ -164,7 +168,7 @@ async fn add(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { m.reply(&ctx, "No such role exists").await?; } Some(role) - if DB::open(&*ctx.data.read().await) + if DB::open(&*data) .borrow()? .get(&guild_id) .map(|g| g.contains_key(&role.id)) @@ -174,7 +178,7 @@ async fn add(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { .await?; } Some(role) => { - DB::open(&*ctx.data.read().await) + DB::open(&*data) .borrow_mut()? .entry(guild_id) .or_default() @@ -200,6 +204,7 @@ async fn add(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { #[only_in(guilds)] async fn remove(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { let role = args.single_quoted::()?; + let data = ctx.data.read().await; let guild_id = m.guild_id.unwrap(); let roles = guild_id.to_partial_guild(&ctx).await?.roles; let role = role_from_string(&role, &roles); @@ -208,7 +213,7 @@ async fn remove(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { m.reply(&ctx, "No such role exists").await?; } Some(role) - if !DB::open(&*ctx.data.read().await) + if !DB::open(&*data) .borrow()? .get(&guild_id) .map(|g| g.contains_key(&role.id)) @@ -218,7 +223,7 @@ async fn remove(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { .await?; } Some(role) => { - DB::open(&*ctx.data.read().await) + DB::open(&*data) .borrow_mut()? .entry(guild_id) .or_default() diff --git a/youmubot-core/src/community/votes.rs b/youmubot-core/src/community/votes.rs index 524ee5d..0eaa09f 100644 --- a/youmubot-core/src/community/votes.rs +++ b/youmubot-core/src/community/votes.rs @@ -28,7 +28,7 @@ pub async fn vote(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult let args = args.quoted(); let _duration = args.single::()?; let duration = &_duration.0; - if *duration < Duration::from_secs(2 * 60) || *duration > Duration::from_secs(60 * 60 * 24) { + if *duration < Duration::from_secs(2) || *duration > Duration::from_secs(60 * 60 * 24) { msg.reply(ctx, format!("😒 Invalid duration ({}). The voting time should be between **2 minutes** and **1 day**.", _duration)).await?; return Ok(()); } @@ -97,6 +97,8 @@ pub async fn vote(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult }) }).await?; msg.delete(&ctx).await?; + drop(msg); + // React on all the choices choices .iter() @@ -110,16 +112,18 @@ pub async fn vote(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult .await?; // A handler for votes. - let mut user_reactions: Map> = choices + let user_reactions: Map> = choices .iter() .map(|(emote, _)| (emote.clone(), Set::new())) .collect(); // Collect reactions... - msg.await_reactions(&ctx) + let user_reactions = panel + .await_reactions(&ctx) + .removed(true) .timeout(*duration) .await - .scan(user_reactions, |set, reaction| async move { + .fold(user_reactions, |mut set, reaction| async move { let (reaction, is_add) = match &*reaction { ReactionAction::Added(r) => (r, true), ReactionAction::Removed(r) => (r, false), @@ -128,23 +132,22 @@ pub async fn vote(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult if let Some(users) = set.get_mut(s.as_str()) { users } else { - return None; + return set; } } else { - return None; + return set; }; let user_id = match reaction.user_id { Some(v) => v, - None => return None, + None => return set, }; if is_add { users.insert(user_id); } else { users.remove(&user_id); } - Some(()) + set }) - .collect::<()>() .await; // Handle choices @@ -233,3 +236,4 @@ const REACTIONS: [&'static str; 90] = [ // Assertions static_assertions::const_assert!(MAX_CHOICES <= REACTIONS.len()); +static_assertions::const_assert!(MAX_CHOICES <= REACTIONS.len()); diff --git a/youmubot-core/src/lib.rs b/youmubot-core/src/lib.rs index 19dcbe2..c9db469 100644 --- a/youmubot-core/src/lib.rs +++ b/youmubot-core/src/lib.rs @@ -20,26 +20,30 @@ pub use fun::FUN_GROUP; pub fn setup( path: &std::path::Path, client: &serenity::client::Client, - data: &mut youmubot_prelude::ShareMap, + data: &mut TypeMap, ) -> serenity::framework::standard::CommandResult { db::SoftBans::insert_into(&mut *data, &path.join("soft_bans.yaml"))?; db::Roles::insert_into(&mut *data, &path.join("roles.yaml"))?; // Create handler threads - std::thread::spawn(admin::watch_soft_bans(client)); + tokio::spawn(admin::watch_soft_bans( + client.cache_and_http.clone(), + client.data.clone(), + )); Ok(()) } // A help command #[help] -pub fn help( - context: &mut Context, +pub async fn help( + context: &Context, msg: &Message, args: Args, help_options: &'static HelpOptions, groups: &[&'static CommandGroup], owners: HashSet, ) -> CommandResult { - help_commands::with_embeds(context, msg, args, help_options, groups, owners) + help_commands::with_embeds(context, msg, args, help_options, groups, owners).await; + Ok(()) } diff --git a/youmubot-prelude/src/announcer.rs b/youmubot-prelude/src/announcer.rs index a916c17..ef6693e 100644 --- a/youmubot-prelude/src/announcer.rs +++ b/youmubot-prelude/src/announcer.rs @@ -1,6 +1,5 @@ use crate::{AppData, Result}; use async_trait::async_trait; -use crossbeam_channel::after; use futures_util::{ future::{join_all, ready, FutureExt}, stream::{FuturesUnordered, StreamExt}, @@ -78,7 +77,7 @@ impl MemberToChannels { pub struct AnnouncerHandler { cache_http: Arc, data: AppData, - announcers: HashMap<&'static str, RwLock>>, + announcers: HashMap<&'static str, RwLock>>, } // Querying for the AnnouncerHandler in the internal data returns a vec of keys. @@ -100,7 +99,11 @@ impl AnnouncerHandler { /// Insert a new announcer into the handler. /// /// The handler must take an unique key. If a duplicate is found, this method panics. - pub fn add(&mut self, key: &'static str, announcer: impl Announcer + 'static) -> &mut Self { + pub fn add( + &mut self, + key: &'static str, + announcer: impl Announcer + Send + Sync + 'static, + ) -> &mut Self { if let Some(_) = self .announcers .insert(key, RwLock::new(Box::new(announcer))) @@ -132,7 +135,7 @@ impl AnnouncerHandler { data: AppData, cache_http: Arc, key: &'static str, - announcer: &'_ RwLock>, + announcer: &'_ RwLock>, ) -> Result<()> { let channels = MemberToChannels(Self::get_guilds(&data, key).await?); announcer @@ -151,7 +154,8 @@ impl AnnouncerHandler { self.data.write().await.insert::(keys.clone()); loop { eprintln!("{}: announcer started scanning", chrono::Utc::now()); - let after_timer = after(cooldown); + // let after_timer = after(cooldown); + let after = tokio::time::delay_for(cooldown); join_all(self.announcers.iter().map(|(key, announcer)| { eprintln!(" - scanning key `{}`", key); Self::announce(self.data.clone(), self.cache_http.clone(), *key, announcer).map( @@ -164,7 +168,7 @@ impl AnnouncerHandler { })) .await; eprintln!("{}: announcer finished scanning", chrono::Utc::now()); - after_timer.recv().ok(); + after.await; } } } diff --git a/youmubot-prelude/src/pagination.rs b/youmubot-prelude/src/pagination.rs index 55d8c2d..c16802e 100644 --- a/youmubot-prelude/src/pagination.rs +++ b/youmubot-prelude/src/pagination.rs @@ -13,18 +13,39 @@ use tokio::time as tokio_time; const ARROW_RIGHT: &'static str = "➡️"; const ARROW_LEFT: &'static str = "⬅️"; -/// Paginate! with a pager function. +#[async_trait::async_trait] +pub trait Paginate { + async fn render(&mut self, page: u8, ctx: &Context, m: &mut Message) -> Result; +} + +#[async_trait::async_trait] +impl Paginate for T +where + T: for<'m> FnMut( + u8, + &'m Context, + &'m mut Message, + ) -> std::pin::Pin> + Send + 'm>> + + Send, +{ + async fn render(&mut self, page: u8, ctx: &Context, m: &mut Message) -> Result { + self(page, ctx, m).await + } +} + +// Paginate! with a pager function. /// If awaited, will block until everything is done. -pub async fn paginate<'a, T, F>( - mut pager: T, - ctx: &'a Context, +pub async fn paginate( + mut pager: impl for<'m> FnMut( + u8, + &'m Context, + &'m mut Message, + ) -> std::pin::Pin> + Send + 'm>> + + Send, + ctx: &Context, channel: ChannelId, timeout: std::time::Duration, -) -> Result<()> -where - T: for<'m> FnMut(u8, &'a Context, &'m mut Message) -> F, - F: Future>, -{ +) -> Result<()> { let mut message = channel .send_message(&ctx, |e| e.content("Youmu is loading the first page...")) .await?; @@ -35,8 +56,9 @@ where message .react(&ctx, ReactionType::try_from(ARROW_RIGHT)?) .await?; + pager(0, ctx, &mut message).await?; // Build a reaction collector - let mut reaction_collector = message.await_reactions(&ctx).await; + let mut reaction_collector = message.await_reactions(&ctx).removed(true).await; let mut page = 0; // Loop the handler function. @@ -59,29 +81,25 @@ where } // Handle the reaction and return a new page number. -async fn handle_reaction<'a, T, F>( +async fn handle_reaction( page: u8, - pager: &mut T, - ctx: &'a Context, - message: &'_ mut Message, + pager: &mut impl Paginate, + ctx: &Context, + message: &mut Message, reaction: &ReactionAction, -) -> Result -where - T: for<'m> FnMut(u8, &'a Context, &'m mut Message) -> F, - F: Future>, -{ +) -> Result { let reaction = match reaction { ReactionAction::Added(v) | ReactionAction::Removed(v) => v, }; match &reaction.emoji { ReactionType::Unicode(ref s) => match s.as_str() { ARROW_LEFT if page == 0 => Ok(page), - ARROW_LEFT => Ok(if pager(page - 1, ctx, message).await? { + ARROW_LEFT => Ok(if pager.render(page - 1, ctx, message).await? { page - 1 } else { page }), - ARROW_RIGHT => Ok(if pager(page + 1, ctx, message).await? { + ARROW_RIGHT => Ok(if pager.render(page + 1, ctx, message).await? { page + 1 } else { page diff --git a/youmubot-prelude/src/setup.rs b/youmubot-prelude/src/setup.rs index 7f380f2..2d3704c 100644 --- a/youmubot-prelude/src/setup.rs +++ b/youmubot-prelude/src/setup.rs @@ -1,10 +1,10 @@ -use serenity::{framework::standard::StandardFramework, prelude::*}; +use serenity::prelude::*; use std::path::Path; /// Set up the prelude libraries. /// /// Panics on failure: Youmubot should *NOT* attempt to continue when this function fails. -pub fn setup_prelude(db_path: &Path, data: &mut TypeMap, _: &mut StandardFramework) { +pub fn setup_prelude(db_path: &Path, data: &mut TypeMap) { // Setup the announcer DB. crate::announcer::AnnouncerChannels::insert_into(data, db_path.join("announcers.yaml")) .expect("Announcers DB set up");