Core/Prelude: Fix lifetime unsoundness

This commit is contained in:
Natsu Kagami 2020-09-07 02:09:06 -04:00
parent c672a8836c
commit f1719019d1
Signed by: nki
GPG key ID: 73376E117CD20735
9 changed files with 150 additions and 112 deletions

View file

@ -40,7 +40,7 @@ async fn clean(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
messages
.into_iter()
.filter(|v| v.author.id == self_id)
.map(|m| m.delete(&ctx))
.map(|m| async move { m.delete(&ctx).await })
.collect::<stream::FuturesUnordered<_>>()
.try_collect::<()>()
.await?;
@ -54,7 +54,7 @@ async fn clean(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
msg.react(&ctx, '🌋').await?;
if let Channel::Guild(_) = &channel {
tokio::time::delay_for(std::time::Duration::from_secs(2)).await;
msg.delete(&ctx).await;
msg.delete(&ctx).await.ok();
}
Ok(())

View file

@ -30,7 +30,7 @@ pub async fn soft_ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandRe
};
let guild = msg.guild_id.ok_or(Error::msg("Command is guild only"))?;
let db = SoftBans::open(&*data);
let mut db = SoftBans::open(&*data);
let val = db
.borrow()?
.get(&guild)
@ -82,6 +82,7 @@ pub async fn soft_ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandRe
#[only_in("guilds")]
pub async fn soft_ban_init(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let role_id = args.single::<RoleId>()?;
let data = ctx.data.read().await;
let guild = msg.guild(&ctx).await.unwrap();
// Check whether the role_id is the one we wanted
if !guild.roles.contains_key(&role_id) {
@ -91,7 +92,7 @@ pub async fn soft_ban_init(ctx: &Context, msg: &Message, mut args: Args) -> Comm
)))?;
}
// Check if we already set up
let db = SoftBans::open(&*ctx.data.read().await);
let mut db = SoftBans::open(&*data);
let set_up = db.borrow()?.contains_key(&guild.id);
if !set_up {
@ -104,7 +105,7 @@ pub async fn soft_ban_init(ctx: &Context, msg: &Message, mut args: Args) -> Comm
Ok(())
}
// Watch the soft bans.
// Watch the soft bans. Blocks forever.
pub async fn watch_soft_bans(cache_http: Arc<CacheAndHttp>, data: AppData) {
loop {
// Scope so that locks are released

View file

@ -41,10 +41,7 @@ pub async fn choose(ctx: &Context, m: &Message, mut args: Args) -> CommandResult
} else {
args.single::<String>()?
};
let role = match args.single::<RoleId>().ok() {
Some(v) => v.to_role_cached(&ctx).await,
None => None,
};
let role = args.single::<RoleId>().ok();
let users: Result<Vec<_>, Error> = {
let guild = m.guild(&ctx).await.unwrap();
@ -68,16 +65,21 @@ pub async fn choose(ctx: &Context, m: &Message, mut args: Args) -> CommandResult
})
.map(|mem| future::ready(mem))
.collect::<stream::FuturesUnordered<_>>()
.filter(|member| async {
.filter_map(|member| async move {
// Filter by role if provided
if let Some(role) = role {
member
if member
.roles(&ctx)
.await
.map(|roles| roles.into_iter().any(|r| role.id == r.id))
.map(|roles| roles.into_iter().any(|r| role == r.id))
.unwrap_or(false)
{
Some(member)
} else {
None
}
} else {
true
Some(member)
}
})
.collect()

View file

@ -34,66 +34,69 @@ async fn list(ctx: &Context, m: &Message, _: Args) -> CommandResult {
let pages = (roles.len() + ROLES_PER_PAGE - 1) / ROLES_PER_PAGE;
paginate(
|page, ctx, msg| async move {
let page = page as usize;
let start = page * ROLES_PER_PAGE;
let end = roles.len().min(start + ROLES_PER_PAGE);
if end <= start {
return Ok(false);
}
let roles = &roles[start..end];
let nw = roles // name width
.iter()
.map(|(r, _)| r.name.len())
.max()
.unwrap()
.max(6);
let idw = roles[0].0.id.to_string().len();
let dw = roles
.iter()
.map(|v| v.1.len())
.max()
.unwrap()
.max(" Description ".len());
let mut m = MessageBuilder::new();
m.push_line("```");
|page, ctx, msg| {
let roles = roles.clone();
Box::pin(async move {
let page = page as usize;
let start = page * ROLES_PER_PAGE;
let end = roles.len().min(start + ROLES_PER_PAGE);
if end <= start {
return Ok(false);
}
let roles = &roles[start..end];
let nw = roles // name width
.iter()
.map(|(r, _)| r.name.len())
.max()
.unwrap()
.max(6);
let idw = roles[0].0.id.to_string().len();
let dw = roles
.iter()
.map(|v| v.1.len())
.max()
.unwrap()
.max(" Description ".len());
let mut m = MessageBuilder::new();
m.push_line("```");
// Table header
m.push_line(format!(
"{:nw$} | {:idw$} | {:dw$}",
"Name",
"ID",
"Description",
nw = nw,
idw = idw,
dw = dw,
));
m.push_line(format!(
"{:->nw$}---{:->idw$}---{:->dw$}",
"",
"",
"",
nw = nw,
idw = idw,
dw = dw,
));
for (role, description) in roles.iter() {
// Table header
m.push_line(format!(
"{:nw$} | {:idw$} | {:dw$}",
role.name,
role.id,
description,
"Name",
"ID",
"Description",
nw = nw,
idw = idw,
dw = dw,
));
m.push_line(format!(
"{:->nw$}---{:->idw$}---{:->dw$}",
"",
"",
"",
nw = nw,
idw = idw,
dw = dw,
));
}
m.push_line("```");
m.push(format!("Page **{}/{}**", page + 1, pages));
msg.edit(ctx, |f| f.content(m.to_string())).await?;
Ok(true)
for (role, description) in roles.iter() {
m.push_line(format!(
"{:nw$} | {:idw$} | {:dw$}",
role.name,
role.id,
description,
nw = nw,
idw = idw,
dw = dw,
));
}
m.push_line("```");
m.push(format!("Page **{}/{}**", page + 1, pages));
msg.edit(ctx, |f| f.content(m.to_string())).await?;
Ok(true)
})
},
ctx,
m.channel_id,
@ -155,6 +158,7 @@ async fn toggle(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
#[only_in(guilds)]
async fn add(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let role = args.single_quoted::<String>()?;
let data = ctx.data.read().await;
let description = args.single::<String>()?;
let guild_id = m.guild_id.unwrap();
let roles = guild_id.to_partial_guild(&ctx).await?.roles;
@ -164,7 +168,7 @@ async fn add(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
m.reply(&ctx, "No such role exists").await?;
}
Some(role)
if DB::open(&*ctx.data.read().await)
if DB::open(&*data)
.borrow()?
.get(&guild_id)
.map(|g| g.contains_key(&role.id))
@ -174,7 +178,7 @@ async fn add(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
.await?;
}
Some(role) => {
DB::open(&*ctx.data.read().await)
DB::open(&*data)
.borrow_mut()?
.entry(guild_id)
.or_default()
@ -200,6 +204,7 @@ async fn add(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
#[only_in(guilds)]
async fn remove(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let role = args.single_quoted::<String>()?;
let data = ctx.data.read().await;
let guild_id = m.guild_id.unwrap();
let roles = guild_id.to_partial_guild(&ctx).await?.roles;
let role = role_from_string(&role, &roles);
@ -208,7 +213,7 @@ async fn remove(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
m.reply(&ctx, "No such role exists").await?;
}
Some(role)
if !DB::open(&*ctx.data.read().await)
if !DB::open(&*data)
.borrow()?
.get(&guild_id)
.map(|g| g.contains_key(&role.id))
@ -218,7 +223,7 @@ async fn remove(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
.await?;
}
Some(role) => {
DB::open(&*ctx.data.read().await)
DB::open(&*data)
.borrow_mut()?
.entry(guild_id)
.or_default()

View file

@ -28,7 +28,7 @@ pub async fn vote(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
let args = args.quoted();
let _duration = args.single::<ParseDuration>()?;
let duration = &_duration.0;
if *duration < Duration::from_secs(2 * 60) || *duration > Duration::from_secs(60 * 60 * 24) {
if *duration < Duration::from_secs(2) || *duration > Duration::from_secs(60 * 60 * 24) {
msg.reply(ctx, format!("😒 Invalid duration ({}). The voting time should be between **2 minutes** and **1 day**.", _duration)).await?;
return Ok(());
}
@ -97,6 +97,8 @@ pub async fn vote(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
})
}).await?;
msg.delete(&ctx).await?;
drop(msg);
// React on all the choices
choices
.iter()
@ -110,16 +112,18 @@ pub async fn vote(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
.await?;
// A handler for votes.
let mut user_reactions: Map<String, Set<UserId>> = choices
let user_reactions: Map<String, Set<UserId>> = choices
.iter()
.map(|(emote, _)| (emote.clone(), Set::new()))
.collect();
// Collect reactions...
msg.await_reactions(&ctx)
let user_reactions = panel
.await_reactions(&ctx)
.removed(true)
.timeout(*duration)
.await
.scan(user_reactions, |set, reaction| async move {
.fold(user_reactions, |mut set, reaction| async move {
let (reaction, is_add) = match &*reaction {
ReactionAction::Added(r) => (r, true),
ReactionAction::Removed(r) => (r, false),
@ -128,23 +132,22 @@ pub async fn vote(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
if let Some(users) = set.get_mut(s.as_str()) {
users
} else {
return None;
return set;
}
} else {
return None;
return set;
};
let user_id = match reaction.user_id {
Some(v) => v,
None => return None,
None => return set,
};
if is_add {
users.insert(user_id);
} else {
users.remove(&user_id);
}
Some(())
set
})
.collect::<()>()
.await;
// Handle choices
@ -233,3 +236,4 @@ const REACTIONS: [&'static str; 90] = [
// Assertions
static_assertions::const_assert!(MAX_CHOICES <= REACTIONS.len());
static_assertions::const_assert!(MAX_CHOICES <= REACTIONS.len());

View file

@ -20,26 +20,30 @@ pub use fun::FUN_GROUP;
pub fn setup(
path: &std::path::Path,
client: &serenity::client::Client,
data: &mut youmubot_prelude::ShareMap,
data: &mut TypeMap,
) -> serenity::framework::standard::CommandResult {
db::SoftBans::insert_into(&mut *data, &path.join("soft_bans.yaml"))?;
db::Roles::insert_into(&mut *data, &path.join("roles.yaml"))?;
// Create handler threads
std::thread::spawn(admin::watch_soft_bans(client));
tokio::spawn(admin::watch_soft_bans(
client.cache_and_http.clone(),
client.data.clone(),
));
Ok(())
}
// A help command
#[help]
pub fn help(
context: &mut Context,
pub async fn help(
context: &Context,
msg: &Message,
args: Args,
help_options: &'static HelpOptions,
groups: &[&'static CommandGroup],
owners: HashSet<UserId>,
) -> CommandResult {
help_commands::with_embeds(context, msg, args, help_options, groups, owners)
help_commands::with_embeds(context, msg, args, help_options, groups, owners).await;
Ok(())
}

View file

@ -1,6 +1,5 @@
use crate::{AppData, Result};
use async_trait::async_trait;
use crossbeam_channel::after;
use futures_util::{
future::{join_all, ready, FutureExt},
stream::{FuturesUnordered, StreamExt},
@ -78,7 +77,7 @@ impl MemberToChannels {
pub struct AnnouncerHandler {
cache_http: Arc<CacheAndHttp>,
data: AppData,
announcers: HashMap<&'static str, RwLock<Box<dyn Announcer>>>,
announcers: HashMap<&'static str, RwLock<Box<dyn Announcer + Send + Sync>>>,
}
// Querying for the AnnouncerHandler in the internal data returns a vec of keys.
@ -100,7 +99,11 @@ impl AnnouncerHandler {
/// Insert a new announcer into the handler.
///
/// The handler must take an unique key. If a duplicate is found, this method panics.
pub fn add(&mut self, key: &'static str, announcer: impl Announcer + 'static) -> &mut Self {
pub fn add(
&mut self,
key: &'static str,
announcer: impl Announcer + Send + Sync + 'static,
) -> &mut Self {
if let Some(_) = self
.announcers
.insert(key, RwLock::new(Box::new(announcer)))
@ -132,7 +135,7 @@ impl AnnouncerHandler {
data: AppData,
cache_http: Arc<CacheAndHttp>,
key: &'static str,
announcer: &'_ RwLock<Box<dyn Announcer>>,
announcer: &'_ RwLock<Box<dyn Announcer + Send + Sync>>,
) -> Result<()> {
let channels = MemberToChannels(Self::get_guilds(&data, key).await?);
announcer
@ -151,7 +154,8 @@ impl AnnouncerHandler {
self.data.write().await.insert::<Self>(keys.clone());
loop {
eprintln!("{}: announcer started scanning", chrono::Utc::now());
let after_timer = after(cooldown);
// let after_timer = after(cooldown);
let after = tokio::time::delay_for(cooldown);
join_all(self.announcers.iter().map(|(key, announcer)| {
eprintln!(" - scanning key `{}`", key);
Self::announce(self.data.clone(), self.cache_http.clone(), *key, announcer).map(
@ -164,7 +168,7 @@ impl AnnouncerHandler {
}))
.await;
eprintln!("{}: announcer finished scanning", chrono::Utc::now());
after_timer.recv().ok();
after.await;
}
}
}

View file

@ -13,18 +13,39 @@ use tokio::time as tokio_time;
const ARROW_RIGHT: &'static str = "➡️";
const ARROW_LEFT: &'static str = "⬅️";
/// Paginate! with a pager function.
#[async_trait::async_trait]
pub trait Paginate {
async fn render(&mut self, page: u8, ctx: &Context, m: &mut Message) -> Result<bool>;
}
#[async_trait::async_trait]
impl<T> Paginate for T
where
T: for<'m> FnMut(
u8,
&'m Context,
&'m mut Message,
) -> std::pin::Pin<Box<dyn Future<Output = Result<bool>> + Send + 'm>>
+ Send,
{
async fn render(&mut self, page: u8, ctx: &Context, m: &mut Message) -> Result<bool> {
self(page, ctx, m).await
}
}
// Paginate! with a pager function.
/// If awaited, will block until everything is done.
pub async fn paginate<'a, T, F>(
mut pager: T,
ctx: &'a Context,
pub async fn paginate(
mut pager: impl for<'m> FnMut(
u8,
&'m Context,
&'m mut Message,
) -> std::pin::Pin<Box<dyn Future<Output = Result<bool>> + Send + 'm>>
+ Send,
ctx: &Context,
channel: ChannelId,
timeout: std::time::Duration,
) -> Result<()>
where
T: for<'m> FnMut(u8, &'a Context, &'m mut Message) -> F,
F: Future<Output = Result<bool>>,
{
) -> Result<()> {
let mut message = channel
.send_message(&ctx, |e| e.content("Youmu is loading the first page..."))
.await?;
@ -35,8 +56,9 @@ where
message
.react(&ctx, ReactionType::try_from(ARROW_RIGHT)?)
.await?;
pager(0, ctx, &mut message).await?;
// Build a reaction collector
let mut reaction_collector = message.await_reactions(&ctx).await;
let mut reaction_collector = message.await_reactions(&ctx).removed(true).await;
let mut page = 0;
// Loop the handler function.
@ -59,29 +81,25 @@ where
}
// Handle the reaction and return a new page number.
async fn handle_reaction<'a, T, F>(
async fn handle_reaction(
page: u8,
pager: &mut T,
ctx: &'a Context,
message: &'_ mut Message,
pager: &mut impl Paginate,
ctx: &Context,
message: &mut Message,
reaction: &ReactionAction,
) -> Result<u8>
where
T: for<'m> FnMut(u8, &'a Context, &'m mut Message) -> F,
F: Future<Output = Result<bool>>,
{
) -> Result<u8> {
let reaction = match reaction {
ReactionAction::Added(v) | ReactionAction::Removed(v) => v,
};
match &reaction.emoji {
ReactionType::Unicode(ref s) => match s.as_str() {
ARROW_LEFT if page == 0 => Ok(page),
ARROW_LEFT => Ok(if pager(page - 1, ctx, message).await? {
ARROW_LEFT => Ok(if pager.render(page - 1, ctx, message).await? {
page - 1
} else {
page
}),
ARROW_RIGHT => Ok(if pager(page + 1, ctx, message).await? {
ARROW_RIGHT => Ok(if pager.render(page + 1, ctx, message).await? {
page + 1
} else {
page

View file

@ -1,10 +1,10 @@
use serenity::{framework::standard::StandardFramework, prelude::*};
use serenity::prelude::*;
use std::path::Path;
/// Set up the prelude libraries.
///
/// Panics on failure: Youmubot should *NOT* attempt to continue when this function fails.
pub fn setup_prelude(db_path: &Path, data: &mut TypeMap, _: &mut StandardFramework) {
pub fn setup_prelude(db_path: &Path, data: &mut TypeMap) {
// Setup the announcer DB.
crate::announcer::AnnouncerChannels::insert_into(data, db_path.join("announcers.yaml"))
.expect("Announcers DB set up");