Merge pull request #5 from natsukagami/async-youmu

Async youmu!!!
This commit is contained in:
Natsu Kagami 2020-09-20 20:57:58 +00:00 committed by GitHub
commit e25701a99c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 3471 additions and 3197 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@ target
.env .env
*.yaml *.yaml
cargo-remote cargo-remote
.vscode

64
.vscode/launch.json vendored
View file

@ -1,64 +0,0 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in library 'youmubot-osu'",
"cargo": {
"args": [
"test",
"--no-run",
"--lib",
"--package=youmubot-osu"
],
"filter": {
"name": "youmubot-osu",
"kind": "lib"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug executable 'youmubot'",
"cargo": {
"args": [
"build",
"--bin=youmubot",
"--package=youmubot"
],
"filter": {
"name": "youmubot",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in executable 'youmubot'",
"cargo": {
"args": [
"test",
"--no-run",
"--bin=youmubot",
"--package=youmubot"
],
"filter": {
"name": "youmubot",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
}
]
}

1997
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -7,15 +7,15 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
tokio = { version = "0.2", features = ["time"] }
reqwest = "0.10.1" reqwest = "0.10.1"
serenity = "0.8" serenity = "0.9.0-rc.0"
Inflector = "0.11" Inflector = "0.11"
codeforces = { git = "https://github.com/natsukagami/rust-codeforces-api" } codeforces = "0.2.1"
regex = "1" regex = "1"
lazy_static = "1" lazy_static = "1"
rayon = "1"
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
crossbeam-channel = "0.4" dashmap = "3.11.4"
youmubot-prelude = { path = "../youmubot-prelude" } youmubot-prelude = { path = "../youmubot-prelude" }
youmubot-db = { path = "../youmubot-db" } youmubot-db = { path = "../youmubot-db" }

View file

@ -1,85 +1,60 @@
use crate::db::{CfSavedUsers, CfUser}; use crate::{
db::{CfSavedUsers, CfUser},
CFClient,
};
use announcer::MemberToChannels; use announcer::MemberToChannels;
use chrono::Utc; use chrono::Utc;
use codeforces::{RatingChange, User}; use codeforces::{RatingChange, User};
use serenity::{ use serenity::{http::CacheHttp, model::id::UserId, CacheAndHttp};
framework::standard::{CommandError, CommandResult},
http::CacheHttp,
model::id::{ChannelId, UserId},
CacheAndHttp,
};
use std::sync::Arc; use std::sync::Arc;
use youmubot_prelude::*; use youmubot_prelude::*;
type Reqwest = <HTTPClient as TypeMapKey>::Value;
/// Updates the rating and rating changes of the users. /// Updates the rating and rating changes of the users.
pub fn updates( pub struct Announcer;
http: Arc<CacheAndHttp>,
data: AppData,
channels: MemberToChannels,
) -> CommandResult {
let mut users = CfSavedUsers::open(&*data.read()).borrow()?.clone();
let reqwest = data.get_cloned::<HTTPClient>();
for (user_id, cfu) in users.iter_mut() { #[async_trait]
if let Err(e) = update_user(http.clone(), &channels, &reqwest, *user_id, cfu) { impl youmubot_prelude::Announcer for Announcer {
dbg!((*user_id, e)); async fn updates(
} &mut self,
http: Arc<CacheAndHttp>,
data: AppData,
channels: MemberToChannels,
) -> Result<()> {
let data = data.read().await;
let client = data.get::<CFClient>().unwrap();
let mut users = CfSavedUsers::open(&*data).borrow()?.clone();
users
.iter_mut()
.map(|(user_id, cfu)| update_user(http.clone(), &channels, &client, *user_id, cfu))
.collect::<stream::FuturesUnordered<_>>()
.try_collect::<()>()
.await?;
*CfSavedUsers::open(&*data).borrow_mut()? = users;
Ok(())
} }
*CfSavedUsers::open(&*data.read()).borrow_mut()? = users;
Ok(())
} }
fn update_user( async fn update_user(
http: Arc<CacheAndHttp>, http: Arc<CacheAndHttp>,
channels: &MemberToChannels, channels: &MemberToChannels,
reqwest: &Reqwest, client: &codeforces::Client,
user_id: UserId, user_id: UserId,
cfu: &mut CfUser, cfu: &mut CfUser,
) -> CommandResult { ) -> Result<()> {
// Ensure this takes 200ms let info = User::info(client, &[cfu.handle.as_str()])
let after = crossbeam_channel::after(std::time::Duration::from_secs_f32(0.2)); .await?
let info = User::info(reqwest, &[cfu.handle.as_str()])?
.into_iter() .into_iter()
.next() .next()
.ok_or(CommandError::from("Not found"))?; .ok_or(Error::msg("Not found"))?;
let rating_changes = info.rating_changes(reqwest)?; let rating_changes = info.rating_changes(client).await?;
let mut channels_list: Option<Vec<ChannelId>> = None; let channels_list = channels.channels_of(&http, user_id).await;
cfu.last_update = Utc::now(); cfu.last_update = Utc::now();
// Update the rating // Update the rating
cfu.rating = info.rating; cfu.rating = info.rating;
let mut send_message = |rc: RatingChange| -> CommandResult {
let channels =
channels_list.get_or_insert_with(|| channels.channels_of(http.clone(), user_id));
if channels.is_empty() {
return Ok(());
}
let (contest, _, _) =
codeforces::Contest::standings(reqwest, rc.contest_id, |f| f.limit(1, 1))?;
for channel in channels {
if let Err(e) = channel.send_message(http.http(), |e| {
e.content(format!("Rating change for {}!", user_id.mention()))
.embed(|c| {
crate::embed::rating_change_embed(
&rc,
&info,
&contest,
&user_id.mention(),
c,
)
})
}) {
dbg!(e);
}
}
Ok(())
};
let rating_changes = match cfu.last_contest_id { let rating_changes = match cfu.last_contest_id {
None => rating_changes, None => rating_changes,
Some(v) => { Some(v) => {
@ -101,12 +76,46 @@ fn update_user(
.or(cfu.last_contest_id); .or(cfu.last_contest_id);
// Check for any good announcements to make // Check for any good announcements to make
for rc in rating_changes { rating_changes
if let Err(v) = send_message(rc) { .into_iter()
dbg!(v); .map(|rc: RatingChange| {
} let channels = channels_list.clone();
} let http = http.clone();
after.recv().ok(); let info = info.clone();
async move {
if channels.is_empty() {
return Ok(());
}
let (contest, _, _) =
codeforces::Contest::standings(client, rc.contest_id, |f| f.limit(1, 1))
.await?;
channels
.iter()
.map(|channel| {
channel.send_message(http.http(), |e| {
e.content(format!("Rating change for {}!", user_id.mention()))
.embed(|c| {
crate::embed::rating_change_embed(
&rc,
&info,
&contest,
&user_id.mention(),
c,
)
})
})
})
.collect::<stream::FuturesUnordered<_>>()
.map(|v| v.map(|_| ()))
.try_collect::<()>()
.await?;
let r: Result<_> = Ok(());
r
}
})
.collect::<stream::FuturesUnordered<_>>()
.try_collect::<()>()
.await?;
Ok(()) Ok(())
} }

View file

@ -1,13 +1,13 @@
use chrono::{TimeZone, Utc}; use chrono::{TimeZone, Utc};
use codeforces::{Contest, Problem}; use codeforces::{Client, Contest, Problem};
use dashmap::DashMap as HashMap;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rayon::{iter::Either, prelude::*};
use regex::{Captures, Regex}; use regex::{Captures, Regex};
use serenity::{ use serenity::{
builder::CreateEmbed, framework::standard::CommandError, model::channel::Message, builder::CreateEmbed, framework::standard::CommandError, model::channel::Message,
utils::MessageBuilder, utils::MessageBuilder,
}; };
use std::{collections::HashMap, sync::Arc}; use std::{sync::Arc, time::Instant};
use youmubot_prelude::*; use youmubot_prelude::*;
lazy_static! { lazy_static! {
@ -27,106 +27,132 @@ enum ContestOrProblem {
} }
/// Caches the contest list. /// Caches the contest list.
#[derive(Clone, Debug, Default)] pub struct ContestCache {
pub struct ContestCache(Arc<RwLock<HashMap<u64, (Contest, Option<Vec<Problem>>)>>>); contests: HashMap<u64, (Contest, Option<Vec<Problem>>)>,
all_list: RwLock<(Vec<Contest>, Instant)>,
http: Arc<Client>,
}
impl TypeMapKey for ContestCache { impl TypeMapKey for ContestCache {
type Value = ContestCache; type Value = ContestCache;
} }
impl ContestCache { impl ContestCache {
fn get( /// Creates a new, empty cache.
&self, pub async fn new(http: Arc<Client>) -> Result<Self> {
http: &<HTTPClient as TypeMapKey>::Value, let contests_list = Contest::list(&*http, true).await?;
contest_id: u64, Ok(Self {
) -> Result<(Contest, Option<Vec<Problem>>), CommandError> { contests: HashMap::new(),
let rl = self.0.read(); all_list: RwLock::new((contests_list, Instant::now())),
match rl.get(&contest_id) { http,
Some(r @ (_, Some(_))) => Ok(r.clone()), })
Some((c, None)) => match Contest::standings(http, contest_id, |f| f.limit(1, 1)) { }
Ok((c, p, _)) => Ok({
drop(rl); /// Gets a contest from the cache, fetching from upstream if possible.
let mut v = self.0.write(); pub async fn get(&self, contest_id: u64) -> Result<(Contest, Option<Vec<Problem>>)> {
let v = v.entry(contest_id).or_insert((c, None)); if let Some(v) = self.contests.get(&contest_id) {
v.1 = Some(p); if v.1.is_some() {
v.clone() return Ok(v.clone());
}),
Err(_) => Ok((c.clone(), None)),
},
None => {
drop(rl);
// Step 1: try to fetch it individually
match Contest::standings(http, contest_id, |f| f.limit(1, 1)) {
Ok((c, p, _)) => Ok(self
.0
.write()
.entry(contest_id)
.or_insert((c, Some(p)))
.clone()),
Err(codeforces::Error::Codeforces(s)) if s.ends_with("has not started") => {
// Fetch the entire list
{
let mut m = self.0.write();
let contests = Contest::list(http, contest_id > 100_000)?;
contests.into_iter().for_each(|c| {
m.entry(c.id).or_insert((c, None));
});
}
self.0
.read()
.get(&contest_id)
.cloned()
.ok_or("No contest found".into())
}
Err(e) => Err(e.into()),
}
// Step 2: try to fetch the entire list.
} }
} }
self.get_and_store_contest(contest_id).await
}
async fn get_and_store_contest(
&self,
contest_id: u64,
) -> Result<(Contest, Option<Vec<Problem>>)> {
let (c, p) = match Contest::standings(&*self.http, contest_id, |f| f.limit(1, 1)).await {
Ok((c, p, _)) => (c, Some(p)),
Err(codeforces::Error::Codeforces(s)) if s.ends_with("has not started") => {
let c = self.get_from_list(contest_id).await?;
(c, None)
}
Err(v) => return Err(Error::from(v)),
};
self.contests.insert(contest_id, (c, p));
Ok(self.contests.get(&contest_id).unwrap().clone())
}
async fn get_from_list(&self, contest_id: u64) -> Result<Contest> {
let last_updated = self.all_list.read().await.1.clone();
if Instant::now() - last_updated > std::time::Duration::from_secs(60 * 60) {
// We update at most once an hour.
*self.all_list.write().await =
(Contest::list(&*self.http, true).await?, Instant::now());
}
self.all_list
.read()
.await
.0
.iter()
.find(|v| v.id == contest_id)
.cloned()
.ok_or_else(|| Error::msg("Contest not found"))
} }
} }
/// Prints info whenever a problem or contest (or more) is sent on a channel. /// Prints info whenever a problem or contest (or more) is sent on a channel.
pub fn codeforces_info_hook(ctx: &mut Context, m: &Message) { pub struct InfoHook;
if m.author.bot {
return; #[async_trait]
impl Hook for InfoHook {
async fn call(&mut self, ctx: &Context, m: &Message) -> Result<()> {
if m.author.bot {
return Ok(());
}
let data = ctx.data.read().await;
let contest_cache = data.get::<ContestCache>().unwrap();
let matches = parse(&m.content[..], contest_cache)
.collect::<Vec<_>>()
.await;
if !matches.is_empty() {
m.channel_id
.send_message(&ctx, |c| {
c.content("Here are the info of the given Codeforces links!")
.embed(|e| print_info_message(&matches[..], e))
})
.await?;
}
Ok(())
} }
let http = ctx.data.get_cloned::<HTTPClient>(); }
let contest_cache = ctx.data.get_cloned::<ContestCache>();
fn parse<'a>(
content: &'a str,
contest_cache: &'a ContestCache,
) -> impl stream::Stream<Item = (ContestOrProblem, &'a str)> + 'a {
let matches = CONTEST_LINK let matches = CONTEST_LINK
.captures_iter(&m.content) .captures_iter(content)
.chain(PROBLEMSET_LINK.captures_iter(&m.content)) .chain(PROBLEMSET_LINK.captures_iter(content))
// .collect::<Vec<_>>() .map(|v| parse_capture(contest_cache, v))
// .into_par_iter() .collect::<stream::FuturesUnordered<_>>()
.filter_map( .filter_map(|v| future::ready(v.ok()));
|v| match parse_capture(http.clone(), contest_cache.clone(), v) { matches
Ok(v) => Some(v),
Err(e) => {
dbg!(e);
None
}
},
)
.collect::<Vec<_>>();
if !matches.is_empty() {
m.channel_id
.send_message(&ctx, |c| {
c.content("Here are the info of the given Codeforces links!")
.embed(|e| print_info_message(&matches[..], e))
})
.ok();
}
} }
fn print_info_message<'a>( fn print_info_message<'a>(
info: &[(ContestOrProblem, &str)], info: &[(ContestOrProblem, &str)],
e: &'a mut CreateEmbed, e: &'a mut CreateEmbed,
) -> &'a mut CreateEmbed { ) -> &'a mut CreateEmbed {
let (mut problems, contests): (Vec<_>, Vec<_>) = let (problems, contests): (Vec<_>, Vec<_>) = info.iter().partition(|(v, _)| match v {
info.par_iter().partition_map(|(v, l)| match v { ContestOrProblem::Problem(_) => true,
ContestOrProblem::Problem(p) => Either::Left((p, l)), ContestOrProblem::Contest(_, _) => false,
ContestOrProblem::Contest(c, p) => Either::Right((c, p, l)), });
}); let mut problems = problems
.into_iter()
.map(|(v, l)| match v {
ContestOrProblem::Problem(p) => (p, l),
_ => unreachable!(),
})
.collect::<Vec<_>>();
let contests = contests
.into_iter()
.map(|(v, l)| match v {
ContestOrProblem::Contest(c, p) => (c, p, l),
_ => unreachable!(),
})
.collect::<Vec<_>>();
problems.sort_by(|(a, _), (b, _)| a.rating.unwrap_or(1500).cmp(&b.rating.unwrap_or(1500))); problems.sort_by(|(a, _), (b, _)| a.rating.unwrap_or(1500).cmp(&b.rating.unwrap_or(1500)));
let mut m = MessageBuilder::new(); let mut m = MessageBuilder::new();
if !problems.is_empty() { if !problems.is_empty() {
@ -190,9 +216,8 @@ fn print_info_message<'a>(
e.description(m.build()) e.description(m.build())
} }
fn parse_capture<'a>( async fn parse_capture<'a>(
http: <HTTPClient as TypeMapKey>::Value, contest_cache: &ContestCache,
contest_cache: ContestCache,
cap: Captures<'a>, cap: Captures<'a>,
) -> Result<(ContestOrProblem, &'a str), CommandError> { ) -> Result<(ContestOrProblem, &'a str), CommandError> {
let contest_id: u64 = cap let contest_id: u64 = cap
@ -200,7 +225,7 @@ fn parse_capture<'a>(
.ok_or(CommandError::from("Contest not captured"))? .ok_or(CommandError::from("Contest not captured"))?
.as_str() .as_str()
.parse()?; .parse()?;
let (contest, problems) = contest_cache.get(&http, contest_id)?; let (contest, problems) = contest_cache.get(contest_id).await?;
match cap.name("problem") { match cap.name("problem") {
Some(p) => { Some(p) => {
for problem in problems.ok_or(CommandError::from("Contest hasn't started"))? { for problem in problems.ok_or(CommandError::from("Contest hasn't started"))? {

View file

@ -2,12 +2,12 @@ use codeforces::Contest;
use serenity::{ use serenity::{
framework::standard::{ framework::standard::{
macros::{command, group}, macros::{command, group},
Args, CommandError as Error, CommandResult, Args, CommandResult,
}, },
model::channel::Message, model::channel::Message,
utils::MessageBuilder, utils::MessageBuilder,
}; };
use std::{collections::HashMap, time::Duration}; use std::{collections::HashMap, sync::Arc, time::Duration};
use youmubot_prelude::*; use youmubot_prelude::*;
mod announcer; mod announcer;
@ -18,16 +18,26 @@ mod hook;
/// Live-commentating a Codeforces round. /// Live-commentating a Codeforces round.
mod live; mod live;
/// The TypeMapKey holding the Client.
struct CFClient;
impl TypeMapKey for CFClient {
type Value = Arc<codeforces::Client>;
}
use db::{CfSavedUsers, CfUser}; use db::{CfSavedUsers, CfUser};
pub use hook::codeforces_info_hook; pub use hook::InfoHook;
/// Sets up the CF databases. /// Sets up the CF databases.
pub fn setup(path: &std::path::Path, data: &mut ShareMap, announcers: &mut AnnouncerHandler) { pub async fn setup(path: &std::path::Path, data: &mut TypeMap, announcers: &mut AnnouncerHandler) {
CfSavedUsers::insert_into(data, path.join("cf_saved_users.yaml")) CfSavedUsers::insert_into(data, path.join("cf_saved_users.yaml"))
.expect("Must be able to set up DB"); .expect("Must be able to set up DB");
data.insert::<hook::ContestCache>(hook::ContestCache::default()); let http = data.get::<HTTPClient>().unwrap();
announcers.add("codeforces", announcer::updates); let client = Arc::new(codeforces::Client::new(http.clone()));
data.insert::<hook::ContestCache>(hook::ContestCache::new(client.clone()).await.unwrap());
data.insert::<CFClient>(client);
announcers.add("codeforces", announcer::Announcer);
} }
#[group] #[group]
@ -43,40 +53,46 @@ pub struct Codeforces;
#[usage = "[handle or tag = yourself]"] #[usage = "[handle or tag = yourself]"]
#[example = "natsukagami"] #[example = "natsukagami"]
#[max_args(1)] #[max_args(1)]
pub fn profile(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { pub async fn profile(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await;
let handle = args let handle = args
.single::<UsernameArg>() .single::<UsernameArg>()
.unwrap_or(UsernameArg::mention(m.author.id)); .unwrap_or(UsernameArg::mention(m.author.id));
let http = ctx.data.get_cloned::<HTTPClient>(); let http = data.get::<CFClient>().unwrap();
let handle = match handle { let handle = match handle {
UsernameArg::Raw(s) => s, UsernameArg::Raw(s) => s,
UsernameArg::Tagged(u) => { UsernameArg::Tagged(u) => {
let db = CfSavedUsers::open(&*ctx.data.read()); let db = CfSavedUsers::open(&*data);
let db = db.borrow()?; let user = db.borrow()?.get(&u).map(|u| u.handle.clone());
match db.get(&u) { match user {
Some(v) => v.handle.clone(), Some(v) => v,
None => { None => {
m.reply(&ctx, "no saved account found.")?; m.reply(&ctx, "no saved account found.").await?;
return Ok(()); return Ok(());
} }
} }
} }
}; };
let account = codeforces::User::info(&http, &[&handle[..]])? let account = codeforces::User::info(&http, &[&handle[..]])
.await?
.into_iter() .into_iter()
.next(); .next();
match account { match account {
Some(v) => m.channel_id.send_message(&ctx, |send| { Some(v) => {
send.content(format!( m.channel_id
"{}: Here is the user that you requested", .send_message(&ctx, |send| {
m.author.mention() send.content(format!(
)) "{}: Here is the user that you requested",
.embed(|e| embed::user_embed(&v, e)) m.author.mention()
}), ))
None => m.reply(&ctx, "User not found"), .embed(|e| embed::user_embed(&v, e))
})
.await
}
None => m.reply(&ctx, "User not found").await,
}?; }?;
Ok(()) Ok(())
@ -86,28 +102,32 @@ pub fn profile(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult
#[description = "Link your Codeforces account to the Discord account, to enjoy Youmu's tracking capabilities."] #[description = "Link your Codeforces account to the Discord account, to enjoy Youmu's tracking capabilities."]
#[usage = "[handle]"] #[usage = "[handle]"]
#[num_args(1)] #[num_args(1)]
pub fn save(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { pub async fn save(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await;
let handle = args.single::<String>()?; let handle = args.single::<String>()?;
let http = ctx.data.get_cloned::<HTTPClient>(); let http = data.get::<CFClient>().unwrap();
let account = codeforces::User::info(&http, &[&handle[..]])? let account = codeforces::User::info(&http, &[&handle[..]])
.await?
.into_iter() .into_iter()
.next(); .next();
match account { match account {
None => { None => {
m.reply(&ctx, "cannot find an account with such handle")?; m.reply(&ctx, "cannot find an account with such handle")
.await?;
} }
Some(acc) => { Some(acc) => {
// Collect rating changes data. // Collect rating changes data.
let rating_changes = acc.rating_changes(&http)?; let rating_changes = acc.rating_changes(&http).await?;
let db = CfSavedUsers::open(&*ctx.data.read()); let mut db = CfSavedUsers::open(&*data);
let mut db = db.borrow_mut()?;
m.reply( m.reply(
&ctx, &ctx,
format!("account `{}` has been linked to your account.", &acc.handle), format!("account `{}` has been linked to your account.", &acc.handle),
)?; )
db.insert(m.author.id, CfUser::save(acc, rating_changes)); .await?;
db.borrow_mut()?
.insert(m.author.id, CfUser::save(acc, rating_changes));
} }
} }
@ -118,9 +138,10 @@ pub fn save(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult {
#[description = "See the leaderboard of all people in the server."] #[description = "See the leaderboard of all people in the server."]
#[only_in(guilds)] #[only_in(guilds)]
#[num_args(0)] #[num_args(0)]
pub fn ranks(ctx: &mut Context, m: &Message) -> CommandResult { pub async fn ranks(ctx: &Context, m: &Message) -> CommandResult {
let data = ctx.data.read().await;
let everyone = { let everyone = {
let db = CfSavedUsers::open(&*ctx.data.read()); let db = CfSavedUsers::open(&*data);
let db = db.borrow()?; let db = db.borrow()?;
db.iter() db.iter()
.map(|(k, v)| (k.clone(), v.clone())) .map(|(k, v)| (k.clone(), v.clone()))
@ -129,84 +150,98 @@ pub fn ranks(ctx: &mut Context, m: &Message) -> CommandResult {
let guild = m.guild_id.expect("Guild-only command"); let guild = m.guild_id.expect("Guild-only command");
let mut ranks = everyone let mut ranks = everyone
.into_iter() .into_iter()
.filter_map(|(id, cf_user)| guild.member(&ctx, id).ok().map(|mem| (mem, cf_user))) .map(|(id, cf_user)| {
.collect::<Vec<_>>(); guild
.member(&ctx, id)
.map(|mem| mem.map(|mem| (mem, cf_user)))
})
.collect::<stream::FuturesUnordered<_>>()
.filter_map(|v| future::ready(v.ok()))
.collect::<Vec<_>>()
.await;
ranks.sort_by(|(_, a), (_, b)| b.rating.unwrap_or(-1).cmp(&a.rating.unwrap_or(-1))); ranks.sort_by(|(_, a), (_, b)| b.rating.unwrap_or(-1).cmp(&a.rating.unwrap_or(-1)));
if ranks.is_empty() { if ranks.is_empty() {
m.reply(&ctx, "No saved users in this server.")?; m.reply(&ctx, "No saved users in this server.").await?;
return Ok(()); return Ok(());
} }
let ranks = Arc::new(ranks);
const ITEMS_PER_PAGE: usize = 10; const ITEMS_PER_PAGE: usize = 10;
let total_pages = (ranks.len() + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE; let total_pages = (ranks.len() + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE;
let last_updated = ranks.iter().map(|(_, cfu)| cfu.last_update).min().unwrap(); let last_updated = ranks.iter().map(|(_, cfu)| cfu.last_update).min().unwrap();
ctx.data.get_cloned::<ReactionWatcher>().paginate_fn( paginate(
ctx.clone(), move |page, ctx, msg| {
m.channel_id, let ranks = ranks.clone();
move |page, e| { Box::pin(async move {
let page = page as usize; let page = page as usize;
let start = ITEMS_PER_PAGE * page; let start = ITEMS_PER_PAGE * page;
let end = ranks.len().min(start + ITEMS_PER_PAGE); let end = ranks.len().min(start + ITEMS_PER_PAGE);
if start >= end { if start >= end {
return (e, Err(Error::from("No more pages"))); return Ok(false);
} }
let ranks = &ranks[start..end]; let ranks = &ranks[start..end];
let handle_width = ranks.iter().map(|(_, cfu)| cfu.handle.len()).max().unwrap(); let handle_width = ranks.iter().map(|(_, cfu)| cfu.handle.len()).max().unwrap();
let username_width = ranks let username_width = ranks
.iter() .iter()
.map(|(mem, _)| mem.distinct().len()) .map(|(mem, _)| mem.distinct().len())
.max() .max()
.unwrap(); .unwrap();
let mut m = MessageBuilder::new(); let mut m = MessageBuilder::new();
m.push_line("```"); m.push_line("```");
// Table header // Table header
m.push_line(format!(
"Rank | Rating | {:hw$} | {:uw$}",
"Handle",
"Username",
hw = handle_width,
uw = username_width
));
m.push_line(format!(
"----------------{:->hw$}---{:->uw$}",
"",
"",
hw = handle_width,
uw = username_width
));
for (id, (mem, cfu)) in ranks.iter().enumerate() {
let id = id + start + 1;
m.push_line(format!( m.push_line(format!(
"{:>4} | {:>6} | {:hw$} | {:uw$}", "Rank | Rating | {:hw$} | {:uw$}",
format!("#{}", id), "Handle",
cfu.rating "Username",
.map(|v| v.to_string()) hw = handle_width,
.unwrap_or("----".to_owned()), uw = username_width
cfu.handle, ));
mem.distinct(), m.push_line(format!(
"----------------{:->hw$}---{:->uw$}",
"",
"",
hw = handle_width, hw = handle_width,
uw = username_width uw = username_width
)); ));
}
m.push_line("```"); for (id, (mem, cfu)) in ranks.iter().enumerate() {
m.push(format!( let id = id + start + 1;
"Page **{}/{}**. Last updated **{}**", m.push_line(format!(
page + 1, "{:>4} | {:>6} | {:hw$} | {:uw$}",
total_pages, format!("#{}", id),
last_updated.to_rfc2822() cfu.rating
)); .map(|v| v.to_string())
.unwrap_or("----".to_owned()),
cfu.handle,
mem.distinct(),
hw = handle_width,
uw = username_width
));
}
(e.content(m.build()), Ok(())) m.push_line("```");
m.push(format!(
"Page **{}/{}**. Last updated **{}**",
page + 1,
total_pages,
last_updated.to_rfc2822()
));
msg.edit(ctx, |f| f.content(m.build())).await?;
Ok(true)
})
}, },
ctx,
m.channel_id,
std::time::Duration::from_secs(60), std::time::Duration::from_secs(60),
)?; )
.await?;
Ok(()) Ok(())
} }
@ -216,23 +251,27 @@ pub fn ranks(ctx: &mut Context, m: &Message) -> CommandResult {
#[usage = "[the contest id]"] #[usage = "[the contest id]"]
#[num_args(1)] #[num_args(1)]
#[only_in(guilds)] #[only_in(guilds)]
pub fn contestranks(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { pub async fn contestranks(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await;
let contest_id: u64 = args.single()?; let contest_id: u64 = args.single()?;
let guild = m.guild_id.unwrap(); // Guild-only command let guild = m.guild_id.unwrap(); // Guild-only command
let members = CfSavedUsers::open(&*ctx.data.read()).borrow()?.clone(); let members = CfSavedUsers::open(&*data).borrow()?.clone();
let members = members let members = members
.into_iter() .into_iter()
.filter_map(|(user_id, cf_user)| { .map(|(user_id, cf_user)| {
guild guild
.member(&ctx, user_id) .member(&ctx, user_id)
.ok() .map(|v| v.map(|v| (cf_user.handle, v)))
.map(|v| (cf_user.handle, v))
}) })
.collect::<HashMap<_, _>>(); .collect::<stream::FuturesUnordered<_>>()
let http = ctx.data.get_cloned::<HTTPClient>(); .filter_map(|v| future::ready(v.ok()))
let (contest, problems, ranks) = Contest::standings(&http, contest_id, |f| { .collect::<HashMap<_, _>>()
.await;
let http = data.get::<CFClient>().unwrap();
let (contest, problems, ranks) = Contest::standings(http, contest_id, |f| {
f.handles(members.iter().map(|(k, _)| k.clone()).collect()) f.handles(members.iter().map(|(k, _)| k.clone()).collect())
})?; })
.await?;
// Table me // Table me
let ranks = ranks let ranks = ranks
@ -252,100 +291,111 @@ pub fn contestranks(ctx: &mut Context, m: &Message, mut args: Args) -> CommandRe
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if ranks.is_empty() { if ranks.is_empty() {
m.reply(&ctx, "No one in this server participated in the contest...")?; m.reply(&ctx, "No one in this server participated in the contest...")
.await?;
return Ok(()); return Ok(());
} }
let ranks = Arc::new(ranks);
const ITEMS_PER_PAGE: usize = 10; const ITEMS_PER_PAGE: usize = 10;
let total_pages = (ranks.len() + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE; let total_pages = (ranks.len() + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE;
ctx.data.get_cloned::<ReactionWatcher>().paginate_fn( paginate(
ctx.clone(), move |page, ctx, msg| {
m.channel_id, let contest = contest.clone();
move |page, e| { let problems = problems.clone();
let page = page as usize; let ranks = ranks.clone();
let start = page * ITEMS_PER_PAGE; Box::pin(async move {
let end = ranks.len().min(start + ITEMS_PER_PAGE); let page = page as usize;
if start >= end { let start = page * ITEMS_PER_PAGE;
return (e, Err(Error::from("no more pages to show"))); let end = ranks.len().min(start + ITEMS_PER_PAGE);
} if start >= end {
let ranks = &ranks[start..end]; return Ok(false);
let hw = ranks }
.iter() let ranks = &ranks[start..end];
.map(|(mem, handle, _)| format!("{} ({})", handle, mem.distinct()).len()) let hw = ranks
.max() .iter()
.unwrap_or(0) .map(|(mem, handle, _)| format!("{} ({})", handle, mem.distinct()).len())
.max(6); .max()
let hackw = ranks .unwrap_or(0)
.iter() .max(6);
.map(|(_, _, row)| { let hackw = ranks
format!( .iter()
"{}/{}", .map(|(_, _, row)| {
row.successful_hack_count, row.unsuccessful_hack_count format!(
) "{}/{}",
.len() row.successful_hack_count, row.unsuccessful_hack_count
}) )
.max() .len()
.unwrap_or(0) })
.max(5); .max()
.unwrap_or(0)
.max(5);
let mut table = MessageBuilder::new(); let mut table = MessageBuilder::new();
let mut header = MessageBuilder::new(); let mut header = MessageBuilder::new();
// Header // Header
header.push(format!( header.push(format!(
" Rank | {:hw$} | Total | {:hackw$}", " Rank | {:hw$} | Total | {:hackw$}",
"Handle", "Handle",
"Hacks", "Hacks",
hw = hw,
hackw = hackw
));
for p in &problems {
header.push(format!(" | {:4}", p.index));
}
let header = header.build();
table
.push_line(&header)
.push_line(format!("{:-<w$}", "", w = header.len()));
// Body
for (mem, handle, row) in ranks {
table.push(format!(
"{:>5} | {:<hw$} | {:>5.0} | {:<hackw$}",
row.rank,
format!("{} ({})", handle, mem.distinct()),
row.points,
format!(
"{}/{}",
row.successful_hack_count, row.unsuccessful_hack_count
),
hw = hw, hw = hw,
hackw = hackw hackw = hackw
)); ));
for p in &row.problem_results { for p in &problems {
table.push(" | "); header.push(format!(" | {:4}", p.index));
if p.points > 0.0 {
table.push(format!("{:^4.0}", p.points));
} else if let Some(_) = p.best_submission_time_seconds {
table.push(format!("{:^4}", "?"));
} else if p.rejected_attempt_count > 0 {
table.push(format!("{:^4}", format!("-{}", p.rejected_attempt_count)));
} else {
table.push(format!("{:^4}", ""));
}
} }
table.push_line(""); let header = header.build();
} table
.push_line(&header)
.push_line(format!("{:-<w$}", "", w = header.len()));
let mut m = MessageBuilder::new(); // Body
m.push_bold_safe(&contest.name) for (mem, handle, row) in ranks {
.push(" ") table.push(format!(
.push_line(contest.url()) "{:>5} | {:<hw$} | {:>5.0} | {:<hackw$}",
.push_codeblock(table.build(), None) row.rank,
.push_line(format!("Page **{}/{}**", page + 1, total_pages)); format!("{} ({})", handle, mem.distinct()),
(e.content(m.build()), Ok(())) row.points,
format!(
"{}/{}",
row.successful_hack_count, row.unsuccessful_hack_count
),
hw = hw,
hackw = hackw
));
for p in &row.problem_results {
table.push(" | ");
if p.points > 0.0 {
table.push(format!("{:^4.0}", p.points));
} else if let Some(_) = p.best_submission_time_seconds {
table.push(format!("{:^4}", "?"));
} else if p.rejected_attempt_count > 0 {
table.push(format!("{:^4}", format!("-{}", p.rejected_attempt_count)));
} else {
table.push(format!("{:^4}", ""));
}
}
table.push_line("");
}
let mut m = MessageBuilder::new();
m.push_bold_safe(&contest.name)
.push(" ")
.push_line(contest.url())
.push_codeblock(table.build(), None)
.push_line(format!("Page **{}/{}**", page + 1, total_pages));
msg.edit(ctx, |e| e.content(m.build())).await?;
Ok(true)
})
}, },
ctx,
m.channel_id,
Duration::from_secs(60), Duration::from_secs(60),
) )
.await?;
Ok(())
} }
#[command] #[command]
@ -354,10 +404,10 @@ pub fn contestranks(ctx: &mut Context, m: &Message, mut args: Args) -> CommandRe
#[num_args(1)] #[num_args(1)]
#[required_permissions(MANAGE_CHANNELS)] #[required_permissions(MANAGE_CHANNELS)]
#[only_in(guilds)] #[only_in(guilds)]
pub fn watch(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { pub async fn watch(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let contest_id: u64 = args.single()?; let contest_id: u64 = args.single()?;
live::watch_contest(ctx, m.guild_id.unwrap(), m.channel_id, contest_id)?; live::watch_contest(ctx, m.guild_id.unwrap(), m.channel_id, contest_id).await?;
Ok(()) Ok(())
} }

View file

@ -1,8 +1,6 @@
use crate::db::CfSavedUsers; use crate::{db::CfSavedUsers, CFClient};
use codeforces::{Contest, ContestPhase, Problem, ProblemResult, ProblemResultType, RanklistRow}; use codeforces::{Contest, ContestPhase, Problem, ProblemResult, ProblemResultType, RanklistRow};
use rayon::prelude::*;
use serenity::{ use serenity::{
framework::standard::{CommandError, CommandResult},
model::{ model::{
guild::Member, guild::Member,
id::{ChannelId, GuildId, UserId}, id::{ChannelId, GuildId, UserId},
@ -21,56 +19,64 @@ struct MemberResult {
/// Watch and commentate a contest. /// Watch and commentate a contest.
/// ///
/// Does the thing on a channel, block until the contest ends. /// Does the thing on a channel, block until the contest ends.
pub fn watch_contest( pub async fn watch_contest(
ctx: &mut Context, ctx: &Context,
guild: GuildId, guild: GuildId,
channel: ChannelId, channel: ChannelId,
contest_id: u64, contest_id: u64,
) -> CommandResult { ) -> Result<()> {
let db = CfSavedUsers::open(&*ctx.data.read()).borrow()?.clone(); let data = ctx.data.read().await;
let db = CfSavedUsers::open(&*data).borrow()?.clone();
let http = ctx.http.clone(); let http = ctx.http.clone();
// Collect an initial member list. // Collect an initial member list.
// This never changes during the scan. // This never changes during the scan.
let mut member_results: HashMap<UserId, MemberResult> = db let mut member_results: HashMap<UserId, MemberResult> = db
.into_par_iter() .into_iter()
.filter_map(|(user_id, cfu)| { .map(|(user_id, cfu)| {
let member = guild.member(http.clone().as_ref(), user_id).ok(); let http = http.clone();
match member { async move {
Some(m) => Some(( guild.member(http, user_id).await.map(|m| {
user_id, (
MemberResult { user_id,
member: m, MemberResult {
handle: cfu.handle, member: m,
row: None, handle: cfu.handle,
}, row: None,
)), },
None => None, )
})
} }
}) })
.collect(); .collect::<stream::FuturesUnordered<_>>()
.filter_map(|v| future::ready(v.ok()))
.collect()
.await;
let http = ctx.data.get_cloned::<HTTPClient>(); let http = data.get::<CFClient>().unwrap();
let (mut contest, _, _) = Contest::standings(&http, contest_id, |f| f.limit(1, 1))?; let (mut contest, _, _) = Contest::standings(&http, contest_id, |f| f.limit(1, 1)).await?;
channel.send_message(&ctx, |e| { channel
e.content(format!( .send_message(&ctx, |e| {
"Youmu is watching contest **{}**, with the following members:\n{}", e.content(format!(
contest.name, "Youmu is watching contest **{}**, with the following members:\n{}",
member_results contest.name,
.iter() member_results
.map(|(_, m)| format!("- {} as **{}**", m.member.distinct(), m.handle)) .iter()
.collect::<Vec<_>>() .map(|(_, m)| format!("- {} as **{}**", m.member.distinct(), m.handle))
.join("\n"), .collect::<Vec<_>>()
)) .join("\n"),
})?; ))
})
.await?;
loop { loop {
if let Ok(messages) = scan_changes(http.clone(), &mut member_results, &mut contest) { if let Ok(messages) = scan_changes(&*http, &mut member_results, &mut contest).await {
for message in messages { for message in messages {
channel channel
.send_message(&ctx, |e| { .send_message(&ctx, |e| {
e.content(format!("**{}**: {}", contest.name, message)) e.content(format!("**{}**: {}", contest.name, message))
}) })
.await
.ok(); .ok();
} }
} }
@ -78,7 +84,7 @@ pub fn watch_contest(
break; break;
} }
// Sleep for a minute // Sleep for a minute
std::thread::sleep(std::time::Duration::from_secs(60)); tokio::time::delay_for(std::time::Duration::from_secs(60)).await;
} }
// Announce the final results // Announce the final results
@ -93,12 +99,14 @@ pub fn watch_contest(
ranks.sort_by(|(_, a), (_, b)| a.rank.cmp(&b.rank)); ranks.sort_by(|(_, a), (_, b)| a.rank.cmp(&b.rank));
if ranks.is_empty() { if ranks.is_empty() {
channel.send_message(&ctx, |e| { channel
e.content(format!( .send_message(&ctx, |e| {
"**{}** has ended, but I can't find anyone in this server on the scoreboard...", e.content(format!(
contest.name "**{}** has ended, but I can't find anyone in this server on the scoreboard...",
)) contest.name
})?; ))
})
.await?;
return Ok(()); return Ok(());
} }
@ -115,23 +123,23 @@ pub fn watch_contest(
row.problem_results.iter().map(|p| format!("{:.0}", p.points)).collect::<Vec<_>>().join("/"), row.problem_results.iter().map(|p| format!("{:.0}", p.points)).collect::<Vec<_>>().join("/"),
row.successful_hack_count, row.successful_hack_count,
row.unsuccessful_hack_count, row.unsuccessful_hack_count,
)).collect::<Vec<_>>().join("\n"))))?; )).collect::<Vec<_>>().join("\n")))).await?;
Ok(()) Ok(())
} }
fn scan_changes( async fn scan_changes(
http: <HTTPClient as TypeMapKey>::Value, http: &codeforces::Client,
members: &mut HashMap<UserId, MemberResult>, members: &mut HashMap<UserId, MemberResult>,
contest: &mut Contest, contest: &mut Contest,
) -> Result<Vec<String>, CommandError> { ) -> Result<Vec<String>> {
let mut messages: Vec<String> = vec![]; let mut messages: Vec<String> = vec![];
let (updated_contest, problems, ranks) = { let (updated_contest, problems, ranks) = {
let handles = members let handles = members
.iter() .iter()
.map(|(_, h)| h.handle.clone()) .map(|(_, h)| h.handle.clone())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Contest::standings(&http, contest.id, |f| f.handles(handles))? Contest::standings(&http, contest.id, |f| f.handles(handles)).await?
}; };
// Change of phase. // Change of phase.
if contest.phase != updated_contest.phase { if contest.phase != updated_contest.phase {

View file

@ -7,11 +7,13 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
serenity = "0.8" serenity = { version = "0.9.0-rc.0", features = ["collector"] }
rand = "0.7" rand = "0.7"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
chrono = "0.4" chrono = "0.4"
static_assertions = "1.1" static_assertions = "1.1"
futures-util = "0.3"
tokio = { version = "0.2", features = ["time"] }
youmubot-db = { path = "../youmubot-db" } youmubot-db = { path = "../youmubot-db" }
youmubot-prelude = { path = "../youmubot-prelude" } youmubot-prelude = { path = "../youmubot-prelude" }

View file

@ -1,3 +1,4 @@
use futures_util::{stream, TryStreamExt};
use serenity::{ use serenity::{
framework::standard::{ framework::standard::{
macros::{command, group}, macros::{command, group},
@ -9,7 +10,6 @@ use serenity::{
}, },
}; };
use soft_ban::{SOFT_BAN_COMMAND, SOFT_BAN_INIT_COMMAND}; use soft_ban::{SOFT_BAN_COMMAND, SOFT_BAN_INIT_COMMAND};
use std::{thread::sleep, time::Duration};
use youmubot_prelude::*; use youmubot_prelude::*;
mod soft_ban; mod soft_ban;
@ -27,29 +27,34 @@ struct Admin;
#[usage = "clean 50"] #[usage = "clean 50"]
#[min_args(0)] #[min_args(0)]
#[max_args(1)] #[max_args(1)]
fn clean(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { async fn clean(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let limit = args.single().unwrap_or(10); let limit = args.single().unwrap_or(10);
let messages = msg let messages = msg
.channel_id .channel_id
.messages(&ctx.http, |b| b.before(msg.id).limit(limit))?; .messages(&ctx.http, |b| b.before(msg.id).limit(limit))
let channel = msg.channel_id.to_channel(&ctx)?; .await?;
let channel = msg.channel_id.to_channel(&ctx).await?;
match &channel { match &channel {
Channel::Private(_) | Channel::Group(_) => { Channel::Private(_) => {
let self_id = ctx.http.get_current_application_info()?.id; let self_id = ctx.http.get_current_application_info().await?.id;
messages messages
.into_iter() .into_iter()
.filter(|v| v.author.id == self_id) .filter(|v| v.author.id == self_id)
.try_for_each(|m| m.delete(&ctx))?; .map(|m| async move { m.delete(&ctx).await })
.collect::<stream::FuturesUnordered<_>>()
.try_collect::<()>()
.await?;
} }
_ => { _ => {
msg.channel_id msg.channel_id
.delete_messages(&ctx.http, messages.into_iter())?; .delete_messages(&ctx.http, messages.into_iter())
.await?;
} }
}; };
msg.react(&ctx, "🌋")?; msg.react(&ctx, '🌋').await?;
if let Channel::Guild(_) = &channel { if let Channel::Guild(_) = &channel {
sleep(Duration::from_secs(2)); tokio::time::delay_for(std::time::Duration::from_secs(2)).await;
msg.delete(&ctx)?; msg.delete(&ctx).await.ok();
} }
Ok(()) Ok(())
@ -58,25 +63,36 @@ fn clean(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
#[command] #[command]
#[required_permissions(ADMINISTRATOR)] #[required_permissions(ADMINISTRATOR)]
#[description = "Ban an user with a certain reason."] #[description = "Ban an user with a certain reason."]
#[usage = "@user#1234/spam"] #[usage = "tag user/[reason = none]/[days of messages to delete = 0]"]
#[min_args(1)] #[min_args(1)]
#[max_args(2)] #[max_args(2)]
#[only_in("guilds")] #[only_in("guilds")]
fn ban(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { async fn ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let user = args.single::<UserId>()?.to_user(&ctx)?; let user = args.single::<UserId>()?.to_user(&ctx).await?;
let reason = args let reason = args.single::<String>().map(|v| format!("`{}`", v)).ok();
.remains() let dmds = args.single::<u8>().unwrap_or(0);
.map(|v| format!("`{}`", v))
.unwrap_or("no provided reason".to_owned());
msg.reply( match reason {
&ctx, Some(reason) => {
format!("🔨 Banning user {} for reason `{}`.", user.tag(), reason), msg.reply(
)?; &ctx,
format!("🔨 Banning user {} for reason `{}`.", user.tag(), reason),
msg.guild_id )
.ok_or("Can't get guild from message?")? // we had a contract .await?;
.ban(&ctx.http, user, &reason)?; msg.guild_id
.ok_or(Error::msg("Can't get guild from message?"))? // we had a contract
.ban_with_reason(&ctx.http, user, dmds, &reason)
.await?;
}
None => {
msg.reply(&ctx, format!("🔨 Banning user {}.", user.tag()))
.await?;
msg.guild_id
.ok_or(Error::msg("Can't get guild from message?"))? // we had a contract
.ban(&ctx.http, user, dmds)
.await?;
}
}
Ok(()) Ok(())
} }
@ -87,14 +103,16 @@ fn ban(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
#[usage = "@user#1234"] #[usage = "@user#1234"]
#[num_args(1)] #[num_args(1)]
#[only_in("guilds")] #[only_in("guilds")]
fn kick(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { async fn kick(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let user = args.single::<UserId>()?.to_user(&ctx)?; let user = args.single::<UserId>()?.to_user(&ctx).await?;
msg.reply(&ctx, format!("🔫 Kicking user {}.", user.tag()))?; msg.reply(&ctx, format!("🔫 Kicking user {}.", user.tag()))
.await?;
msg.guild_id msg.guild_id
.ok_or("Can't get guild from message?")? // we had a contract .ok_or("Can't get guild from message?")? // we had a contract
.kick(&ctx.http, user)?; .kick(&ctx.http, user)
.await?;
Ok(()) Ok(())
} }

View file

@ -1,13 +1,15 @@
use crate::db::{ServerSoftBans, SoftBans}; use crate::db::{ServerSoftBans, SoftBans};
use chrono::offset::Utc; use chrono::offset::Utc;
use futures_util::{stream, TryStreamExt};
use serenity::{ use serenity::{
framework::standard::{macros::command, Args, CommandError as Error, CommandResult}, framework::standard::{macros::command, Args, CommandResult},
model::{ model::{
channel::Message, channel::Message,
id::{RoleId, UserId}, id::{GuildId, RoleId, UserId},
}, },
CacheAndHttp,
}; };
use std::cmp::max; use std::sync::Arc;
use youmubot_prelude::*; use youmubot_prelude::*;
#[command] #[command]
@ -18,57 +20,56 @@ use youmubot_prelude::*;
#[min_args(1)] #[min_args(1)]
#[max_args(2)] #[max_args(2)]
#[only_in("guilds")] #[only_in("guilds")]
pub fn soft_ban(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { pub async fn soft_ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let user = args.single::<UserId>()?.to_user(&ctx)?; let user = args.single::<UserId>()?.to_user(&ctx).await?;
let data = ctx.data.read().await;
let duration = if args.is_empty() { let duration = if args.is_empty() {
None None
} else { } else {
Some( Some(args.single::<args::Duration>()?)
args.single::<args::Duration>()
.map_err(|e| Error::from(&format!("{:?}", e)))?,
)
}; };
let guild = msg.guild_id.ok_or(Error::from("Command is guild only"))?; let guild = msg.guild_id.ok_or(Error::msg("Command is guild only"))?;
let db = SoftBans::open(&*ctx.data.read()); let mut db = SoftBans::open(&*data);
let mut db = db.borrow_mut()?; let val = db
let mut server_ban = db.get_mut(&guild).and_then(|v| match v { .borrow()?
ServerSoftBans::Unimplemented => None, .get(&guild)
ServerSoftBans::Implemented(ref mut v) => Some(v), .map(|v| (v.role, v.periodical_bans.get(&user.id).cloned()));
}); let (role, current_ban_deadline) = match val {
match server_ban {
None => { None => {
println!("get here"); msg.reply(&ctx, format!("⚠ This server has not enabled the soft-ban feature. Check out `y!a soft-ban-init`.")).await?;
msg.reply(&ctx, format!("⚠ This server has not enabled the soft-ban feature. Check out `y!a soft-ban-init`."))?; return Ok(());
} }
Some(ref mut server_ban) => { Some(v) => v,
let mut member = guild.member(&ctx, &user)?; };
match duration {
None if member.roles.contains(&server_ban.role) => { let mut member = guild.member(&ctx, &user).await?;
msg.reply(&ctx, format!("⛓ Lifting soft-ban for user {}.", user.tag()))?; match duration {
member.remove_role(&ctx, server_ban.role)?; None if member.roles.contains(&role) => {
return Ok(()); msg.reply(&ctx, format!("⛓ Lifting soft-ban for user {}.", user.tag()))
} .await?;
None => { member.remove_role(&ctx, role).await?;
msg.reply(&ctx, format!("⛓ Soft-banning user {}.", user.tag()))?; return Ok(());
} }
Some(v) => { None => {
let until = Utc::now() + chrono::Duration::from_std(v.0)?; msg.reply(&ctx, format!("⛓ Soft-banning user {}.", user.tag()))
let until = server_ban .await?;
.periodical_bans }
.entry(user.id) Some(v) => {
.and_modify(|v| *v = max(*v, until)) // Add the duration into the ban timeout.
.or_insert(until); let until =
msg.reply( current_ban_deadline.unwrap_or(Utc::now()) + chrono::Duration::from_std(v.0)?;
&ctx, msg.reply(
format!("⛓ Soft-banning user {} until {}.", user.tag(), until), &ctx,
)?; format!("⛓ Soft-banning user {} until {}.", user.tag(), until),
} )
} .await?;
member.add_role(&ctx, server_ban.role)?; db.borrow_mut()?
.get_mut(&guild)
.map(|v| v.periodical_bans.insert(user.id, until));
} }
} }
member.add_role(&ctx, role).await?;
Ok(()) Ok(())
} }
@ -79,86 +80,90 @@ pub fn soft_ban(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResu
#[usage = "{soft_ban_role_id}"] #[usage = "{soft_ban_role_id}"]
#[num_args(1)] #[num_args(1)]
#[only_in("guilds")] #[only_in("guilds")]
pub fn soft_ban_init(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { pub async fn soft_ban_init(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let role_id = args.single::<RoleId>()?; let role_id = args.single::<RoleId>()?;
let guild = msg.guild(&ctx).ok_or(Error::from("Guild-only command"))?; let data = ctx.data.read().await;
let guild = guild.read(); let guild = msg.guild(&ctx).await.unwrap();
// Check whether the role_id is the one we wanted // Check whether the role_id is the one we wanted
if !guild.roles.contains_key(&role_id) { if !guild.roles.contains_key(&role_id) {
return Err(Error::from(format!( Err(Error::msg(format!(
"{} is not a role in this server.", "{} is not a role in this server.",
role_id role_id
))); )))?;
} }
// Check if we already set up // Check if we already set up
let db = SoftBans::open(&*ctx.data.read()); let mut db = SoftBans::open(&*data);
let mut db = db.borrow_mut()?; let set_up = db.borrow()?.contains_key(&guild.id);
let server = db
.get(&guild.id)
.map(|v| match v {
ServerSoftBans::Unimplemented => false,
_ => true,
})
.unwrap_or(false);
if !server { if !set_up {
db.insert(guild.id, ServerSoftBans::new_implemented(role_id)); db.borrow_mut()?
msg.react(&ctx, "👌")?; .insert(guild.id, ServerSoftBans::new(role_id));
Ok(()) msg.react(&ctx, '👌').await?;
} else { } else {
Err(Error::from("Server already set up soft-bans.")) Err(Error::msg("Server already set up soft-bans."))?
}
Ok(())
}
// Watch the soft bans. Blocks forever.
pub async fn watch_soft_bans(cache_http: Arc<CacheAndHttp>, data: AppData) {
loop {
// Scope so that locks are released
{
// Poll the data for any changes.
let db = data.read().await;
let db = SoftBans::open(&*db);
let mut db = db.borrow().unwrap().clone();
let now = Utc::now();
for (server_id, bans) in db.iter_mut() {
let server_name: String = match server_id.to_partial_guild(&*cache_http.http).await
{
Err(_) => continue,
Ok(v) => v.name,
};
let to_remove: Vec<_> = bans
.periodical_bans
.iter()
.filter_map(|(user, time)| if time <= &now { Some(user) } else { None })
.cloned()
.collect();
if let Err(e) = to_remove
.into_iter()
.map(|user_id| {
bans.periodical_bans.remove(&user_id);
lift_soft_ban_for(
&*cache_http,
*server_id,
&server_name[..],
bans.role,
user_id,
)
})
.collect::<stream::FuturesUnordered<_>>()
.try_collect::<()>()
.await
{
eprintln!("Error while scanning soft-bans list: {}", e)
}
}
}
// Sleep the thread for a minute
tokio::time::delay_for(std::time::Duration::from_secs(60)).await
} }
} }
// Watch the soft bans. async fn lift_soft_ban_for(
pub fn watch_soft_bans(client: &serenity::Client) -> impl FnOnce() -> () + 'static { cache_http: &CacheAndHttp,
let cache_http = { server_id: GuildId,
let cache_http = client.cache_and_http.clone(); server_name: &str,
let cache: serenity::cache::CacheRwLock = cache_http.cache.clone().into(); ban_role: RoleId,
(cache, cache_http.http.clone()) user_id: UserId,
}; ) -> Result<()> {
let data = client.data.clone(); let mut m = server_id.member(cache_http, user_id).await?;
return move || { println!(
let cache_http = (&cache_http.0, &*cache_http.1); "Soft-ban for `{}` in server `{}` unlifted.",
loop { m.user.name, server_name
// Scope so that locks are released );
{ m.remove_role(&cache_http.http, ban_role).await?;
// Poll the data for any changes. Ok(())
let db = data.read();
let db = SoftBans::open(&*db);
let mut db = db.borrow_mut().expect("Borrowable");
let now = Utc::now();
for (server_id, soft_bans) in db.iter_mut() {
let server_name: String = match server_id.to_partial_guild(cache_http) {
Err(_) => continue,
Ok(v) => v.name,
};
if let ServerSoftBans::Implemented(ref mut bans) = soft_bans {
let to_remove: Vec<_> = bans
.periodical_bans
.iter()
.filter_map(|(user, time)| if time <= &now { Some(user) } else { None })
.cloned()
.collect();
for user_id in to_remove {
server_id
.member(cache_http, user_id)
.and_then(|mut m| {
println!(
"Soft-ban for `{}` in server `{}` unlifted.",
m.user.read().name,
server_name
);
m.remove_role(cache_http, bans.role)
})
.unwrap_or(());
bans.periodical_bans.remove(&user_id);
}
}
}
}
// Sleep the thread for a minute
std::thread::sleep(std::time::Duration::from_secs(60))
}
};
} }

View file

@ -9,6 +9,7 @@ use serenity::{
}, },
model::{ model::{
channel::{Channel, Message}, channel::{Channel, Message},
id::RoleId,
user::OnlineStatus, user::OnlineStatus,
}, },
utils::MessageBuilder, utils::MessageBuilder,
@ -30,11 +31,12 @@ struct Community;
#[command] #[command]
#[description = r"👑 Randomly choose an active member and mention them! #[description = r"👑 Randomly choose an active member and mention them!
Note that only online/idle users in the channel are chosen from."] Note that only online/idle users in the channel are chosen from."]
#[usage = "[title = the chosen one]"] #[usage = "[limited roles = everyone online] / [title = the chosen one]"]
#[example = "the strongest in Gensokyo"] #[example = "the strongest in Gensokyo"]
#[bucket = "community"] #[bucket = "community"]
#[max_args(1)] #[max_args(2)]
pub fn choose(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { pub async fn choose(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let role = args.find::<RoleId>().ok();
let title = if args.is_empty() { let title = if args.is_empty() {
"the chosen one".to_owned() "the chosen one".to_owned()
} else { } else {
@ -42,29 +44,39 @@ pub fn choose(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult {
}; };
let users: Result<Vec<_>, Error> = { let users: Result<Vec<_>, Error> = {
let guild = m.guild(&ctx).unwrap(); let guild = m.guild(&ctx).await.unwrap();
let guild = guild.read();
let presences = &guild.presences; let presences = &guild.presences;
let channel = m.channel_id.to_channel(&ctx)?; let channel = m.channel_id.to_channel(&ctx).await?;
if let Channel::Guild(channel) = channel { if let Channel::Guild(channel) = channel {
let channel = channel.read();
Ok(channel Ok(channel
.members(&ctx)? .members(&ctx)
.await?
.into_iter() .into_iter()
.filter(|v| !v.user.read().bot) .filter(|v| !v.user.bot) // Filter out bots
.map(|v| v.user_id())
.filter(|v| { .filter(|v| {
// Filter out only online people
presences presences
.get(v) .get(&v.user.id)
.map(|presence| { .map(|presence| {
presence.status == OnlineStatus::Online presence.status == OnlineStatus::Online
|| presence.status == OnlineStatus::Idle || presence.status == OnlineStatus::Idle
}) })
.unwrap_or(false) .unwrap_or(false)
}) })
.collect()) .map(|mem| future::ready(mem))
.collect::<stream::FuturesUnordered<_>>()
.filter_map(|member| async move {
// Filter by role if provided
match role {
Some(role) if member.roles.iter().any(|r| role == *r) => Some(member),
None => Some(member),
_ => None,
}
})
.collect()
.await)
} else { } else {
panic!() unreachable!()
} }
}; };
let users = users?; let users = users?;
@ -73,7 +85,8 @@ pub fn choose(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult {
m.reply( m.reply(
&ctx, &ctx,
"🍰 Have this cake for yourself because no-one is here for the gods to pick.", "🍰 Have this cake for yourself because no-one is here for the gods to pick.",
)?; )
.await?;
return Ok(()); return Ok(());
} }
@ -83,19 +96,26 @@ pub fn choose(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult {
&users[uniform.sample(&mut rng)] &users[uniform.sample(&mut rng)]
}; };
m.channel_id.send_message(&ctx, |c| { m.channel_id
c.content( .send_message(&ctx, |c| {
MessageBuilder::new() c.content(
.push("👑 The Gensokyo gods have gathered around and decided, out of ") MessageBuilder::new()
.push_bold(format!("{}", users.len())) .push("👑 The Gensokyo gods have gathered around and decided, out of ")
.push(" potential prayers, ") .push_bold(format!("{}", users.len()))
.push(winner.mention()) .push(" ")
.push(" will be ") .push(
.push_bold_safe(title) role.map(|r| r.mention() + "s")
.push(". Congrats! 🎉 🎊 🥳") .unwrap_or("potential prayers".to_owned()),
.build(), )
) .push(", ")
})?; .push(winner.mention())
.push(" will be ")
.push_bold_safe(title)
.push(". Congrats! 🎉 🎊 🥳")
.build(),
)
})
.await?;
Ok(()) Ok(())
} }

View file

@ -1,6 +1,6 @@
use crate::db::Roles as DB; use crate::db::Roles as DB;
use serenity::{ use serenity::{
framework::standard::{macros::command, Args, CommandError as Error, CommandResult}, framework::standard::{macros::command, Args, CommandResult},
model::{channel::Message, guild::Role, id::RoleId}, model::{channel::Message, guild::Role, id::RoleId},
utils::MessageBuilder, utils::MessageBuilder,
}; };
@ -10,18 +10,22 @@ use youmubot_prelude::*;
#[description = "List all available roles in the server."] #[description = "List all available roles in the server."]
#[num_args(0)] #[num_args(0)]
#[only_in(guilds)] #[only_in(guilds)]
fn list(ctx: &mut Context, m: &Message, _: Args) -> CommandResult { async fn list(ctx: &Context, m: &Message, _: Args) -> CommandResult {
let guild_id = m.guild_id.unwrap(); // only_in(guilds) let guild_id = m.guild_id.unwrap(); // only_in(guilds)
let data = ctx.data.read().await;
let db = DB::open(&*ctx.data.read()); let db = DB::open(&*data);
let db = db.borrow()?; let roles = db
let roles = db.get(&guild_id).filter(|v| !v.is_empty()).cloned(); .borrow()?
.get(&guild_id)
.filter(|v| !v.is_empty())
.cloned();
match roles { match roles {
None => { None => {
m.reply(&ctx, "No roles available for assigning.")?; m.reply(&ctx, "No roles available for assigning.").await?;
} }
Some(v) => { Some(v) => {
let roles = guild_id.to_partial_guild(&ctx)?.roles; let roles = guild_id.to_partial_guild(&ctx).await?.roles;
let roles: Vec<_> = v let roles: Vec<_> = v
.into_iter() .into_iter()
.filter_map(|(_, role)| roles.get(&role.id).cloned().map(|r| (r, role.description))) .filter_map(|(_, role)| roles.get(&role.id).cloned().map(|r| (r, role.description)))
@ -29,108 +33,116 @@ fn list(ctx: &mut Context, m: &Message, _: Args) -> CommandResult {
const ROLES_PER_PAGE: usize = 8; const ROLES_PER_PAGE: usize = 8;
let pages = (roles.len() + ROLES_PER_PAGE - 1) / ROLES_PER_PAGE; let pages = (roles.len() + ROLES_PER_PAGE - 1) / ROLES_PER_PAGE;
let watcher = ctx.data.get_cloned::<ReactionWatcher>(); paginate(
watcher.paginate_fn( |page, ctx, msg| {
ctx.clone(), let roles = roles.clone();
m.channel_id, Box::pin(async move {
move |page, e| { let page = page as usize;
let page = page as usize; let start = page * ROLES_PER_PAGE;
let start = page * ROLES_PER_PAGE; let end = roles.len().min(start + ROLES_PER_PAGE);
let end = roles.len().min(start + ROLES_PER_PAGE); if end <= start {
if end <= start { return Ok(false);
return (e, Err(Error::from("No more roles to display"))); }
} let roles = &roles[start..end];
let roles = &roles[start..end]; let nw = roles // name width
let nw = roles // name width .iter()
.iter() .map(|(r, _)| r.name.len())
.map(|(r, _)| r.name.len()) .max()
.max() .unwrap()
.unwrap() .max(6);
.max(6); let idw = roles[0].0.id.to_string().len();
let idw = roles[0].0.id.to_string().len(); let dw = roles
let dw = roles .iter()
.iter() .map(|v| v.1.len())
.map(|v| v.1.len()) .max()
.max() .unwrap()
.unwrap() .max(" Description ".len());
.max(" Description ".len()); let mut m = MessageBuilder::new();
let mut m = MessageBuilder::new(); m.push_line("```");
m.push_line("```");
// Table header // Table header
m.push_line(format!(
"{:nw$} | {:idw$} | {:dw$}",
"Name",
"ID",
"Description",
nw = nw,
idw = idw,
dw = dw,
));
m.push_line(format!(
"{:->nw$}---{:->idw$}---{:->dw$}",
"",
"",
"",
nw = nw,
idw = idw,
dw = dw,
));
for (role, description) in roles.iter() {
m.push_line(format!( m.push_line(format!(
"{:nw$} | {:idw$} | {:dw$}", "{:nw$} | {:idw$} | {:dw$}",
role.name, "Name",
role.id, "ID",
description, "Description",
nw = nw,
idw = idw,
dw = dw,
));
m.push_line(format!(
"{:->nw$}---{:->idw$}---{:->dw$}",
"",
"",
"",
nw = nw, nw = nw,
idw = idw, idw = idw,
dw = dw, dw = dw,
)); ));
}
m.push_line("```");
m.push(format!("Page **{}/{}**", page + 1, pages));
(e.content(m.build()), Ok(())) for (role, description) in roles.iter() {
m.push_line(format!(
"{:nw$} | {:idw$} | {:dw$}",
role.name,
role.id,
description,
nw = nw,
idw = idw,
dw = dw,
));
}
m.push_line("```");
m.push(format!("Page **{}/{}**", page + 1, pages));
msg.edit(ctx, |f| f.content(m.to_string())).await?;
Ok(true)
})
}, },
ctx,
m.channel_id,
std::time::Duration::from_secs(60 * 10), std::time::Duration::from_secs(60 * 10),
)?; )
.await?;
} }
}; };
Ok(()) Ok(())
} }
// async fn list_pager(
#[command("role")] #[command("role")]
#[description = "Toggle a role by its name or ID."] #[description = "Toggle a role by its name or ID."]
#[example = "\"IELTS / TOEFL\""] #[example = "\"IELTS / TOEFL\""]
#[num_args(1)] #[num_args(1)]
#[only_in(guilds)] #[only_in(guilds)]
fn toggle(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { async fn toggle(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let role = args.single_quoted::<String>()?; let role = args.single_quoted::<String>()?;
let guild_id = m.guild_id.unwrap(); let guild_id = m.guild_id.unwrap();
let roles = guild_id.to_partial_guild(&ctx)?.roles; let guild = guild_id.to_partial_guild(&ctx).await?;
let role = role_from_string(&role, &roles); let role = role_from_string(&role, &guild.roles);
match role { match role {
None => { None => {
m.reply(&ctx, "No such role exists")?; m.reply(&ctx, "No such role exists").await?;
} }
Some(role) Some(role)
if !DB::open(&*ctx.data.read()) if !DB::open(&*ctx.data.read().await)
.borrow()? .borrow()?
.get(&guild_id) .get(&guild_id)
.map(|g| g.contains_key(&role.id)) .map(|g| g.contains_key(&role.id))
.unwrap_or(false) => .unwrap_or(false) =>
{ {
m.reply(&ctx, "This role is not self-assignable. Check the `listroles` command to see which role can be assigned.")?; m.reply(&ctx, "This role is not self-assignable. Check the `listroles` command to see which role can be assigned.").await?;
} }
Some(role) => { Some(role) => {
let mut member = m.member(&ctx).ok_or(Error::from("Cannot find member"))?; let mut member = guild.member(&ctx, m.author.id).await.unwrap();
if member.roles.contains(&role.id) { if member.roles.contains(&role.id) {
member.remove_role(&ctx, &role)?; member.remove_role(&ctx, &role).await?;
m.reply(&ctx, format!("Role `{}` has been removed.", role.name))?; m.reply(&ctx, format!("Role `{}` has been removed.", role.name))
.await?;
} else { } else {
member.add_role(&ctx, &role)?; member.add_role(&ctx, &role).await?;
m.reply(&ctx, format!("Role `{}` has been assigned.", role.name))?; m.reply(&ctx, format!("Role `{}` has been assigned.", role.name))
.await?;
} }
} }
}; };
@ -144,27 +156,29 @@ fn toggle(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult {
#[num_args(2)] #[num_args(2)]
#[required_permissions(MANAGE_ROLES)] #[required_permissions(MANAGE_ROLES)]
#[only_in(guilds)] #[only_in(guilds)]
fn add(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { async fn add(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let role = args.single_quoted::<String>()?; let role = args.single_quoted::<String>()?;
let data = ctx.data.read().await;
let description = args.single::<String>()?; let description = args.single::<String>()?;
let guild_id = m.guild_id.unwrap(); let guild_id = m.guild_id.unwrap();
let roles = guild_id.to_partial_guild(&ctx)?.roles; let roles = guild_id.to_partial_guild(&ctx).await?.roles;
let role = role_from_string(&role, &roles); let role = role_from_string(&role, &roles);
match role { match role {
None => { None => {
m.reply(&ctx, "No such role exists")?; m.reply(&ctx, "No such role exists").await?;
} }
Some(role) Some(role)
if DB::open(&*ctx.data.read()) if DB::open(&*data)
.borrow()? .borrow()?
.get(&guild_id) .get(&guild_id)
.map(|g| g.contains_key(&role.id)) .map(|g| g.contains_key(&role.id))
.unwrap_or(false) => .unwrap_or(false) =>
{ {
m.reply(&ctx, "This role already exists in the database.")?; m.reply(&ctx, "This role already exists in the database.")
.await?;
} }
Some(role) => { Some(role) => {
DB::open(&*ctx.data.read()) DB::open(&*data)
.borrow_mut()? .borrow_mut()?
.entry(guild_id) .entry(guild_id)
.or_default() .or_default()
@ -175,7 +189,7 @@ fn add(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult {
description, description,
}, },
); );
m.react(&ctx, "👌🏼")?; m.react(&ctx, '👌').await?;
} }
}; };
Ok(()) Ok(())
@ -188,31 +202,33 @@ fn add(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult {
#[num_args(1)] #[num_args(1)]
#[required_permissions(MANAGE_ROLES)] #[required_permissions(MANAGE_ROLES)]
#[only_in(guilds)] #[only_in(guilds)]
fn remove(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { async fn remove(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let role = args.single_quoted::<String>()?; let role = args.single_quoted::<String>()?;
let data = ctx.data.read().await;
let guild_id = m.guild_id.unwrap(); let guild_id = m.guild_id.unwrap();
let roles = guild_id.to_partial_guild(&ctx)?.roles; let roles = guild_id.to_partial_guild(&ctx).await?.roles;
let role = role_from_string(&role, &roles); let role = role_from_string(&role, &roles);
match role { match role {
None => { None => {
m.reply(&ctx, "No such role exists")?; m.reply(&ctx, "No such role exists").await?;
} }
Some(role) Some(role)
if !DB::open(&*ctx.data.read()) if !DB::open(&*data)
.borrow()? .borrow()?
.get(&guild_id) .get(&guild_id)
.map(|g| g.contains_key(&role.id)) .map(|g| g.contains_key(&role.id))
.unwrap_or(false) => .unwrap_or(false) =>
{ {
m.reply(&ctx, "This role does not exist in the assignable list.")?; m.reply(&ctx, "This role does not exist in the assignable list.")
.await?;
} }
Some(role) => { Some(role) => {
DB::open(&*ctx.data.read()) DB::open(&*data)
.borrow_mut()? .borrow_mut()?
.entry(guild_id) .entry(guild_id)
.or_default() .or_default()
.remove(&role.id); .remove(&role.id);
m.react(&ctx, "👌🏼")?; m.react(&ctx, '👌').await?;
} }
}; };
Ok(()) Ok(())

View file

@ -1,14 +1,18 @@
use serenity::framework::standard::CommandError as Error; use serenity::framework::standard::CommandError as Error;
use serenity::{ use serenity::{
collector::ReactionAction,
framework::standard::{macros::command, Args, CommandResult}, framework::standard::{macros::command, Args, CommandResult},
model::{ model::{
channel::{Message, Reaction, ReactionType}, channel::{Message, ReactionType},
id::UserId, id::UserId,
}, },
utils::MessageBuilder, utils::MessageBuilder,
}; };
use std::collections::{HashMap as Map, HashSet as Set};
use std::time::Duration; use std::time::Duration;
use std::{
collections::{HashMap as Map, HashSet as Set},
convert::TryFrom,
};
use youmubot_prelude::{Duration as ParseDuration, *}; use youmubot_prelude::{Duration as ParseDuration, *};
#[command] #[command]
@ -19,13 +23,13 @@ use youmubot_prelude::{Duration as ParseDuration, *};
#[only_in(guilds)] #[only_in(guilds)]
#[min_args(2)] #[min_args(2)]
#[owner_privilege] #[owner_privilege]
pub fn vote(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { pub async fn vote(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
// Parse stuff first // Parse stuff first
let args = args.quoted(); let args = args.quoted();
let _duration = args.single::<ParseDuration>()?; let _duration = args.single::<ParseDuration>()?;
let duration = &_duration.0; let duration = &_duration.0;
if *duration < Duration::from_secs(2 * 60) || *duration > Duration::from_secs(60 * 60 * 24) { if *duration < Duration::from_secs(2) || *duration > Duration::from_secs(60 * 60 * 24) {
msg.reply(ctx, format!("😒 Invalid duration ({}). The voting time should be between **2 minutes** and **1 day**.", _duration))?; msg.reply(ctx, format!("😒 Invalid duration ({}). The voting time should be between **2 minutes** and **1 day**.", _duration)).await?;
return Ok(()); return Ok(());
} }
let question = args.single::<String>()?; let question = args.single::<String>()?;
@ -41,7 +45,8 @@ pub fn vote(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
msg.reply( msg.reply(
ctx, ctx,
"😒 Can't have a nice voting session if you only have one choice.", "😒 Can't have a nice voting session if you only have one choice.",
)?; )
.await?;
return Ok(()); return Ok(());
} }
if choices.len() > MAX_CHOICES { if choices.len() > MAX_CHOICES {
@ -52,7 +57,8 @@ pub fn vote(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
"😵 Too many choices... We only support {} choices at the moment!", "😵 Too many choices... We only support {} choices at the moment!",
MAX_CHOICES MAX_CHOICES
), ),
)?; )
.await?;
return Ok(()); return Ok(());
} }
@ -89,123 +95,116 @@ pub fn vote(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
.description(MessageBuilder::new().push_bold_line_safe(&question).push("\nThis question was asked by ").push(author.mention())) .description(MessageBuilder::new().push_bold_line_safe(&question).push("\nThis question was asked by ").push(author.mention()))
.fields(fields.into_iter()) .fields(fields.into_iter())
}) })
})?; }).await?;
msg.delete(&ctx)?; msg.delete(&ctx).await?;
drop(msg);
// React on all the choices // React on all the choices
choices choices
.iter() .iter()
.try_for_each(|(emote, _)| panel.react(&ctx, &emote[..]))?; .map(|(emote, _)| {
panel
.react(&ctx, ReactionType::try_from(&emote[..]).unwrap())
.map_ok(|_| ())
})
.collect::<stream::FuturesUnordered<_>>()
.try_collect::<()>()
.await?;
// A handler for votes. // A handler for votes.
struct VoteHandler { let user_reactions: Map<String, Set<UserId>> = choices
pub ctx: Context, .iter()
pub msg: Message, .map(|(emote, _)| (emote.clone(), Set::new()))
pub user_reactions: Map<String, Set<UserId>>, .collect();
pub panel: Message, // Collect reactions...
} let user_reactions = panel
.await_reactions(&ctx)
impl VoteHandler { .removed(true)
fn new(ctx: Context, msg: Message, panel: Message, choices: &[(String, String)]) -> Self { .timeout(*duration)
VoteHandler { .await
ctx, .fold(user_reactions, |mut set, reaction| async move {
msg, let (reaction, is_add) = match &*reaction {
user_reactions: choices ReactionAction::Added(r) => (r, true),
.iter() ReactionAction::Removed(r) => (r, false),
.map(|(emote, _)| (emote.clone(), Set::new())) };
.collect(),
panel,
}
}
}
impl ReactionHandler for VoteHandler {
fn handle_reaction(&mut self, reaction: &Reaction, is_add: bool) -> CommandResult {
if reaction.message_id != self.panel.id {
return Ok(());
}
if reaction.user(&self.ctx)?.bot {
return Ok(());
}
let users = if let ReactionType::Unicode(ref s) = reaction.emoji { let users = if let ReactionType::Unicode(ref s) = reaction.emoji {
if let Some(users) = self.user_reactions.get_mut(s.as_str()) { if let Some(users) = set.get_mut(s.as_str()) {
users users
} else { } else {
return Ok(()); return set;
} }
} else { } else {
return Ok(()); return set;
};
let user_id = match reaction.user_id {
Some(v) => v,
None => return set,
}; };
if is_add { if is_add {
users.insert(reaction.user_id); users.insert(user_id);
} else { } else {
users.remove(&reaction.user_id); users.remove(&user_id);
} }
Ok(()) set
} })
.await;
// Handle choices
let choice_map = choices.into_iter().collect::<Map<_, _>>();
let mut result: Vec<(String, Vec<UserId>)> = user_reactions
.into_iter()
.filter(|(_, users)| !users.is_empty())
.map(|(emote, users)| (emote, users.into_iter().collect()))
.collect();
result.sort_unstable_by(|(_, v), (_, w)| w.len().cmp(&v.len()));
if result.len() == 0 {
msg.reply(
&ctx,
MessageBuilder::new()
.push("no one answer your question ")
.push_bold_safe(&question)
.push(", sorry 😭")
.build(),
)
.await?;
return Ok(());
} }
ctx.data channel
.get_cloned::<ReactionWatcher>() .send_message(&ctx, |c| {
.handle_reactions_timed( c.content({
VoteHandler::new(ctx.clone(), msg.clone(), panel, &choices), let mut content = MessageBuilder::new();
*duration, content
move |vh| { .push("@here, ")
let (ctx, msg, user_reactions, panel) = .push(author.mention())
(vh.ctx, vh.msg, vh.user_reactions, vh.panel); .push(" previously asked ")
let choice_map = choices.into_iter().collect::<Map<_, _>>(); .push_bold_safe(&question)
let result: Vec<(String, Vec<UserId>)> = user_reactions .push(", and here are the results!");
.into_iter() result.into_iter().for_each(|(emote, votes)| {
.filter(|(_, users)| !users.is_empty()) content
.map(|(emote, users)| (emote, users.into_iter().collect())) .push("\n - ")
.collect(); .push_bold(format!("{}", votes.len()))
.push(" voted for ")
if result.len() == 0 { .push(&emote)
msg.reply( .push(" ")
&ctx, .push_bold_safe(choice_map.get(&emote).unwrap())
MessageBuilder::new() .push(": ")
.push("no one answer your question ") .push(
.push_bold_safe(&question) votes
.push(", sorry 😭") .into_iter()
.build(), .map(|v| v.mention())
) .collect::<Vec<_>>()
.ok(); .join(", "),
} else { );
channel });
.send_message(&ctx, |c| { content.build()
c.content({ })
let mut content = MessageBuilder::new(); })
content .await?;
.push("@here, ") panel.delete(&ctx).await?;
.push(author.mention())
.push(" previously asked ")
.push_bold_safe(&question)
.push(", and here are the results!");
result.into_iter().for_each(|(emote, votes)| {
content
.push("\n - ")
.push_bold(format!("{}", votes.len()))
.push(" voted for ")
.push(&emote)
.push(" ")
.push_bold_safe(choice_map.get(&emote).unwrap())
.push(": ")
.push(
votes
.into_iter()
.map(|v| v.mention())
.collect::<Vec<_>>()
.join(", "),
);
});
content.build()
})
})
.ok();
}
panel.delete(&ctx).ok();
},
);
Ok(()) Ok(())
// unimplemented!(); // unimplemented!();
@ -239,3 +238,4 @@ const REACTIONS: [&'static str; 90] = [
// Assertions // Assertions
static_assertions::const_assert!(MAX_CHOICES <= REACTIONS.len()); static_assertions::const_assert!(MAX_CHOICES <= REACTIONS.len());
static_assertions::const_assert!(MAX_CHOICES <= REACTIONS.len());

View file

@ -14,30 +14,25 @@ pub type Roles = DB<GuildMap<HashMap<RoleId, Role>>>;
/// For the admin commands: /// For the admin commands:
/// - Each server might have a `soft ban` role implemented. /// - Each server might have a `soft ban` role implemented.
/// - We allow periodical `soft ban` applications. /// - We allow periodical `soft ban` applications.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ServerSoftBans {
Implemented(ImplementedSoftBans),
Unimplemented,
}
impl ServerSoftBans {
// Create a new, implemented role.
pub fn new_implemented(role: RoleId) -> ServerSoftBans {
ServerSoftBans::Implemented(ImplementedSoftBans {
role,
periodical_bans: HashMap::new(),
})
}
}
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ImplementedSoftBans { pub struct ServerSoftBans {
/// The soft-ban role. /// The soft-ban role.
pub role: RoleId, pub role: RoleId,
/// List of all to-unban people. /// List of all to-unban people.
pub periodical_bans: HashMap<UserId, DateTime<Utc>>, pub periodical_bans: HashMap<UserId, DateTime<Utc>>,
} }
impl ServerSoftBans {
// Create a new, implemented role.
pub fn new(role: RoleId) -> Self {
Self {
role,
periodical_bans: HashMap::new(),
}
}
}
/// Role represents an assignable role. /// Role represents an assignable role.
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Role { pub struct Role {

View file

@ -15,24 +15,24 @@ use youmubot_prelude::*;
#[description = "🖼️ Find an image with a given tag on Danbooru[nsfw]!"] #[description = "🖼️ Find an image with a given tag on Danbooru[nsfw]!"]
#[min_args(1)] #[min_args(1)]
#[bucket("images")] #[bucket("images")]
pub fn nsfw(ctx: &mut Context, msg: &Message, args: Args) -> CommandResult { pub async fn nsfw(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
message_command(ctx, msg, args, Rating::Explicit) message_command(ctx, msg, args, Rating::Explicit).await
} }
#[command] #[command]
#[description = "🖼️ Find an image with a given tag on Danbooru[safe]!"] #[description = "🖼️ Find an image with a given tag on Danbooru[safe]!"]
#[min_args(1)] #[min_args(1)]
#[bucket("images")] #[bucket("images")]
pub fn image(ctx: &mut Context, msg: &Message, args: Args) -> CommandResult { pub async fn image(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
message_command(ctx, msg, args, Rating::Safe) message_command(ctx, msg, args, Rating::Safe).await
} }
#[check] #[check]
#[name = "nsfw"] #[name = "nsfw"]
fn nsfw_check(ctx: &mut Context, msg: &Message, _: &mut Args, _: &CommandOptions) -> CheckResult { async fn nsfw_check(ctx: &Context, msg: &Message, _: &mut Args, _: &CommandOptions) -> CheckResult {
let channel = msg.channel_id.to_channel(&ctx).unwrap(); let channel = msg.channel_id.to_channel(&ctx).await.unwrap();
if !(match channel { if !(match channel {
Channel::Guild(guild_channel) => guild_channel.read().nsfw, Channel::Guild(guild_channel) => guild_channel.nsfw,
_ => true, _ => true,
}) { }) {
CheckResult::Failure(Reason::User("😣 YOU FREAKING PERVERT!!!".to_owned())) CheckResult::Failure(Reason::User("😣 YOU FREAKING PERVERT!!!".to_owned()))
@ -41,22 +41,31 @@ fn nsfw_check(ctx: &mut Context, msg: &Message, _: &mut Args, _: &CommandOptions
} }
} }
fn message_command(ctx: &mut Context, msg: &Message, args: Args, rating: Rating) -> CommandResult { async fn message_command(
ctx: &Context,
msg: &Message,
args: Args,
rating: Rating,
) -> CommandResult {
let tags = args.remains().unwrap_or("touhou"); let tags = args.remains().unwrap_or("touhou");
let http = ctx.data.get_cloned::<HTTPClient>(); let image = get_image(
let image = get_image(&http, rating, tags)?; ctx.data.read().await.get::<HTTPClient>().unwrap(),
rating,
tags,
)
.await?;
match image { match image {
None => msg.reply(&ctx, "🖼️ No image found...\n💡 Tip: In danbooru, character names follow Japanese standards (last name before first name), so **Hakurei Reimu** might give you an image while **Reimu Hakurei** won't."), None => msg.reply(&ctx, "🖼️ No image found...\n💡 Tip: In danbooru, character names follow Japanese standards (last name before first name), so **Hakurei Reimu** might give you an image while **Reimu Hakurei** won't.").await,
Some(url) => msg.reply( Some(url) => msg.reply(
&ctx, &ctx,
format!("🖼️ Here's the image you requested!\n\n{}", url), format!("🖼️ Here's the image you requested!\n\n{}", url),
), ).await,
}?; }?;
Ok(()) Ok(())
} }
// Gets an image URL. // Gets an image URL.
fn get_image( async fn get_image(
client: &<HTTPClient as TypeMapKey>::Value, client: &<HTTPClient as TypeMapKey>::Value,
rating: Rating, rating: Rating,
tags: &str, tags: &str,
@ -72,7 +81,7 @@ fn get_image(
.query(&[("limit", "1"), ("random", "true")]) .query(&[("limit", "1"), ("random", "true")])
.build()?; .build()?;
println!("{:?}", req.url()); println!("{:?}", req.url());
let response: Vec<PostResponse> = client.execute(req)?.json()?; let response: Vec<PostResponse> = client.execute(req).await?.json().await?;
Ok(response Ok(response
.into_iter() .into_iter()
.next() .next()

View file

@ -28,7 +28,7 @@ struct Fun;
#[max_args(2)] #[max_args(2)]
#[usage = "[max-dice-faces = 6] / [message]"] #[usage = "[max-dice-faces = 6] / [message]"]
#[example = "100 / What's my score?"] #[example = "100 / What's my score?"]
fn roll(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { async fn roll(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let dice = if args.is_empty() { let dice = if args.is_empty() {
6 6
} else { } else {
@ -36,7 +36,8 @@ fn roll(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
}; };
if dice == 0 { if dice == 0 {
msg.reply(&ctx, "Give me a dice with 0 faces, what do you expect 😒")?; msg.reply(&ctx, "Give me a dice with 0 faces, what do you expect 😒")
.await?;
return Ok(()); return Ok(());
} }
@ -47,24 +48,30 @@ fn roll(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
}; };
match args.single_quoted::<String>() { match args.single_quoted::<String>() {
Ok(s) => msg.reply( Ok(s) => {
&ctx, msg.reply(
MessageBuilder::new() &ctx,
.push("you asked ") MessageBuilder::new()
.push_bold_safe(s) .push("you asked ")
.push(format!( .push_bold_safe(s)
", so I rolled a 🎲 of **{}** faces, and got **{}**!", .push(format!(
", so I rolled a 🎲 of **{}** faces, and got **{}**!",
dice, result
))
.build(),
)
.await
}
Err(_) if args.is_empty() => {
msg.reply(
&ctx,
format!(
"I rolled a 🎲 of **{}** faces, and got **{}**!",
dice, result dice, result
)) ),
.build(), )
), .await
Err(_) if args.is_empty() => msg.reply( }
&ctx,
format!(
"I rolled a 🎲 of **{}** faces, and got **{}**!",
dice, result
),
),
Err(e) => return Err(e.into()), Err(e) => return Err(e.into()),
}?; }?;
@ -77,7 +84,7 @@ You may prefix the first choice with `?` to make it a question!
If no choices are given, Youmu defaults to `Yes!` and `No!`"#] If no choices are given, Youmu defaults to `Yes!` and `No!`"#]
#[usage = "[?question]/[choice #1]/[choice #2]/..."] #[usage = "[?question]/[choice #1]/[choice #2]/..."]
#[example = "?What for dinner/Pizza/Hamburger"] #[example = "?What for dinner/Pizza/Hamburger"]
fn pick(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { async fn pick(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let (question, choices) = { let (question, choices) = {
// Get a list of options. // Get a list of options.
let mut choices = args let mut choices = args
@ -114,24 +121,30 @@ fn pick(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
}; };
match question { match question {
None => msg.reply( None => {
&ctx, msg.reply(
MessageBuilder::new() &ctx,
.push("Youmu picks 👉") MessageBuilder::new()
.push_bold_safe(choice) .push("Youmu picks 👉")
.push("👈!") .push_bold_safe(choice)
.build(), .push("👈!")
), .build(),
Some(s) => msg.reply( )
&ctx, .await
MessageBuilder::new() }
.push("you asked ") Some(s) => {
.push_bold_safe(s) msg.reply(
.push(", and Youmu picks 👉") &ctx,
.push_bold_safe(choice) MessageBuilder::new()
.push("👈!") .push("you asked ")
.build(), .push_bold_safe(s)
), .push(", and Youmu picks 👉")
.push_bold_safe(choice)
.push("👈!")
.build(),
)
.await
}
}?; }?;
Ok(()) Ok(())
@ -142,7 +155,7 @@ fn pick(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
#[usage = "[user_mention = yourself]"] #[usage = "[user_mention = yourself]"]
#[example = "@user#1234"] #[example = "@user#1234"]
#[max_args(1)] #[max_args(1)]
fn name(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { async fn name(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let user_id = if args.is_empty() { let user_id = if args.is_empty() {
msg.author.id msg.author.id
} else { } else {
@ -153,15 +166,15 @@ fn name(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
"your".to_owned() "your".to_owned()
} else { } else {
MessageBuilder::new() MessageBuilder::new()
.push_bold_safe(user_id.to_user(&ctx)?.tag()) .push_bold_safe(user_id.to_user(&ctx).await?.tag())
.push("'s") .push("'s")
.build() .build()
}; };
// Rule out a couple of cases // Rule out a couple of cases
if user_id == ctx.http.get_current_application_info()?.id { if user_id == ctx.http.get_current_application_info().await?.id {
// This is my own user_id // This is my own user_id
msg.reply(&ctx, "😠 My name is **Youmu Konpaku**!")?; msg.reply(&ctx, "😠 My name is **Youmu Konpaku**!").await?;
return Ok(()); return Ok(());
} }
@ -173,6 +186,7 @@ fn name(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
"{} Japanese🇯🇵 name is **{} {}**!", "{} Japanese🇯🇵 name is **{} {}**!",
user_mention, first_name, last_name user_mention, first_name, last_name
), ),
)?; )
.await?;
Ok(()) Ok(())
} }

View file

@ -20,26 +20,30 @@ pub use fun::FUN_GROUP;
pub fn setup( pub fn setup(
path: &std::path::Path, path: &std::path::Path,
client: &serenity::client::Client, client: &serenity::client::Client,
data: &mut youmubot_prelude::ShareMap, data: &mut TypeMap,
) -> serenity::framework::standard::CommandResult { ) -> serenity::framework::standard::CommandResult {
db::SoftBans::insert_into(&mut *data, &path.join("soft_bans.yaml"))?; db::SoftBans::insert_into(&mut *data, &path.join("soft_bans.yaml"))?;
db::Roles::insert_into(&mut *data, &path.join("roles.yaml"))?; db::Roles::insert_into(&mut *data, &path.join("roles.yaml"))?;
// Create handler threads // Create handler threads
std::thread::spawn(admin::watch_soft_bans(client)); tokio::spawn(admin::watch_soft_bans(
client.cache_and_http.clone(),
client.data.clone(),
));
Ok(()) Ok(())
} }
// A help command // A help command
#[help] #[help]
pub fn help( pub async fn help(
context: &mut Context, context: &Context,
msg: &Message, msg: &Message,
args: Args, args: Args,
help_options: &'static HelpOptions, help_options: &'static HelpOptions,
groups: &[&'static CommandGroup], groups: &[&'static CommandGroup],
owners: HashSet<UserId>, owners: HashSet<UserId>,
) -> CommandResult { ) -> CommandResult {
help_commands::with_embeds(context, msg, args, help_options, groups, owners) help_commands::with_embeds(context, msg, args, help_options, groups, owners).await;
Ok(())
} }

View file

@ -7,18 +7,11 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
serenity = "0.8" serenity = "0.9.0-rc.0"
dotenv = "0.15" dotenv = "0.15"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
chrono = "0.4.9" chrono = "0.4.9"
# rand = "0.7.2"
# static_assertions = "1.1.0"
# reqwest = "0.10.1"
# regex = "1"
# lazy_static = "1"
# youmubot-osu = { path = "../youmubot-osu" }
rayon = "1.1"
[dependencies.rustbreak] [dependencies.rustbreak]
version = "2.0.0-rc3" version = "2.0.0"
features = ["yaml_enc"] features = ["yaml_enc"]

View file

@ -1,15 +1,22 @@
use rustbreak::{deser::Yaml as Ron, FileDatabase}; use rustbreak::{deser::Yaml, FileDatabase, RustbreakError as DBError};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serenity::{framework::standard::CommandError as Error, model::id::GuildId, prelude::*}; use serenity::{
use std::collections::HashMap; model::id::GuildId,
use std::{cell::Cell, path::Path, sync::Arc}; prelude::{TypeMap, TypeMapKey},
};
use std::{collections::HashMap, path::Path};
/// GuildMap defines the guild-map type. /// GuildMap defines the guild-map type.
/// It is basically a HashMap from a GuildId to a data structure. /// It is basically a HashMap from a GuildId to a data structure.
pub type GuildMap<V> = HashMap<GuildId, V>; pub type GuildMap<V> = HashMap<GuildId, V>;
/// The generic DB type we will be using. /// The generic DB type we will be using.
pub struct DB<T>(std::marker::PhantomData<T>); pub struct DB<T>(std::marker::PhantomData<T>);
impl<T: std::any::Any> serenity::prelude::TypeMapKey for DB<T> {
type Value = Arc<FileDatabase<T, Ron>>; /// A short type abbreviation for a FileDatabase.
type Database<T> = FileDatabase<T, Yaml>;
impl<T: std::any::Any + Send + Sync> TypeMapKey for DB<T> {
type Value = Database<T>;
} }
impl<T: std::any::Any + Default + Send + Sync + Clone + Serialize + std::fmt::Debug> DB<T> impl<T: std::any::Any + Default + Send + Sync + Clone + Serialize + std::fmt::Debug> DB<T>
@ -17,66 +24,62 @@ where
for<'de> T: Deserialize<'de>, for<'de> T: Deserialize<'de>,
{ {
/// Insert into a ShareMap. /// Insert into a ShareMap.
pub fn insert_into(data: &mut ShareMap, path: impl AsRef<Path>) -> Result<(), Error> { pub fn insert_into(data: &mut TypeMap, path: impl AsRef<Path>) -> Result<(), DBError> {
let db = FileDatabase::<T, Ron>::from_path(path, T::default())?; let db = Database::<T>::load_from_path_or_default(path)?;
db.load().or_else(|e| { data.insert::<DB<T>>(db);
dbg!(e);
db.save()
})?;
data.insert::<DB<T>>(Arc::new(db));
Ok(()) Ok(())
} }
/// Open a previously inserted DB. /// Open a previously inserted DB.
pub fn open(data: &ShareMap) -> DBWriteGuard<T> { pub fn open(data: &TypeMap) -> DBWriteGuard<T> {
data.get::<Self>().expect("DB initialized").clone().into() data.get::<Self>().expect("DB initialized").into()
} }
} }
/// The write guard for our FileDatabase. /// The write guard for our FileDatabase.
/// It wraps the FileDatabase in a write-on-drop lock. /// It wraps the FileDatabase in a write-on-drop lock.
#[derive(Debug)] #[derive(Debug)]
pub struct DBWriteGuard<T> pub struct DBWriteGuard<'a, T>
where where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned, T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned,
{ {
db: Arc<FileDatabase<T, Ron>>, db: &'a Database<T>,
needs_save: Cell<bool>, needs_save: bool,
} }
impl<T> From<Arc<FileDatabase<T, Ron>>> for DBWriteGuard<T> impl<'a, T> From<&'a Database<T>> for DBWriteGuard<'a, T>
where where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned, T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned,
{ {
fn from(v: Arc<FileDatabase<T, Ron>>) -> Self { fn from(v: &'a Database<T>) -> Self {
DBWriteGuard { DBWriteGuard {
db: v, db: v,
needs_save: Cell::new(false), needs_save: false,
} }
} }
} }
impl<T> DBWriteGuard<T> impl<'a, T> DBWriteGuard<'a, T>
where where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned, T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned,
{ {
/// Borrows the FileDatabase. /// Borrows the FileDatabase.
pub fn borrow(&self) -> Result<std::sync::RwLockReadGuard<T>, rustbreak::RustbreakError> { pub fn borrow(&self) -> Result<std::sync::RwLockReadGuard<T>, DBError> {
(*self).db.borrow_data() self.db.borrow_data()
} }
/// Borrows the FileDatabase for writing. /// Borrows the FileDatabase for writing.
pub fn borrow_mut(&self) -> Result<std::sync::RwLockWriteGuard<T>, rustbreak::RustbreakError> { pub fn borrow_mut(&mut self) -> Result<std::sync::RwLockWriteGuard<T>, DBError> {
self.needs_save.set(true); self.needs_save = true;
(*self).db.borrow_data_mut() self.db.borrow_data_mut()
} }
} }
impl<T> Drop for DBWriteGuard<T> impl<'a, T> Drop for DBWriteGuard<'a, T>
where where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned, T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned,
{ {
fn drop(&mut self) { fn drop(&mut self) {
if self.needs_save.get() { if self.needs_save {
if let Err(e) = self.db.save() { if let Err(e) = self.db.save() {
dbg!(e); dbg!(e);
} }

View file

@ -7,12 +7,11 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
serenity = "0.8" serenity = "0.9.0-rc.0"
chrono = "0.4.10" chrono = "0.4.10"
reqwest = "0.10.1" reqwest = "0.10.1"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
bitflags = "1" bitflags = "1"
rayon = "1.1"
lazy_static = "1" lazy_static = "1"
regex = "1" regex = "1"
oppai-rs = "0.2.0" oppai-rs = "0.2.0"

View file

@ -9,101 +9,159 @@ use crate::{
Client as Osu, Client as Osu,
}; };
use announcer::MemberToChannels; use announcer::MemberToChannels;
use rayon::prelude::*;
use serenity::{ use serenity::{
framework::standard::{CommandError as Error, CommandResult},
http::CacheHttp, http::CacheHttp,
model::id::{ChannelId, UserId}, model::id::{ChannelId, UserId},
CacheAndHttp, CacheAndHttp,
}; };
use std::sync::Arc; use std::{collections::HashMap, sync::Arc};
use youmubot_prelude::*; use youmubot_prelude::*;
/// osu! announcer's unique announcer key. /// osu! announcer's unique announcer key.
pub const ANNOUNCER_KEY: &'static str = "osu"; pub const ANNOUNCER_KEY: &'static str = "osu";
/// Announce osu! top scores. /// The announcer struct implementing youmubot_prelude::Announcer
pub fn updates(c: Arc<CacheAndHttp>, d: AppData, channels: MemberToChannels) -> CommandResult { pub struct Announcer;
let osu = d.get_cloned::<OsuClient>();
let cache = d.get_cloned::<BeatmapMetaCache>(); #[async_trait]
let oppai = d.get_cloned::<BeatmapCache>(); impl youmubot_prelude::Announcer for Announcer {
// For each user... async fn updates(
let mut data = OsuSavedUsers::open(&*d.read()).borrow()?.clone(); &mut self,
for (user_id, osu_user) in data.iter_mut() { c: Arc<CacheAndHttp>,
let channels = channels.channels_of(c.clone(), *user_id); d: AppData,
if channels.is_empty() { channels: MemberToChannels,
continue; // We don't wanna update an user without any active server ) -> Result<()> {
} // For each user...
osu_user.pp = match (&[Mode::Std, Mode::Taiko, Mode::Catch, Mode::Mania]) let data = OsuSavedUsers::open(&*d.read().await).borrow()?.clone();
.par_iter() let data = data
.map(|m| { .into_iter()
handle_user_mode( .map(|(user_id, osu_user)| {
c.clone(), let d = d.clone();
&osu, let channels = &channels;
&cache, let c = c.clone();
&oppai, async move {
&osu_user, let channels = channels.channels_of(c.clone(), user_id).await;
*user_id, if channels.is_empty() {
&channels[..], return (user_id, osu_user); // We don't wanna update an user without any active server
*m, }
d.clone(), let pp = match (&[Mode::Std, Mode::Taiko, Mode::Catch, Mode::Mania])
) .into_iter()
.map(|m| {
handle_user_mode(
c.clone(),
&osu_user,
user_id,
channels.clone(),
*m,
d.clone(),
)
})
.collect::<stream::FuturesOrdered<_>>()
.try_collect::<Vec<_>>()
.await
{
Ok(v) => v,
Err(e) => {
eprintln!("osu: Cannot update {}: {}", osu_user.id, e);
return (user_id, osu_user);
}
};
let last_update = chrono::Utc::now();
(
user_id,
OsuUser {
pp,
last_update,
..osu_user
},
)
}
}) })
.collect::<Result<_, _>>() .collect::<stream::FuturesUnordered<_>>()
{ .collect::<HashMap<_, _>>()
Ok(v) => v, .await;
Err(e) => { // Update users
eprintln!("osu: Cannot update {}: {}", osu_user.id, e.0); *OsuSavedUsers::open(&*d.read().await).borrow_mut()? = data;
continue; Ok(())
}
};
osu_user.last_update = chrono::Utc::now();
} }
// Update users
*OsuSavedUsers::open(&*d.read()).borrow_mut()? = data;
Ok(())
} }
/// Handles an user/mode scan, announces all possible new scores, return the new pp value. /// Handles an user/mode scan, announces all possible new scores, return the new pp value.
fn handle_user_mode( async fn handle_user_mode(
c: Arc<CacheAndHttp>, c: Arc<CacheAndHttp>,
osu: &Osu,
cache: &BeatmapMetaCache,
oppai: &BeatmapCache,
osu_user: &OsuUser, osu_user: &OsuUser,
user_id: UserId, user_id: UserId,
channels: &[ChannelId], channels: Vec<ChannelId>,
mode: Mode, mode: Mode,
d: AppData, d: AppData,
) -> Result<Option<f64>, Error> { ) -> Result<Option<f64>, Error> {
let scores = scan_user(osu, osu_user, mode)?; let (scores, user) = {
let user = osu let data = d.read().await;
.user(UserID::ID(osu_user.id), |f| f.mode(mode))? let osu = data.get::<OsuClient>().unwrap();
.ok_or(Error::from("user not found"))?; let scores = scan_user(osu, osu_user, mode).await?;
scores let user = osu
.into_par_iter() .user(UserID::ID(osu_user.id), |f| f.mode(mode))
.map(|(rank, score)| -> Result<_, Error> { .await?
let beatmap = cache.get_beatmap_default(score.beatmap_id)?; .ok_or(Error::msg("user not found"))?;
let content = oppai.get_beatmap(beatmap.beatmap_id)?; (scores, user)
Ok((rank, score, BeatmapWithMode(beatmap, mode), content)) };
}) let pp = user.pp;
.filter_map(|v| v.ok()) spawn_future(async move {
.for_each(|(rank, score, beatmap, content)| { scores
for channel in (&channels).iter() { .into_iter()
if let Err(e) = channel.send_message(c.http(), |c| { .map(|(rank, score)| {
c.content(format!("New top record from {}!", user_id.mention())) let d = d.clone();
.embed(|e| score_embed(&score, &beatmap, &content, &user, Some(rank), e)) async move {
}) { let data = d.read().await;
dbg!(e); let cache = data.get::<BeatmapMetaCache>().unwrap();
let oppai = data.get::<BeatmapCache>().unwrap();
let beatmap = cache.get_beatmap_default(score.beatmap_id).await?;
let content = oppai.get_beatmap(beatmap.beatmap_id).await?;
let r: Result<_> = Ok((rank, score, BeatmapWithMode(beatmap, mode), content));
r
} }
save_beatmap(&*d.read(), *channel, &beatmap).ok(); })
} .collect::<stream::FuturesOrdered<_>>()
}); .filter_map(|v| future::ready(v.ok()))
Ok(user.pp) .for_each(move |(rank, score, beatmap, content)| {
let channels = channels.clone();
let d = d.clone();
let c = c.clone();
let user = user.clone();
async move {
let data = d.read().await;
for channel in (&channels).iter() {
if let Err(e) = channel
.send_message(c.http(), |c| {
c.content(format!("New top record from {}!", user_id.mention()))
.embed(|e| {
score_embed(
&score,
&beatmap,
&content,
&user,
Some(rank),
e,
)
})
})
.await
{
dbg!(e);
}
save_beatmap(&*data, *channel, &beatmap).ok();
}
}
})
.await;
});
Ok(pp)
} }
fn scan_user(osu: &Osu, u: &OsuUser, mode: Mode) -> Result<Vec<(u8, Score)>, Error> { async fn scan_user(osu: &Osu, u: &OsuUser, mode: Mode) -> Result<Vec<(u8, Score)>, Error> {
let scores = osu.user_best(UserID::ID(u.id), |f| f.mode(mode).limit(25))?; let scores = osu
.user_best(UserID::ID(u.id), |f| f.mode(mode).limit(25))
.await?;
let scores = scores let scores = scores
.into_iter() .into_iter()
.enumerate() .enumerate()

View file

@ -3,16 +3,14 @@ use crate::{
Client, Client,
}; };
use dashmap::DashMap; use dashmap::DashMap;
use serenity::framework::standard::CommandError;
use std::sync::Arc; use std::sync::Arc;
use youmubot_prelude::TypeMapKey; use youmubot_prelude::*;
/// BeatmapMetaCache intercepts beatmap-by-id requests and caches them for later recalling. /// BeatmapMetaCache intercepts beatmap-by-id requests and caches them for later recalling.
/// Does not cache non-Ranked beatmaps. /// Does not cache non-Ranked beatmaps.
#[derive(Clone, Debug)]
pub struct BeatmapMetaCache { pub struct BeatmapMetaCache {
client: Client, client: Arc<Client>,
cache: Arc<DashMap<(u64, Mode), Beatmap>>, cache: DashMap<(u64, Mode), Beatmap>,
} }
impl TypeMapKey for BeatmapMetaCache { impl TypeMapKey for BeatmapMetaCache {
@ -21,13 +19,13 @@ impl TypeMapKey for BeatmapMetaCache {
impl BeatmapMetaCache { impl BeatmapMetaCache {
/// Create a new beatmap cache. /// Create a new beatmap cache.
pub fn new(client: Client) -> Self { pub fn new(client: Arc<Client>) -> Self {
BeatmapMetaCache { BeatmapMetaCache {
client, client,
cache: Arc::new(DashMap::new()), cache: DashMap::new(),
} }
} }
fn insert_if_possible(&self, id: u64, mode: Option<Mode>) -> Result<Beatmap, CommandError> { async fn insert_if_possible(&self, id: u64, mode: Option<Mode>) -> Result<Beatmap> {
let beatmap = self let beatmap = self
.client .client
.beatmaps(crate::BeatmapRequestKind::Beatmap(id), |f| { .beatmaps(crate::BeatmapRequestKind::Beatmap(id), |f| {
@ -36,35 +34,37 @@ impl BeatmapMetaCache {
} }
f f
}) })
.and_then(|v| { .await
v.into_iter() .and_then(|v| v.into_iter().next().ok_or(Error::msg("beatmap not found")))?;
.next()
.ok_or(CommandError::from("beatmap not found"))
})?;
if let ApprovalStatus::Ranked(_) = beatmap.approval { if let ApprovalStatus::Ranked(_) = beatmap.approval {
self.cache.insert((id, beatmap.mode), beatmap.clone()); self.cache.insert((id, beatmap.mode), beatmap.clone());
}; };
Ok(beatmap) Ok(beatmap)
} }
/// Get the given beatmap /// Get the given beatmap
pub fn get_beatmap(&self, id: u64, mode: Mode) -> Result<Beatmap, CommandError> { pub async fn get_beatmap(&self, id: u64, mode: Mode) -> Result<Beatmap> {
self.cache match self.cache.get(&(id, mode)).map(|v| v.clone()) {
.get(&(id, mode)) Some(v) => Ok(v),
.map(|b| Ok(b.clone())) None => self.insert_if_possible(id, Some(mode)).await,
.unwrap_or_else(|| self.insert_if_possible(id, Some(mode))) }
} }
/// Get a beatmap without a mode... /// Get a beatmap without a mode...
pub fn get_beatmap_default(&self, id: u64) -> Result<Beatmap, CommandError> { pub async fn get_beatmap_default(&self, id: u64) -> Result<Beatmap> {
(&[Mode::Std, Mode::Taiko, Mode::Catch, Mode::Mania]) Ok(
.iter() match (&[Mode::Std, Mode::Taiko, Mode::Catch, Mode::Mania])
.filter_map(|&mode| { .iter()
self.cache .filter_map(|&mode| {
.get(&(id, mode)) self.cache
.filter(|b| b.mode == mode) .get(&(id, mode))
.map(|b| Ok(b.clone())) .filter(|b| b.mode == mode)
}) .map(|b| b.clone())
.next() })
.unwrap_or_else(|| self.insert_if_possible(id, None)) .next()
{
Some(v) => v,
None => self.insert_if_possible(id, None).await?,
},
)
} }
} }

View file

@ -1,30 +1,26 @@
use super::db::OsuLastBeatmap; use super::db::OsuLastBeatmap;
use super::BeatmapWithMode; use super::BeatmapWithMode;
use serenity::{ use serenity::model::id::ChannelId;
framework::standard::{CommandError as Error, CommandResult}, use youmubot_prelude::*;
model::id::ChannelId,
prelude::*,
};
/// Save the beatmap into the server data storage. /// Save the beatmap into the server data storage.
pub(crate) fn save_beatmap( pub(crate) fn save_beatmap(
data: &ShareMap, data: &TypeMap,
channel_id: ChannelId, channel_id: ChannelId,
bm: &BeatmapWithMode, bm: &BeatmapWithMode,
) -> CommandResult { ) -> Result<()> {
let db = OsuLastBeatmap::open(data); OsuLastBeatmap::open(data)
let mut db = db.borrow_mut()?; .borrow_mut()?
.insert(channel_id, (bm.0.clone(), bm.mode()));
db.insert(channel_id, (bm.0.clone(), bm.mode()));
Ok(()) Ok(())
} }
/// Get the last beatmap requested from this channel. /// Get the last beatmap requested from this channel.
pub(crate) fn get_beatmap( pub(crate) fn get_beatmap(
data: &ShareMap, data: &TypeMap,
channel_id: ChannelId, channel_id: ChannelId,
) -> Result<Option<BeatmapWithMode>, Error> { ) -> Result<Option<BeatmapWithMode>> {
let db = OsuLastBeatmap::open(data); let db = OsuLastBeatmap::open(data);
let db = db.borrow()?; let db = db.borrow()?;

View file

@ -7,12 +7,7 @@ use crate::{
}; };
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use serenity::{ use serenity::{builder::CreateMessage, model::channel::Message, utils::MessageBuilder};
builder::CreateMessage,
framework::standard::{CommandError as Error, CommandResult},
model::channel::Message,
utils::MessageBuilder,
};
use std::str::FromStr; use std::str::FromStr;
use youmubot_prelude::*; use youmubot_prelude::*;
@ -26,47 +21,58 @@ lazy_static! {
r"(?:https?://)?osu\.ppy\.sh/beatmapsets/(?P<set_id>\d+)/?(?:\#(?P<mode>osu|taiko|fruits|mania)(?:/(?P<beatmap_id>\d+)|/?))?(?:\+(?P<mods>[A-Z]+))?" r"(?:https?://)?osu\.ppy\.sh/beatmapsets/(?P<set_id>\d+)/?(?:\#(?P<mode>osu|taiko|fruits|mania)(?:/(?P<beatmap_id>\d+)|/?))?(?:\+(?P<mods>[A-Z]+))?"
).unwrap(); ).unwrap();
static ref SHORT_LINK_REGEX: Regex = Regex::new( static ref SHORT_LINK_REGEX: Regex = Regex::new(
r"(?:^|\s)/b/(?P<id>\d+)(?:/(?P<mode>osu|taiko|fruits|mania))?(?:\+(?P<mods>[A-Z]+))?" r"(?:^|\s|\W)(?P<main>/b/(?P<id>\d+)(?:/(?P<mode>osu|taiko|fruits|mania))?(?:\+(?P<mods>[A-Z]+))?)"
).unwrap(); ).unwrap();
} }
pub fn hook(ctx: &mut Context, msg: &Message) -> () { pub fn hook<'a>(
if msg.author.bot { ctx: &'a Context,
return; msg: &'a Message,
} ) -> std::pin::Pin<Box<dyn future::Future<Output = Result<()>> + Send + 'a>> {
let mut v = move || -> CommandResult { Box::pin(async move {
let old_links = handle_old_links(ctx, &msg.content)?; if msg.author.bot {
let new_links = handle_new_links(ctx, &msg.content)?; return Ok(());
let short_links = handle_short_links(ctx, &msg, &msg.content)?;
let mut last_beatmap = None;
for l in old_links
.into_iter()
.chain(new_links.into_iter())
.chain(short_links.into_iter())
{
if let Err(v) = msg.channel_id.send_message(&ctx, |m| match l.embed {
EmbedType::Beatmap(b, info, mods) => {
let t = handle_beatmap(&b, info, l.link, l.mode, mods, m);
let mode = l.mode.unwrap_or(b.mode);
last_beatmap = Some(super::BeatmapWithMode(b, mode));
t
}
EmbedType::Beatmapset(b) => handle_beatmapset(b, l.link, l.mode, m),
}) {
println!("Error in osu! hook: {:?}", v)
}
} }
let (old_links, new_links, short_links) = (
handle_old_links(ctx, &msg.content),
handle_new_links(ctx, &msg.content),
handle_short_links(ctx, &msg, &msg.content),
);
let last_beatmap = stream::select(old_links, stream::select(new_links, short_links))
.then(|l| async move {
let mut bm: Option<super::BeatmapWithMode> = None;
msg.channel_id
.send_message(&ctx, |m| match l.embed {
EmbedType::Beatmap(b, info, mods) => {
let t = handle_beatmap(&b, info, l.link, l.mode, mods, m);
let mode = l.mode.unwrap_or(b.mode);
bm = Some(super::BeatmapWithMode(b, mode));
t
}
EmbedType::Beatmapset(b) => handle_beatmapset(b, l.link, l.mode, m),
})
.await?;
let r: Result<_> = Ok(bm);
r
})
.filter_map(|v| async move {
match v {
Ok(v) => v,
Err(e) => {
eprintln!("{}", e);
None
}
}
})
.fold(None, |_, v| async move { Some(v) })
.await;
// Save the beatmap for query later. // Save the beatmap for query later.
if let Some(t) = last_beatmap { if let Some(t) = last_beatmap {
if let Err(v) = super::cache::save_beatmap(&*ctx.data.read(), msg.channel_id, &t) { super::cache::save_beatmap(&*ctx.data.read().await, msg.channel_id, &t)?;
dbg!(v);
}
} }
Ok(()) Ok(())
}; })
if let Err(v) = v() {
println!("Error in osu! hook: {:?}", v)
}
} }
enum EmbedType { enum EmbedType {
@ -80,167 +86,216 @@ struct ToPrint<'a> {
mode: Option<Mode>, mode: Option<Mode>,
} }
fn handle_old_links<'a>(ctx: &mut Context, content: &'a str) -> Result<Vec<ToPrint<'a>>, Error> { fn handle_old_links<'a>(
let osu = ctx.data.get_cloned::<OsuClient>(); ctx: &'a Context,
let mut to_prints: Vec<ToPrint<'a>> = Vec::new(); content: &'a str,
let cache = ctx.data.get_cloned::<BeatmapCache>(); ) -> impl stream::Stream<Item = ToPrint<'a>> + 'a {
for capture in OLD_LINK_REGEX.captures_iter(content) { OLD_LINK_REGEX
let req_type = capture.name("link_type").unwrap().as_str(); .captures_iter(content)
let req = match req_type { .map(move |capture| async move {
"b" => BeatmapRequestKind::Beatmap(capture["id"].parse()?), let data = ctx.data.read().await;
"s" => BeatmapRequestKind::Beatmapset(capture["id"].parse()?), let osu = data.get::<OsuClient>().unwrap();
_ => continue, let cache = data.get::<BeatmapCache>().unwrap();
}; let req_type = capture.name("link_type").unwrap().as_str();
let mode = capture let req = match req_type {
.name("mode") "b" => BeatmapRequestKind::Beatmap(capture["id"].parse()?),
.map(|v| v.as_str().parse()) "s" => BeatmapRequestKind::Beatmapset(capture["id"].parse()?),
.transpose()? _ => unreachable!(),
.and_then(|v| { };
Some(match v { let mode = capture
0 => Mode::Std, .name("mode")
1 => Mode::Taiko, .map(|v| v.as_str().parse())
2 => Mode::Catch, .transpose()?
3 => Mode::Mania, .and_then(|v| {
_ => return None, Some(match v {
0 => Mode::Std,
1 => Mode::Taiko,
2 => Mode::Catch,
3 => Mode::Mania,
_ => return None,
})
});
let beatmaps = osu
.beatmaps(req, |v| match mode {
Some(m) => v.mode(m, true),
None => v,
}) })
}); .await?;
let beatmaps = osu.beatmaps(req, |v| match mode { if beatmaps.is_empty() {
Some(m) => v.mode(m, true), return Ok(None);
None => v, }
})?; let r: Result<_> = Ok(match req_type {
match req_type { "b" => {
"b" => { let b = beatmaps.into_iter().next().unwrap();
for b in beatmaps.into_iter() {
// collect beatmap info // collect beatmap info
let mods = capture let mods = capture
.name("mods") .name("mods")
.map(|v| Mods::from_str(v.as_str()).ok()) .map(|v| Mods::from_str(v.as_str()).ok())
.flatten() .flatten()
.unwrap_or(Mods::NOMOD); .unwrap_or(Mods::NOMOD);
let info = mode.unwrap_or(b.mode).to_oppai_mode().and_then(|mode| { let info = match mode.unwrap_or(b.mode).to_oppai_mode() {
cache Some(mode) => cache
.get_beatmap(b.beatmap_id) .get_beatmap(b.beatmap_id)
.await
.and_then(|b| b.get_info_with(Some(mode), mods)) .and_then(|b| b.get_info_with(Some(mode), mods))
.ok() .ok(),
}); None => None,
to_prints.push(ToPrint { };
Some(ToPrint {
embed: EmbedType::Beatmap(b, info, mods), embed: EmbedType::Beatmap(b, info, mods),
link: capture.get(0).unwrap().as_str(), link: capture.get(0).unwrap().as_str(),
mode, mode,
}) })
} }
} "s" => Some(ToPrint {
"s" => to_prints.push(ToPrint { embed: EmbedType::Beatmapset(beatmaps),
embed: EmbedType::Beatmapset(beatmaps), link: capture.get(0).unwrap().as_str(),
link: capture.get(0).unwrap().as_str(), mode,
mode, }),
}), _ => None,
_ => (), });
} r
} })
Ok(to_prints) .collect::<stream::FuturesUnordered<_>>()
.filter_map(|v| {
future::ready(match v {
Ok(v) => v,
Err(e) => {
eprintln!("{}", e);
None
}
})
})
} }
fn handle_new_links<'a>(ctx: &mut Context, content: &'a str) -> Result<Vec<ToPrint<'a>>, Error> { fn handle_new_links<'a>(
let osu = ctx.data.get_cloned::<OsuClient>(); ctx: &'a Context,
let mut to_prints: Vec<ToPrint<'a>> = Vec::new(); content: &'a str,
let cache = ctx.data.get_cloned::<BeatmapCache>(); ) -> impl stream::Stream<Item = ToPrint<'a>> + 'a {
for capture in NEW_LINK_REGEX.captures_iter(content) { NEW_LINK_REGEX
let mode = capture .captures_iter(content)
.name("mode") .map(|capture| async move {
.and_then(|v| Mode::parse_from_new_site(v.as_str())); let data = ctx.data.read().await;
let link = capture.get(0).unwrap().as_str(); let osu = data.get::<OsuClient>().unwrap();
let req = match capture.name("beatmap_id") { let cache = data.get::<BeatmapCache>().unwrap();
Some(ref v) => BeatmapRequestKind::Beatmap(v.as_str().parse()?), let mode = capture
None => { .name("mode")
BeatmapRequestKind::Beatmapset(capture.name("set_id").unwrap().as_str().parse()?) .and_then(|v| Mode::parse_from_new_site(v.as_str()));
let link = capture.get(0).unwrap().as_str();
let req = match capture.name("beatmap_id") {
Some(ref v) => BeatmapRequestKind::Beatmap(v.as_str().parse()?),
None => BeatmapRequestKind::Beatmapset(
capture.name("set_id").unwrap().as_str().parse()?,
),
};
let beatmaps = osu
.beatmaps(req, |v| match mode {
Some(m) => v.mode(m, true),
None => v,
})
.await?;
if beatmaps.is_empty() {
return Ok(None);
} }
}; let r: Result<_> = Ok(match capture.name("beatmap_id") {
let beatmaps = osu.beatmaps(req, |v| match mode { Some(_) => {
Some(m) => v.mode(m, true), let beatmap = beatmaps.into_iter().next().unwrap();
None => v,
})?;
match capture.name("beatmap_id") {
Some(_) => {
for beatmap in beatmaps.into_iter() {
// collect beatmap info // collect beatmap info
let mods = capture let mods = capture
.name("mods") .name("mods")
.and_then(|v| Mods::from_str(v.as_str()).ok()) .and_then(|v| Mods::from_str(v.as_str()).ok())
.unwrap_or(Mods::NOMOD); .unwrap_or(Mods::NOMOD);
let info = mode let info = match mode.unwrap_or(beatmap.mode).to_oppai_mode() {
.unwrap_or(beatmap.mode) Some(mode) => cache
.to_oppai_mode() .get_beatmap(beatmap.beatmap_id)
.and_then(|mode| { .await
cache .and_then(|b| b.get_info_with(Some(mode), mods))
.get_beatmap(beatmap.beatmap_id) .ok(),
.and_then(|b| b.get_info_with(Some(mode), mods)) None => None,
.ok() };
}); Some(ToPrint {
to_prints.push(ToPrint {
embed: EmbedType::Beatmap(beatmap, info, mods), embed: EmbedType::Beatmap(beatmap, info, mods),
link, link,
mode, mode,
}) })
} }
} None => Some(ToPrint {
None => to_prints.push(ToPrint { embed: EmbedType::Beatmapset(beatmaps),
embed: EmbedType::Beatmapset(beatmaps), link,
link, mode,
mode, }),
}), });
} r
} })
Ok(to_prints) .collect::<stream::FuturesUnordered<_>>()
.filter_map(|v| {
future::ready(match v {
Ok(v) => v,
Err(e) => {
eprintln!("{}", e);
None
}
})
})
} }
fn handle_short_links<'a>( fn handle_short_links<'a>(
ctx: &mut Context, ctx: &'a Context,
msg: &Message, msg: &'a Message,
content: &'a str, content: &'a str,
) -> Result<Vec<ToPrint<'a>>, Error> { ) -> impl stream::Stream<Item = ToPrint<'a>> + 'a {
if let Some(guild_id) = msg.guild_id { SHORT_LINK_REGEX
if announcer::announcer_of(ctx, crate::discord::announcer::ANNOUNCER_KEY, guild_id)?
!= Some(msg.channel_id)
{
// Disable if we are not in the server's announcer channel
return Ok(vec![]);
}
}
let osu = ctx.data.get_cloned::<BeatmapMetaCache>();
let cache = ctx.data.get_cloned::<BeatmapCache>();
Ok(SHORT_LINK_REGEX
.captures_iter(content) .captures_iter(content)
.map(|capture| -> Result<_, Error> { .map(|capture| async move {
if let Some(guild_id) = msg.guild_id {
if announcer::announcer_of(ctx, crate::discord::announcer::ANNOUNCER_KEY, guild_id)
.await?
!= Some(msg.channel_id)
{
// Disable if we are not in the server's announcer channel
return Err(Error::msg("not in server announcer channel"));
}
}
let data = ctx.data.read().await;
let osu = data.get::<BeatmapMetaCache>().unwrap();
let cache = data.get::<BeatmapCache>().unwrap();
let mode = capture let mode = capture
.name("mode") .name("mode")
.and_then(|v| Mode::parse_from_new_site(v.as_str())); .and_then(|v| Mode::parse_from_new_site(v.as_str()));
let id: u64 = capture.name("id").unwrap().as_str().parse()?; let id: u64 = capture.name("id").unwrap().as_str().parse()?;
let beatmap = match mode { let beatmap = match mode {
Some(mode) => osu.get_beatmap(id, mode), Some(mode) => osu.get_beatmap(id, mode).await,
None => osu.get_beatmap_default(id), None => osu.get_beatmap_default(id).await,
}?; }?;
let mods = capture let mods = capture
.name("mods") .name("mods")
.and_then(|v| Mods::from_str(v.as_str()).ok()) .and_then(|v| Mods::from_str(v.as_str()).ok())
.unwrap_or(Mods::NOMOD); .unwrap_or(Mods::NOMOD);
let info = mode let info = match mode.unwrap_or(beatmap.mode).to_oppai_mode() {
.unwrap_or(beatmap.mode) Some(mode) => cache
.to_oppai_mode() .get_beatmap(beatmap.beatmap_id)
.and_then(|mode| { .await
cache .and_then(|b| b.get_info_with(Some(mode), mods))
.get_beatmap(beatmap.beatmap_id) .ok(),
.and_then(|b| b.get_info_with(Some(mode), mods)) None => None,
.ok() };
}); let r: Result<_> = Ok(ToPrint {
Ok(ToPrint {
embed: EmbedType::Beatmap(beatmap, info, mods), embed: EmbedType::Beatmap(beatmap, info, mods),
link: capture.get(0).unwrap().as_str(), link: capture.name("main").unwrap().as_str(),
mode, mode,
});
r
})
.collect::<stream::FuturesUnordered<_>>()
.filter_map(|v| {
future::ready(match v {
Ok(v) => Some(v),
Err(e) => {
eprintln!("{}", e);
None
}
}) })
}) })
.filter_map(|v| v.ok())
.collect())
} }
fn handle_beatmap<'a, 'b>( fn handle_beatmap<'a, 'b>(

View file

@ -2,10 +2,9 @@ use crate::{
discord::beatmap_cache::BeatmapMetaCache, discord::beatmap_cache::BeatmapMetaCache,
discord::oppai_cache::BeatmapCache, discord::oppai_cache::BeatmapCache,
models::{Beatmap, Mode, Mods, Score, User}, models::{Beatmap, Mode, Mods, Score, User},
request::{BeatmapRequestKind, UserID}, request::UserID,
Client as OsuHttpClient, Client as OsuHttpClient,
}; };
use rayon::prelude::*;
use serenity::{ use serenity::{
framework::standard::{ framework::standard::{
macros::{command, group}, macros::{command, group},
@ -14,7 +13,7 @@ use serenity::{
model::channel::Message, model::channel::Message,
utils::MessageBuilder, utils::MessageBuilder,
}; };
use std::str::FromStr; use std::{str::FromStr, sync::Arc};
use youmubot_prelude::*; use youmubot_prelude::*;
mod announcer; mod announcer;
@ -36,7 +35,7 @@ use server_rank::{LEADERBOARD_COMMAND, SERVER_RANK_COMMAND};
pub(crate) struct OsuClient; pub(crate) struct OsuClient;
impl TypeMapKey for OsuClient { impl TypeMapKey for OsuClient {
type Value = OsuHttpClient; type Value = Arc<OsuHttpClient>;
} }
/// Sets up the osu! command handling section. /// Sets up the osu! command handling section.
@ -52,7 +51,7 @@ impl TypeMapKey for OsuClient {
/// ///
pub fn setup( pub fn setup(
path: &std::path::Path, path: &std::path::Path,
data: &mut ShareMap, data: &mut TypeMap,
announcers: &mut AnnouncerHandler, announcers: &mut AnnouncerHandler,
) -> CommandResult { ) -> CommandResult {
// Databases // Databases
@ -61,11 +60,10 @@ pub fn setup(
OsuUserBests::insert_into(&mut *data, &path.join("osu_user_bests.yaml"))?; OsuUserBests::insert_into(&mut *data, &path.join("osu_user_bests.yaml"))?;
// API client // API client
let http_client = data.get_cloned::<HTTPClient>(); let http_client = data.get::<HTTPClient>().unwrap().clone();
let osu_client = OsuHttpClient::new( let osu_client = Arc::new(OsuHttpClient::new(
http_client.clone(),
std::env::var("OSU_API_KEY").expect("Please set OSU_API_KEY as osu! api key."), std::env::var("OSU_API_KEY").expect("Please set OSU_API_KEY as osu! api key."),
); ));
data.insert::<OsuClient>(osu_client.clone()); data.insert::<OsuClient>(osu_client.clone());
data.insert::<oppai_cache::BeatmapCache>(oppai_cache::BeatmapCache::new(http_client)); data.insert::<oppai_cache::BeatmapCache>(oppai_cache::BeatmapCache::new(http_client));
data.insert::<beatmap_cache::BeatmapMetaCache>(beatmap_cache::BeatmapMetaCache::new( data.insert::<beatmap_cache::BeatmapMetaCache>(beatmap_cache::BeatmapMetaCache::new(
@ -73,7 +71,7 @@ pub fn setup(
)); ));
// Announcer // Announcer
announcers.add(announcer::ANNOUNCER_KEY, announcer::updates); announcers.add(announcer::ANNOUNCER_KEY, announcer::Announcer);
Ok(()) Ok(())
} }
@ -101,8 +99,8 @@ struct Osu;
#[description = "Receive information about an user in osu!std mode."] #[description = "Receive information about an user in osu!std mode."]
#[usage = "[username or user_id = your saved username]"] #[usage = "[username or user_id = your saved username]"]
#[max_args(1)] #[max_args(1)]
pub fn std(ctx: &mut Context, msg: &Message, args: Args) -> CommandResult { pub async fn std(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
get_user(ctx, msg, args, Mode::Std) get_user(ctx, msg, args, Mode::Std).await
} }
#[command] #[command]
@ -110,8 +108,8 @@ pub fn std(ctx: &mut Context, msg: &Message, args: Args) -> CommandResult {
#[description = "Receive information about an user in osu!taiko mode."] #[description = "Receive information about an user in osu!taiko mode."]
#[usage = "[username or user_id = your saved username]"] #[usage = "[username or user_id = your saved username]"]
#[max_args(1)] #[max_args(1)]
pub fn taiko(ctx: &mut Context, msg: &Message, args: Args) -> CommandResult { pub async fn taiko(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
get_user(ctx, msg, args, Mode::Taiko) get_user(ctx, msg, args, Mode::Taiko).await
} }
#[command] #[command]
@ -119,8 +117,8 @@ pub fn taiko(ctx: &mut Context, msg: &Message, args: Args) -> CommandResult {
#[description = "Receive information about an user in osu!catch mode."] #[description = "Receive information about an user in osu!catch mode."]
#[usage = "[username or user_id = your saved username]"] #[usage = "[username or user_id = your saved username]"]
#[max_args(1)] #[max_args(1)]
pub fn catch(ctx: &mut Context, msg: &Message, args: Args) -> CommandResult { pub async fn catch(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
get_user(ctx, msg, args, Mode::Catch) get_user(ctx, msg, args, Mode::Catch).await
} }
#[command] #[command]
@ -128,8 +126,8 @@ pub fn catch(ctx: &mut Context, msg: &Message, args: Args) -> CommandResult {
#[description = "Receive information about an user in osu!mania mode."] #[description = "Receive information about an user in osu!mania mode."]
#[usage = "[username or user_id = your saved username]"] #[usage = "[username or user_id = your saved username]"]
#[max_args(1)] #[max_args(1)]
pub fn mania(ctx: &mut Context, msg: &Message, args: Args) -> CommandResult { pub async fn mania(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
get_user(ctx, msg, args, Mode::Mania) get_user(ctx, msg, args, Mode::Mania).await
} }
pub(crate) struct BeatmapWithMode(pub Beatmap, pub Mode); pub(crate) struct BeatmapWithMode(pub Beatmap, pub Mode);
@ -150,17 +148,15 @@ impl AsRef<Beatmap> for BeatmapWithMode {
#[description = "Save the given username as your username."] #[description = "Save the given username as your username."]
#[usage = "[username or user_id]"] #[usage = "[username or user_id]"]
#[num_args(1)] #[num_args(1)]
pub fn save(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let osu = ctx.data.get_cloned::<OsuClient>(); let data = ctx.data.read().await;
let osu = data.get::<OsuClient>().unwrap();
let user = args.single::<String>()?; let user = args.single::<String>()?;
let user: Option<User> = osu.user(UserID::Auto(user), |f| f)?; let user: Option<User> = osu.user(UserID::Auto(user), |f| f).await?;
match user { match user {
Some(u) => { Some(u) => {
let db = OsuSavedUsers::open(&*ctx.data.read()); OsuSavedUsers::open(&*data).borrow_mut()?.insert(
let mut db = db.borrow_mut()?;
db.insert(
msg.author.id, msg.author.id,
OsuUser { OsuUser {
id: u.id, id: u.id,
@ -174,10 +170,11 @@ pub fn save(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
.push("user has been set to ") .push("user has been set to ")
.push_mono_safe(u.username) .push_mono_safe(u.username)
.build(), .build(),
)?; )
.await?;
} }
None => { None => {
msg.reply(&ctx, "user not found...")?; msg.reply(&ctx, "user not found...").await?;
} }
} }
Ok(()) Ok(())
@ -200,7 +197,7 @@ impl FromStr for ModeArg {
fn to_user_id_query( fn to_user_id_query(
s: Option<UsernameArg>, s: Option<UsernameArg>,
data: &ShareMap, data: &TypeMap,
msg: &Message, msg: &Message,
) -> Result<UserID, Error> { ) -> Result<UserID, Error> {
let id = match s { let id = match s {
@ -236,151 +233,161 @@ impl FromStr for Nth {
} }
} }
fn list_plays(plays: Vec<Score>, mode: Mode, ctx: Context, m: &Message) -> CommandResult { async fn list_plays<'a>(
let watcher = ctx.data.get_cloned::<ReactionWatcher>(); plays: Vec<Score>,
let osu = ctx.data.get_cloned::<BeatmapMetaCache>(); mode: Mode,
let beatmap_cache = ctx.data.get_cloned::<BeatmapCache>(); ctx: &'a Context,
m: &'a Message,
) -> CommandResult {
let plays = Arc::new(plays);
if plays.is_empty() { if plays.is_empty() {
m.reply(&ctx, "No plays found")?; m.reply(&ctx, "No plays found").await?;
return Ok(()); return Ok(());
} }
let mut beatmaps: Vec<Option<String>> = vec![None; plays.len()];
const ITEMS_PER_PAGE: usize = 5; const ITEMS_PER_PAGE: usize = 5;
let total_pages = (plays.len() + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE; let total_pages = (plays.len() + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE;
watcher.paginate_fn( paginate(
ctx, move |page, ctx, msg| {
m.channel_id, let plays = plays.clone();
move |page, e| { Box::pin(async move {
let page = page as usize; let data = ctx.data.read().await;
let start = page * ITEMS_PER_PAGE; let osu = data.get::<BeatmapMetaCache>().unwrap();
let end = plays.len().min(start + ITEMS_PER_PAGE); let beatmap_cache = data.get::<BeatmapCache>().unwrap();
if start >= end { let page = page as usize;
return (e, Err(Error::from("No more pages"))); let start = page * ITEMS_PER_PAGE;
} let end = plays.len().min(start + ITEMS_PER_PAGE);
if start >= end {
return Ok(false);
}
let plays = &plays[start..end]; let hourglass = msg.react(ctx, '⌛').await?;
let beatmaps: Vec<&mut String> = { let plays = &plays[start..end];
let b = &mut beatmaps[start..end]; let beatmaps = plays
b.par_iter_mut() .iter()
.enumerate() .map(|play| async move {
.map(|(i, v)| { let beatmap = osu.get_beatmap(play.beatmap_id, mode).await?;
v.get_or_insert_with(|| { let stars = {
if let Some(b) = osu.get_beatmap(plays[i].beatmap_id, mode).ok() { let b = beatmap_cache.get_beatmap(beatmap.beatmap_id).await?;
let stars = beatmap_cache mode.to_oppai_mode()
.get_beatmap(b.beatmap_id) .and_then(|mode| b.get_info_with(Some(mode), play.mods).ok())
.ok() .map(|info| info.stars as f64)
.and_then(|b| { .unwrap_or(beatmap.difficulty.stars)
mode.to_oppai_mode().and_then(|mode| { };
b.get_info_with(Some(mode), plays[i].mods).ok() let r: Result<_> = Ok(format!(
}) "[{:.1}*] {} - {} [{}] ({})",
}) stars,
.map(|info| info.stars as f64) beatmap.artist,
.unwrap_or(b.difficulty.stars); beatmap.title,
format!( beatmap.difficulty_name,
"[{:.1}*] {} - {} [{}] ({})", beatmap.short_link(Some(mode), Some(play.mods)),
stars, ));
b.artist, r
b.title,
b.difficulty_name,
b.short_link(Some(mode), Some(plays[i].mods)),
)
} else {
"FETCH_FAILED".to_owned()
}
})
}) })
.collect::<Vec<_>>() .collect::<stream::FuturesOrdered<_>>()
}; .map(|v| v.unwrap_or("FETCH_FAILED".to_owned()))
let pp = plays .collect::<Vec<String>>();
.iter() let pp = plays
.map(|p| { .iter()
p.pp.map(|pp| format!("{:.2}pp", pp)) .map(|p| async move {
.or_else(|| { match p.pp.map(|pp| format!("{:.2}pp", pp)) {
beatmap_cache.get_beatmap(p.beatmap_id).ok().and_then(|b| { Some(v) => Ok(v),
mode.to_oppai_mode().and_then(|op| { None => {
b.get_pp_from( let b = beatmap_cache.get_beatmap(p.beatmap_id).await?;
oppai_rs::Combo::NonFC { let r: Result<_> = Ok(mode
max_combo: p.max_combo as u32, .to_oppai_mode()
misses: p.count_miss as u32, .and_then(|op| {
}, b.get_pp_from(
p.accuracy(mode) as f32, oppai_rs::Combo::NonFC {
Some(op), max_combo: p.max_combo as u32,
p.mods, misses: p.count_miss as u32,
) },
.ok() p.accuracy(mode) as f32,
.map(|pp| format!("{:.2}pp [?]", pp)) Some(op),
}) p.mods,
}) )
}) .ok()
.unwrap_or("-".to_owned()) .map(|pp| format!("{:.2}pp [?]", pp))
}) })
.collect::<Vec<_>>(); .unwrap_or("-".to_owned()));
let pw = pp.iter().map(|v| v.len()).max().unwrap_or(2); r
/*mods width*/ }
let mw = plays }
.iter() })
.map(|v| v.mods.to_string().len()) .collect::<stream::FuturesOrdered<_>>()
.max() .map(|v| v.unwrap_or("-".to_owned()))
.unwrap() .collect::<Vec<String>>();
.max(4); let (beatmaps, pp) = future::join(beatmaps, pp).await;
/*beatmap names*/ let pw = pp.iter().map(|v| v.len()).max().unwrap_or(2);
let bw = beatmaps.iter().map(|v| v.len()).max().unwrap().max(7); /*mods width*/
let mw = plays
.iter()
.map(|v| v.mods.to_string().len())
.max()
.unwrap()
.max(4);
/*beatmap names*/
let bw = beatmaps.iter().map(|v| v.len()).max().unwrap().max(7);
let mut m = MessageBuilder::new(); let mut m = MessageBuilder::new();
// Table header // Table header
m.push_line(format!(
" # | {:pw$} | accuracy | rank | {:mw$} | {:bw$}",
"pp",
"mods",
"beatmap",
pw = pw,
mw = mw,
bw = bw
));
m.push_line(format!(
"------{:-<pw$}---------------------{:-<mw$}---{:-<bw$}",
"",
"",
"",
pw = pw,
mw = mw,
bw = bw
));
// Each row
for (id, (play, beatmap)) in plays.iter().zip(beatmaps.iter()).enumerate() {
m.push_line(format!( m.push_line(format!(
"{:>3} | {:>pw$} | {:>8} | {:^4} | {:mw$} | {:bw$}", " # | {:pw$} | accuracy | rank | {:mw$} | {:bw$}",
id + start + 1, "pp",
pp[id], "mods",
format!("{:.2}%", play.accuracy(mode)), "beatmap",
play.rank.to_string(),
play.mods.to_string(),
beatmap,
pw = pw, pw = pw,
mw = mw, mw = mw,
bw = bw bw = bw
)); ));
} m.push_line(format!(
// End "------{:-<pw$}---------------------{:-<mw$}---{:-<bw$}",
let table = m.build().replace("```", "\\`\\`\\`"); "",
let mut m = MessageBuilder::new(); "",
m.push_codeblock(table, None).push_line(format!( "",
"Page **{}/{}**", pw = pw,
page + 1, mw = mw,
total_pages bw = bw
)); ));
if let None = mode.to_oppai_mode() { // Each row
m.push_line("Note: star difficulty doesn't reflect mods applied."); for (id, (play, beatmap)) in plays.iter().zip(beatmaps.iter()).enumerate() {
} else { m.push_line(format!(
m.push_line("[?] means pp was predicted by oppai-rs."); "{:>3} | {:>pw$} | {:>8} | {:^4} | {:mw$} | {:bw$}",
} id + start + 1,
(e.content(m.build()), Ok(())) pp[id],
format!("{:.2}%", play.accuracy(mode)),
play.rank.to_string(),
play.mods.to_string(),
beatmap,
pw = pw,
mw = mw,
bw = bw
));
}
// End
let table = m.build().replace("```", "\\`\\`\\`");
let mut m = MessageBuilder::new();
m.push_codeblock(table, None).push_line(format!(
"Page **{}/{}**",
page + 1,
total_pages
));
if let None = mode.to_oppai_mode() {
m.push_line("Note: star difficulty doesn't reflect mods applied.");
} else {
m.push_line("[?] means pp was predicted by oppai-rs.");
}
msg.edit(ctx, |f| f.content(m.to_string())).await?;
hourglass.delete(ctx).await?;
Ok(true)
})
}, },
ctx,
m.channel_id,
std::time::Duration::from_secs(60), std::time::Duration::from_secs(60),
) )
.await?;
Ok(())
} }
#[command] #[command]
@ -388,44 +395,49 @@ fn list_plays(plays: Vec<Score>, mode: Mode, ctx: Context, m: &Message) -> Comma
#[usage = "#[the nth recent play = --all] / [mode (std, taiko, mania, catch) = std] / [username / user id = your saved id]"] #[usage = "#[the nth recent play = --all] / [mode (std, taiko, mania, catch) = std] / [username / user id = your saved id]"]
#[example = "#1 / taiko / natsukagami"] #[example = "#1 / taiko / natsukagami"]
#[max_args(3)] #[max_args(3)]
pub fn recent(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { pub async fn recent(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await;
let nth = args.single::<Nth>().unwrap_or(Nth::All); let nth = args.single::<Nth>().unwrap_or(Nth::All);
let mode = args.single::<ModeArg>().unwrap_or(ModeArg(Mode::Std)).0; let mode = args.single::<ModeArg>().unwrap_or(ModeArg(Mode::Std)).0;
let user = to_user_id_query(args.single::<UsernameArg>().ok(), &*ctx.data.read(), msg)?; let user = to_user_id_query(args.single::<UsernameArg>().ok(), &*data, msg)?;
let osu = ctx.data.get_cloned::<OsuClient>(); let osu = data.get::<OsuClient>().unwrap();
let meta_cache = ctx.data.get_cloned::<BeatmapMetaCache>(); let meta_cache = data.get::<BeatmapMetaCache>().unwrap();
let oppai = ctx.data.get_cloned::<BeatmapCache>(); let oppai = data.get::<BeatmapCache>().unwrap();
let user = osu let user = osu
.user(user, |f| f.mode(mode))? .user(user, |f| f.mode(mode))
.await?
.ok_or(Error::from("User not found"))?; .ok_or(Error::from("User not found"))?;
match nth { match nth {
Nth::Nth(nth) => { Nth::Nth(nth) => {
let recent_play = osu let recent_play = osu
.user_recent(UserID::ID(user.id), |f| f.mode(mode).limit(nth))? .user_recent(UserID::ID(user.id), |f| f.mode(mode).limit(nth))
.await?
.into_iter() .into_iter()
.last() .last()
.ok_or(Error::from("No such play"))?; .ok_or(Error::from("No such play"))?;
let beatmap = meta_cache let beatmap = meta_cache.get_beatmap(recent_play.beatmap_id, mode).await?;
.get_beatmap(recent_play.beatmap_id, mode) let content = oppai.get_beatmap(beatmap.beatmap_id).await?;
.unwrap();
let content = oppai.get_beatmap(beatmap.beatmap_id)?;
let beatmap_mode = BeatmapWithMode(beatmap, mode); let beatmap_mode = BeatmapWithMode(beatmap, mode);
msg.channel_id.send_message(&ctx, |m| { msg.channel_id
m.content(format!( .send_message(&ctx, |m| {
"{}: here is the play that you requested", m.content(format!(
msg.author "{}: here is the play that you requested",
)) msg.author
.embed(|m| score_embed(&recent_play, &beatmap_mode, &content, &user, None, m)) ))
})?; .embed(|m| score_embed(&recent_play, &beatmap_mode, &content, &user, None, m))
})
.await?;
// Save the beatmap... // Save the beatmap...
cache::save_beatmap(&*ctx.data.read(), msg.channel_id, &beatmap_mode)?; cache::save_beatmap(&*data, msg.channel_id, &beatmap_mode)?;
} }
Nth::All => { Nth::All => {
let plays = osu.user_recent(UserID::ID(user.id), |f| f.mode(mode).limit(50))?; let plays = osu
list_plays(plays, mode, ctx.clone(), msg)?; .user_recent(UserID::ID(user.id), |f| f.mode(mode).limit(50))
.await?;
list_plays(plays, mode, ctx, msg).await?;
} }
} }
Ok(()) Ok(())
@ -435,28 +447,33 @@ pub fn recent(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult
#[description = "Show information from the last queried beatmap."] #[description = "Show information from the last queried beatmap."]
#[usage = "[mods = no mod]"] #[usage = "[mods = no mod]"]
#[max_args(1)] #[max_args(1)]
pub fn last(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { pub async fn last(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let b = cache::get_beatmap(&*ctx.data.read(), msg.channel_id)?; let data = ctx.data.read().await;
let b = cache::get_beatmap(&*data, msg.channel_id)?;
match b { match b {
Some(BeatmapWithMode(b, m)) => { Some(BeatmapWithMode(b, m)) => {
let mods = args.find::<Mods>().unwrap_or(Mods::NOMOD); let mods = args.find::<Mods>().unwrap_or(Mods::NOMOD);
let info = ctx let info = data
.data .get::<BeatmapCache>()
.get_cloned::<BeatmapCache>() .unwrap()
.get_beatmap(b.beatmap_id)? .get_beatmap(b.beatmap_id)
.await?
.get_info_with(m.to_oppai_mode(), mods) .get_info_with(m.to_oppai_mode(), mods)
.ok(); .ok();
msg.channel_id.send_message(&ctx, |f| { msg.channel_id
f.content(format!( .send_message(&ctx, |f| {
"{}: here is the beatmap you requested!", f.content(format!(
msg.author "{}: here is the beatmap you requested!",
)) msg.author
.embed(|c| beatmap_embed(&b, m, mods, info, c)) ))
})?; .embed(|c| beatmap_embed(&b, m, mods, info, c))
})
.await?;
} }
None => { None => {
msg.reply(&ctx, "No beatmap was queried on this channel.")?; msg.reply(&ctx, "No beatmap was queried on this channel.")
.await?;
} }
} }
@ -468,12 +485,14 @@ pub fn last(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
#[usage = "[username or tag = yourself]"] #[usage = "[username or tag = yourself]"]
#[description = "Check your own or someone else's best record on the last beatmap. Also stores the result if possible."] #[description = "Check your own or someone else's best record on the last beatmap. Also stores the result if possible."]
#[max_args(1)] #[max_args(1)]
pub fn check(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let bm = cache::get_beatmap(&*ctx.data.read(), msg.channel_id)?; let data = ctx.data.read().await;
let bm = cache::get_beatmap(&*data, msg.channel_id)?;
match bm { match bm {
None => { None => {
msg.reply(&ctx, "No beatmap queried on this channel.")?; msg.reply(&ctx, "No beatmap queried on this channel.")
.await?;
} }
Some(bm) => { Some(bm) => {
let b = &bm.0; let b = &bm.0;
@ -484,31 +503,36 @@ pub fn check(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult
None => Some(msg.author.id), None => Some(msg.author.id),
_ => None, _ => None,
}; };
let user = to_user_id_query(username_arg, &*ctx.data.read(), msg)?; let user = to_user_id_query(username_arg, &*data, msg)?;
let osu = ctx.data.get_cloned::<OsuClient>(); let osu = data.get::<OsuClient>().unwrap();
let oppai = ctx.data.get_cloned::<BeatmapCache>(); let oppai = data.get::<BeatmapCache>().unwrap();
let content = oppai.get_beatmap(b.beatmap_id)?; let content = oppai.get_beatmap(b.beatmap_id).await?;
let user = osu let user = osu
.user(user, |f| f)? .user(user, |f| f)
.await?
.ok_or(Error::from("User not found"))?; .ok_or(Error::from("User not found"))?;
let scores = osu.scores(b.beatmap_id, |f| f.user(UserID::ID(user.id)).mode(m))?; let scores = osu
.scores(b.beatmap_id, |f| f.user(UserID::ID(user.id)).mode(m))
.await?;
if scores.is_empty() { if scores.is_empty() {
msg.reply(&ctx, "No scores found")?; msg.reply(&ctx, "No scores found").await?;
} }
for score in scores.iter() { for score in scores.iter() {
msg.channel_id.send_message(&ctx, |c| { msg.channel_id
c.embed(|m| score_embed(score, &bm, &content, &user, None, m)) .send_message(&ctx, |c| {
})?; c.embed(|m| score_embed(&score, &bm, &content, &user, None, m))
})
.await?;
} }
if let Some(user_id) = user_id { if let Some(user_id) = user_id {
// Save to database // Save to database
OsuUserBests::open(&*ctx.data.read()) OsuUserBests::open(&*data)
.borrow_mut()? .borrow_mut()?
.entry((bm.0.beatmap_id, bm.1)) .entry((bm.0.beatmap_id, bm.1))
.or_default() .or_default()
@ -522,27 +546,32 @@ pub fn check(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult
#[command] #[command]
#[description = "Get the n-th top record of an user."] #[description = "Get the n-th top record of an user."]
#[usage = "#[n-th = --all] / [mode (std, taiko, catch, mania)] = std / [username or user_id = your saved user id]"] #[usage = "[mode (std, taiko, catch, mania)] = std / #[n-th = --all] / [username or user_id = your saved user id]"]
#[example = "#2 / taiko / natsukagami"] #[example = "taiko / #2 / natsukagami"]
#[max_args(3)] #[max_args(3)]
pub fn top(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await;
let nth = args.single::<Nth>().unwrap_or(Nth::All); let nth = args.single::<Nth>().unwrap_or(Nth::All);
let mode = args let mode = args
.single::<ModeArg>() .single::<ModeArg>()
.map(|ModeArg(t)| t) .map(|ModeArg(t)| t)
.unwrap_or(Mode::Std); .unwrap_or(Mode::Std);
let user = to_user_id_query(args.single::<UsernameArg>().ok(), &*ctx.data.read(), msg)?; let user = to_user_id_query(args.single::<UsernameArg>().ok(), &*data, msg)?;
let meta_cache = data.get::<BeatmapMetaCache>().unwrap();
let osu = data.get::<OsuClient>().unwrap();
let osu = ctx.data.get_cloned::<OsuClient>(); let oppai = data.get::<BeatmapCache>().unwrap();
let oppai = ctx.data.get_cloned::<BeatmapCache>();
let user = osu let user = osu
.user(user, |f| f.mode(mode))? .user(user, |f| f.mode(mode))
.await?
.ok_or(Error::from("User not found"))?; .ok_or(Error::from("User not found"))?;
match nth { match nth {
Nth::Nth(nth) => { Nth::Nth(nth) => {
let top_play = osu.user_best(UserID::ID(user.id), |f| f.mode(mode).limit(nth))?; let top_play = osu
.user_best(UserID::ID(user.id), |f| f.mode(mode).limit(nth))
.await?;
let rank = top_play.len() as u8; let rank = top_play.len() as u8;
@ -550,69 +579,76 @@ pub fn top(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
.into_iter() .into_iter()
.last() .last()
.ok_or(Error::from("No such play"))?; .ok_or(Error::from("No such play"))?;
let beatmap = osu let beatmap = meta_cache.get_beatmap(top_play.beatmap_id, mode).await?;
.beatmaps(BeatmapRequestKind::Beatmap(top_play.beatmap_id), |f| { let content = oppai.get_beatmap(beatmap.beatmap_id).await?;
f.mode(mode, true)
})?
.into_iter()
.next()
.unwrap();
let content = oppai.get_beatmap(beatmap.beatmap_id)?;
let beatmap = BeatmapWithMode(beatmap, mode); let beatmap = BeatmapWithMode(beatmap, mode);
msg.channel_id.send_message(&ctx, |m| { msg.channel_id
m.content(format!( .send_message(&ctx, |m| {
"{}: here is the play that you requested", m.content(format!(
msg.author "{}: here is the play that you requested",
)) msg.author
.embed(|m| score_embed(&top_play, &beatmap, &content, &user, Some(rank), m)) ))
})?; .embed(|m| score_embed(&top_play, &beatmap, &content, &user, Some(rank), m))
})
.await?;
// Save the beatmap... // Save the beatmap...
cache::save_beatmap(&*ctx.data.read(), msg.channel_id, &beatmap)?; cache::save_beatmap(&*data, msg.channel_id, &beatmap)?;
} }
Nth::All => { Nth::All => {
let plays = osu.user_best(UserID::ID(user.id), |f| f.mode(mode).limit(100))?; let plays = osu
list_plays(plays, mode, ctx.clone(), msg)?; .user_best(UserID::ID(user.id), |f| f.mode(mode).limit(100))
.await?;
list_plays(plays, mode, ctx, msg).await?;
} }
} }
Ok(()) Ok(())
} }
fn get_user(ctx: &mut Context, msg: &Message, mut args: Args, mode: Mode) -> CommandResult { async fn get_user(ctx: &Context, msg: &Message, mut args: Args, mode: Mode) -> CommandResult {
let user = to_user_id_query(args.single::<UsernameArg>().ok(), &*ctx.data.read(), msg)?; let data = ctx.data.read().await;
let osu = ctx.data.get_cloned::<OsuClient>(); let user = to_user_id_query(args.single::<UsernameArg>().ok(), &*data, msg)?;
let cache = ctx.data.get_cloned::<BeatmapMetaCache>(); let osu = data.get::<OsuClient>().unwrap();
let user = osu.user(user, |f| f.mode(mode))?; let cache = data.get::<BeatmapMetaCache>().unwrap();
let oppai = ctx.data.get_cloned::<BeatmapCache>(); let user = osu.user(user, |f| f.mode(mode)).await?;
let oppai = data.get::<BeatmapCache>().unwrap();
match user { match user {
Some(u) => { Some(u) => {
let best = osu let best = match osu
.user_best(UserID::ID(u.id), |f| f.limit(1).mode(mode))? .user_best(UserID::ID(u.id), |f| f.limit(1).mode(mode))
.await?
.into_iter() .into_iter()
.next() .next()
.map(|m| -> Result<_, Error> { {
let beatmap = cache.get_beatmap(m.beatmap_id, mode)?; Some(m) => {
let info = mode let beatmap = cache.get_beatmap(m.beatmap_id, mode).await?;
.to_oppai_mode() let info = match mode.to_oppai_mode() {
.map(|mode| -> Result<_, Error> { Some(mode) => Some(
Ok(oppai oppai
.get_beatmap(m.beatmap_id)? .get_beatmap(m.beatmap_id)
.get_info_with(Some(mode), m.mods)?) .await?
}) .get_info_with(Some(mode), m.mods)?,
.transpose()?; ),
Ok((m, BeatmapWithMode(beatmap, mode), info)) None => None,
};
Some((m, BeatmapWithMode(beatmap, mode), info))
}
None => None,
};
msg.channel_id
.send_message(&ctx, |m| {
m.content(format!(
"{}: here is the user that you requested",
msg.author
))
.embed(|m| user_embed(u, best, m))
}) })
.transpose()?; .await?;
msg.channel_id.send_message(&ctx, |m| {
m.content(format!(
"{}: here is the user that you requested",
msg.author
))
.embed(|m| user_embed(u, best, m))
})
} }
None => msg.reply(&ctx, "🔍 user not found!"), None => {
}?; msg.reply(&ctx, "🔍 user not found!").await?;
}
};
Ok(()) Ok(())
} }

View file

@ -1,12 +1,11 @@
use serenity::framework::standard::CommandError;
use std::{ffi::CString, sync::Arc}; use std::{ffi::CString, sync::Arc};
use youmubot_prelude::TypeMapKey; use youmubot_prelude::*;
/// the information collected from a download/Oppai request. /// the information collected from a download/Oppai request.
#[derive(Clone, Debug)] #[derive(Debug)]
pub struct BeatmapContent { pub struct BeatmapContent {
id: u64, id: u64,
content: Arc<CString>, content: CString,
} }
/// the output of "one" oppai run. /// the output of "one" oppai run.
@ -24,7 +23,7 @@ impl BeatmapContent {
accuracy: f32, accuracy: f32,
mode: Option<oppai_rs::Mode>, mode: Option<oppai_rs::Mode>,
mods: impl Into<oppai_rs::Mods>, mods: impl Into<oppai_rs::Mods>,
) -> Result<f32, CommandError> { ) -> Result<f32> {
let mut oppai = oppai_rs::Oppai::new_from_content(&self.content[..])?; let mut oppai = oppai_rs::Oppai::new_from_content(&self.content[..])?;
oppai.combo(combo)?.accuracy(accuracy)?.mods(mods.into()); oppai.combo(combo)?.accuracy(accuracy)?.mods(mods.into());
if let Some(mode) = mode { if let Some(mode) = mode {
@ -38,7 +37,7 @@ impl BeatmapContent {
&self, &self,
mode: Option<oppai_rs::Mode>, mode: Option<oppai_rs::Mode>,
mods: impl Into<oppai_rs::Mods>, mods: impl Into<oppai_rs::Mods>,
) -> Result<BeatmapInfo, CommandError> { ) -> Result<BeatmapInfo> {
let mut oppai = oppai_rs::Oppai::new_from_content(&self.content[..])?; let mut oppai = oppai_rs::Oppai::new_from_content(&self.content[..])?;
if let Some(mode) = mode { if let Some(mode) = mode {
oppai.mode(mode)?; oppai.mode(mode)?;
@ -56,39 +55,47 @@ impl BeatmapContent {
} }
/// A central cache for the beatmaps. /// A central cache for the beatmaps.
#[derive(Clone, Debug)]
pub struct BeatmapCache { pub struct BeatmapCache {
client: reqwest::blocking::Client, client: ratelimit::Ratelimit<reqwest::Client>,
cache: Arc<dashmap::DashMap<u64, BeatmapContent>>, cache: dashmap::DashMap<u64, Arc<BeatmapContent>>,
} }
impl BeatmapCache { impl BeatmapCache {
/// Create a new cache. /// Create a new cache.
pub fn new(client: reqwest::blocking::Client) -> Self { pub fn new(client: reqwest::Client) -> Self {
let client = ratelimit::Ratelimit::new(client, 5, std::time::Duration::from_secs(1));
BeatmapCache { BeatmapCache {
client, client,
cache: Arc::new(dashmap::DashMap::new()), cache: dashmap::DashMap::new(),
} }
} }
fn download_beatmap(&self, id: u64) -> Result<BeatmapContent, CommandError> { async fn download_beatmap(&self, id: u64) -> Result<BeatmapContent> {
let content = self let content = self
.client .client
.borrow()
.await?
.get(&format!("https://osu.ppy.sh/osu/{}", id)) .get(&format!("https://osu.ppy.sh/osu/{}", id))
.send()? .send()
.bytes()?; .await?
.bytes()
.await?;
Ok(BeatmapContent { Ok(BeatmapContent {
id, id,
content: Arc::new(CString::new(content.into_iter().collect::<Vec<_>>())?), content: CString::new(content.into_iter().collect::<Vec<_>>())?,
}) })
} }
/// Get a beatmap from the cache. /// Get a beatmap from the cache.
pub fn get_beatmap(&self, id: u64) -> Result<BeatmapContent, CommandError> { pub async fn get_beatmap(
self.cache &self,
.entry(id) id: u64,
.or_try_insert_with(|| self.download_beatmap(id)) ) -> Result<impl std::ops::Deref<Target = BeatmapContent>> {
.map(|v| v.clone()) if !self.cache.contains_key(&id) {
self.cache
.insert(id, Arc::new(self.download_beatmap(id).await?));
}
Ok(self.cache.get(&id).unwrap().clone())
} }
} }

View file

@ -1,12 +1,14 @@
use super::{ use super::{
cache::get_beatmap, cache::get_beatmap,
db::{OsuSavedUsers, OsuUserBests}, db::{OsuSavedUsers, OsuUserBests},
ModeArg, ModeArg, OsuClient,
};
use crate::{
models::{Mode, Score},
request::UserID,
}; };
use crate::models::{Mode, Score};
use serenity::{ use serenity::{
builder::EditMessage, framework::standard::{macros::command, Args, CommandResult},
framework::standard::{macros::command, Args, CommandError as Error, CommandResult},
model::channel::Message, model::channel::Message,
utils::MessageBuilder, utils::MessageBuilder,
}; };
@ -17,15 +19,15 @@ use youmubot_prelude::*;
#[usage = "[mode (Std, Taiko, Catch, Mania) = Std]"] #[usage = "[mode (Std, Taiko, Catch, Mania) = Std]"]
#[max_args(1)] #[max_args(1)]
#[only_in(guilds)] #[only_in(guilds)]
pub fn server_rank(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { pub async fn server_rank(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await;
let mode = args.single::<ModeArg>().map(|v| v.0).unwrap_or(Mode::Std); let mode = args.single::<ModeArg>().map(|v| v.0).unwrap_or(Mode::Std);
let guild = m.guild_id.expect("Guild-only command"); let guild = m.guild_id.expect("Guild-only command");
let users = OsuSavedUsers::open(&*ctx.data.read()) let users = OsuSavedUsers::open(&*data).borrow()?.clone();
.borrow() let users = users
.expect("DB initialized") .into_iter()
.iter() .map(|(user_id, osu_user)| async move {
.filter_map(|(user_id, osu_user)| { guild.member(&ctx, user_id).await.ok().and_then(|member| {
guild.member(&ctx, user_id).ok().and_then(|member| {
osu_user osu_user
.pp .pp
.get(mode as usize) .get(mode as usize)
@ -34,7 +36,10 @@ pub fn server_rank(ctx: &mut Context, m: &Message, mut args: Args) -> CommandRes
.map(|pp| (pp, member.distinct(), osu_user.last_update.clone())) .map(|pp| (pp, member.distinct(), osu_user.last_update.clone()))
}) })
}) })
.collect::<Vec<_>>(); .collect::<stream::FuturesUnordered<_>>()
.filter_map(|v| future::ready(v))
.collect::<Vec<_>>()
.await;
let last_update = users.iter().map(|(_, _, a)| a).min().cloned(); let last_update = users.iter().map(|(_, _, a)| a).min().cloned();
let mut users = users let mut users = users
.into_iter() .into_iter()
@ -43,47 +48,55 @@ pub fn server_rank(ctx: &mut Context, m: &Message, mut args: Args) -> CommandRes
users.sort_by(|(a, _), (b, _)| (*b).partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); users.sort_by(|(a, _), (b, _)| (*b).partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
if users.is_empty() { if users.is_empty() {
m.reply(&ctx, "No saved users in the current server...")?; m.reply(&ctx, "No saved users in the current server...")
.await?;
return Ok(()); return Ok(());
} }
let users = std::sync::Arc::new(users);
let last_update = last_update.unwrap(); let last_update = last_update.unwrap();
const ITEMS_PER_PAGE: usize = 10; paginate(
ctx.data.get_cloned::<ReactionWatcher>().paginate_fn( move |page: u8, ctx: &Context, m: &mut Message| {
ctx.clone(), const ITEMS_PER_PAGE: usize = 10;
m.channel_id, let users = users.clone();
move |page: u8, e: &mut EditMessage| { Box::pin(async move {
let start = (page as usize) * ITEMS_PER_PAGE; let start = (page as usize) * ITEMS_PER_PAGE;
let end = (start + ITEMS_PER_PAGE).min(users.len()); let end = (start + ITEMS_PER_PAGE).min(users.len());
if start >= end { if start >= end {
return (e, Err(Error("No more items".to_owned()))); return Ok(false);
} }
let total_len = users.len(); let total_len = users.len();
let users = &users[start..end]; let users = &users[start..end];
let username_len = users.iter().map(|(_, u)| u.len()).max().unwrap().max(8); let username_len = users.iter().map(|(_, u)| u.len()).max().unwrap_or(8).max(8);
let mut content = MessageBuilder::new(); let mut content = MessageBuilder::new();
content
.push_line("```")
.push_line("Rank | pp | Username")
.push_line(format!("-----------------{:-<uw$}", "", uw = username_len));
for (id, (pp, member)) in users.iter().enumerate() {
content content
.push(format!( .push_line("```")
"{:>4} | {:>7.2} | ", .push_line("Rank | pp | Username")
format!("#{}", 1 + id + start), .push_line(format!("-----------------{:-<uw$}", "", uw = username_len));
pp for (id, (pp, member)) in users.iter().enumerate() {
)) content
.push_line_safe(member); .push(format!(
} "{:>4} | {:>7.2} | ",
content.push_line("```").push_line(format!( format!("#{}", 1 + id + start),
"Page **{}**/**{}**. Last updated: `{}`", pp
page + 1, ))
(total_len + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE, .push_line_safe(member);
last_update.to_rfc2822() }
)); content.push_line("```").push_line(format!(
(e.content(content.build()), Ok(())) "Page **{}**/**{}**. Last updated: `{}`",
page + 1,
(total_len + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE,
last_update.to_rfc2822()
));
m.edit(ctx, |f| f.content(content.to_string())).await?;
Ok(true)
})
}, },
ctx,
m.channel_id,
std::time::Duration::from_secs(60), std::time::Duration::from_secs(60),
)?; )
.await?;
Ok(()) Ok(())
} }
@ -93,48 +106,79 @@ pub fn server_rank(ctx: &mut Context, m: &Message, mut args: Args) -> CommandRes
#[description = "See the server's ranks on the last seen beatmap"] #[description = "See the server's ranks on the last seen beatmap"]
#[max_args(0)] #[max_args(0)]
#[only_in(guilds)] #[only_in(guilds)]
pub fn leaderboard(ctx: &mut Context, m: &Message, mut _args: Args) -> CommandResult { pub async fn leaderboard(ctx: &Context, m: &Message, mut _args: Args) -> CommandResult {
let bm = match get_beatmap(&*ctx.data.read(), m.channel_id)? { let data = ctx.data.read().await;
let mut osu_user_bests = OsuUserBests::open(&*data);
let bm = match get_beatmap(&*data, m.channel_id)? {
Some(bm) => bm, Some(bm) => bm,
None => { None => {
m.reply(&ctx, "No beatmap queried on this channel.")?; m.reply(&ctx, "No beatmap queried on this channel.").await?;
return Ok(()); return Ok(());
} }
}; };
// Run a check on the user once too!
{
let osu_users = OsuSavedUsers::open(&*data);
let user = osu_users.borrow()?.get(&m.author.id).map(|v| v.id);
if let Some(id) = user {
let osu = data.get::<OsuClient>().unwrap();
if let Ok(scores) = osu
.scores(bm.0.beatmap_id, |f| f.user(UserID::ID(id)))
.await
{
if !scores.is_empty() {
osu_user_bests
.borrow_mut()?
.entry((bm.0.beatmap_id, bm.1))
.or_default()
.insert(m.author.id, scores);
}
}
}
}
let guild = m.guild_id.expect("Guild-only command"); let guild = m.guild_id.expect("Guild-only command");
let scores = { let scores = {
let users = OsuUserBests::open(&*ctx.data.read()); const NO_SCORES: &'static str =
let users = users.borrow()?; "No scores have been recorded for this beatmap. Run `osu check` to scan for yours!";
let users = match users.get(&(bm.0.beatmap_id, bm.1)) {
let users = osu_user_bests
.borrow()?
.get(&(bm.0.beatmap_id, bm.1))
.cloned();
let users = match users {
None => { None => {
m.reply( m.reply(&ctx, NO_SCORES).await?;
&ctx,
"No scores have been recorded for this beatmap. Run `osu check` to scan for yours!",
)?;
return Ok(()); return Ok(());
} }
Some(v) if v.is_empty() => { Some(v) if v.is_empty() => {
m.reply( m.reply(&ctx, NO_SCORES).await?;
&ctx,
"No scores have been recorded for this beatmap. Run `osu check` to scan for yours!",
)?;
return Ok(()); return Ok(());
} }
Some(v) => v, Some(v) => v,
}; };
let mut scores: Vec<(f64, String, Score)> = users let mut scores: Vec<(f64, String, Score)> = users
.iter() .into_iter()
.filter_map(|(user_id, scores)| { .map(|(user_id, scores)| async move {
guild guild
.member(&ctx, user_id) .member(&ctx, user_id)
.await
.ok() .ok()
.and_then(|m| Some((m.distinct(), scores))) .and_then(|m| Some((m.distinct(), scores)))
}) })
.flat_map(|(user, scores)| scores.into_iter().map(move |v| (user.clone(), v.clone()))) .collect::<stream::FuturesUnordered<_>>()
.filter_map(|(user, score)| score.pp.map(|v| (v, user, score))) .filter_map(|v| future::ready(v))
.collect::<Vec<_>>(); .flat_map(|(user, scores)| {
scores
.into_iter()
.map(move |v| future::ready((user.clone(), v.clone())))
.collect::<stream::FuturesUnordered<_>>()
})
.filter_map(|(user, score)| future::ready(score.pp.map(|v| (v, user, score))))
.collect::<Vec<_>>()
.await;
scores scores
.sort_by(|(a, _, _), (b, _, _)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); .sort_by(|(a, _, _), (b, _, _)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
scores scores
@ -144,115 +188,121 @@ pub fn leaderboard(ctx: &mut Context, m: &Message, mut _args: Args) -> CommandRe
m.reply( m.reply(
&ctx, &ctx,
"No scores have been recorded for this beatmap. Run `osu check` to scan for yours!", "No scores have been recorded for this beatmap. Run `osu check` to scan for yours!",
)?; )
.await?;
return Ok(()); return Ok(());
} }
ctx.data.get_cloned::<ReactionWatcher>().paginate_fn( paginate(
ctx.clone(), move |page: u8, ctx: &Context, m: &mut Message| {
m.channel_id,
move |page: u8, e: &mut EditMessage| {
const ITEMS_PER_PAGE: usize = 5; const ITEMS_PER_PAGE: usize = 5;
let start = (page as usize) * ITEMS_PER_PAGE; let start = (page as usize) * ITEMS_PER_PAGE;
let end = (start + ITEMS_PER_PAGE).min(scores.len()); let end = (start + ITEMS_PER_PAGE).min(scores.len());
if start >= end { if start >= end {
return (e, Err(Error("No more items".to_owned()))); return Box::pin(future::ready(Ok(false)));
} }
let total_len = scores.len(); let total_len = scores.len();
let scores = &scores[start..end]; let scores = (&scores[start..end]).iter().cloned().collect::<Vec<_>>();
// username width let bm = (bm.0.clone(), bm.1.clone());
let uw = scores Box::pin(async move {
.iter() // username width
.map(|(_, u, _)| u.len()) let uw = scores
.max() .iter()
.unwrap_or(8) .map(|(_, u, _)| u.len())
.max(8); .max()
let accuracies = scores .unwrap_or(8)
.iter() .max(8);
.map(|(_, _, v)| format!("{:.2}%", v.accuracy(bm.1))) let accuracies = scores
.collect::<Vec<_>>(); .iter()
let aw = accuracies.iter().map(|v| v.len()).max().unwrap().max(3); .map(|(_, _, v)| format!("{:.2}%", v.accuracy(bm.1)))
let misses = scores .collect::<Vec<_>>();
.iter() let aw = accuracies.iter().map(|v| v.len()).max().unwrap().max(3);
.map(|(_, _, v)| format!("{}", v.count_miss)) let misses = scores
.collect::<Vec<_>>(); .iter()
let mw = misses.iter().map(|v| v.len()).max().unwrap().max(4); .map(|(_, _, v)| format!("{}", v.count_miss))
let ranks = scores .collect::<Vec<_>>();
.iter() let mw = misses.iter().map(|v| v.len()).max().unwrap().max(4);
.map(|(_, _, v)| v.rank.to_string()) let ranks = scores
.collect::<Vec<_>>(); .iter()
let rw = ranks.iter().map(|v| v.len()).max().unwrap().max(4); .map(|(_, _, v)| v.rank.to_string())
let pp = scores .collect::<Vec<_>>();
.iter() let rw = ranks.iter().map(|v| v.len()).max().unwrap().max(4);
.map(|(pp, _, _)| format!("{:.2}", pp)) let pp = scores
.collect::<Vec<_>>(); .iter()
let pw = pp.iter().map(|v| v.len()).max().unwrap_or(2); .map(|(pp, _, _)| format!("{:.2}", pp))
/*mods width*/ .collect::<Vec<_>>();
let mdw = scores let pw = pp.iter().map(|v| v.len()).max().unwrap_or(2);
.iter() /*mods width*/
.map(|(_, _, v)| v.mods.to_string().len()) let mdw = scores
.max() .iter()
.unwrap() .map(|(_, _, v)| v.mods.to_string().len())
.max(4); .max()
let mut content = MessageBuilder::new(); .unwrap()
content .max(4);
.push_line("```") let mut content = MessageBuilder::new();
.push_line(format!( content
"rank | {:>pw$} | {:mdw$} | {:rw$} | {:>aw$} | {:mw$} | {:uw$}", .push_line("```")
"pp", .push_line(format!(
"mods", "rank | {:>pw$} | {:mdw$} | {:rw$} | {:>aw$} | {:mw$} | {:uw$}",
"rank", "pp",
"acc", "mods",
"miss", "rank",
"user", "acc",
pw = pw, "miss",
mdw = mdw, "user",
rw = rw, pw = pw,
aw = aw, mdw = mdw,
mw = mw, rw = rw,
uw = uw, aw = aw,
)) mw = mw,
.push_line(format!( uw = uw,
"-------{:-<pw$}---{:-<mdw$}---{:-<rw$}---{:-<aw$}---{:-<mw$}---{:-<uw$}", ))
"", .push_line(format!(
"", "-------{:-<pw$}---{:-<mdw$}---{:-<rw$}---{:-<aw$}---{:-<mw$}---{:-<uw$}",
"", "",
"", "",
"", "",
"", "",
pw = pw, "",
mdw = mdw, "",
rw = rw, pw = pw,
aw = aw, mdw = mdw,
mw = mw, rw = rw,
uw = uw, aw = aw,
mw = mw,
uw = uw,
));
for (id, (_, member, p)) in scores.iter().enumerate() {
content.push_line_safe(format!(
"{:>4} | {:>pw$} | {:>mdw$} | {:>rw$} | {:>aw$} | {:>mw$} | {:uw$}",
format!("#{}", 1 + id + start),
pp[id],
p.mods.to_string(),
ranks[id],
accuracies[id],
misses[id],
member,
pw = pw,
mdw = mdw,
rw = rw,
aw = aw,
mw = mw,
uw = uw,
));
}
content.push_line("```").push_line(format!(
"Page **{}**/**{}**. Not seeing your scores? Run `osu check` to update.",
page + 1,
(total_len + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE,
)); ));
for (id, (_, member, p)) in scores.iter().enumerate() { m.edit(&ctx, |f| f.content(content.build())).await?;
content.push_line_safe(format!( Ok(true)
"{:>4} | {:>pw$} | {:>mdw$} | {:>rw$} | {:>aw$} | {:>mw$} | {:uw$}", })
format!("#{}", 1 + id + start),
pp[id],
p.mods.to_string(),
ranks[id],
accuracies[id],
misses[id],
member,
pw = pw,
mdw = mdw,
rw = rw,
aw = aw,
mw = mw,
uw = uw,
));
}
content.push_line("```").push_line(format!(
"Page **{}**/**{}**. Not seeing your scores? Run `osu check` to update.",
page + 1,
(total_len + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE,
));
(e.content(content.build()), Ok(()))
}, },
ctx,
m.channel_id,
std::time::Duration::from_secs(60), std::time::Duration::from_secs(60),
)?; )
.await?;
Ok(()) Ok(())
} }

View file

@ -8,16 +8,17 @@ mod test;
use models::*; use models::*;
use request::builders::*; use request::builders::*;
use request::*; use request::*;
use reqwest::blocking::{Client as HTTPClient, RequestBuilder, Response}; use reqwest::Client as HTTPClient;
use serenity::framework::standard::CommandError as Error; use std::convert::TryInto;
use std::{convert::TryInto, sync::Arc}; use youmubot_prelude::{ratelimit::Ratelimit, *};
/// The number of requests per minute to the osu! server.
const REQUESTS_PER_MINUTE: usize = 200;
/// Client is the client that will perform calls to the osu! api server. /// Client is the client that will perform calls to the osu! api server.
/// It's cheap to clone, so do it.
#[derive(Clone, Debug)]
pub struct Client { pub struct Client {
key: Arc<String>, client: Ratelimit<HTTPClient>,
client: HTTPClient, key: String,
} }
fn vec_try_into<U, T: std::convert::TryFrom<U>>(v: Vec<U>) -> Result<Vec<T>, T::Error> { fn vec_try_into<U, T: std::convert::TryFrom<U>>(v: Vec<U>) -> Result<Vec<T>, T::Error> {
@ -32,50 +33,55 @@ fn vec_try_into<U, T: std::convert::TryFrom<U>>(v: Vec<U>) -> Result<Vec<T>, T::
impl Client { impl Client {
/// Create a new client from the given API key. /// Create a new client from the given API key.
pub fn new(http_client: HTTPClient, key: String) -> Client { pub fn new(key: String) -> Client {
Client { let client = Ratelimit::new(
key: Arc::new(key), HTTPClient::new(),
client: http_client, REQUESTS_PER_MINUTE,
} std::time::Duration::from_secs(60),
);
Client { key, client }
} }
fn build_request(&self, r: RequestBuilder) -> Result<Response, Error> { pub(crate) async fn build_request(&self, url: &str) -> Result<reqwest::RequestBuilder> {
let v = r.query(&[("k", &*self.key)]).build()?; Ok(self
// dbg!(v.url()); .client
Ok(self.client.execute(v)?) .borrow()
.await?
.get(url)
.query(&[("k", &*self.key)]))
} }
pub fn beatmaps( pub async fn beatmaps(
&self, &self,
kind: BeatmapRequestKind, kind: BeatmapRequestKind,
f: impl FnOnce(&mut BeatmapRequestBuilder) -> &mut BeatmapRequestBuilder, f: impl FnOnce(&mut BeatmapRequestBuilder) -> &mut BeatmapRequestBuilder,
) -> Result<Vec<Beatmap>, Error> { ) -> Result<Vec<Beatmap>> {
let mut r = BeatmapRequestBuilder::new(kind); let mut r = BeatmapRequestBuilder::new(kind);
f(&mut r); f(&mut r);
let res: Vec<raw::Beatmap> = self.build_request(r.build(&self.client))?.json()?; let res: Vec<raw::Beatmap> = r.build(&self).await?.json().await?;
Ok(vec_try_into(res)?) Ok(vec_try_into(res)?)
} }
pub fn user( pub async fn user(
&self, &self,
user: UserID, user: UserID,
f: impl FnOnce(&mut UserRequestBuilder) -> &mut UserRequestBuilder, f: impl FnOnce(&mut UserRequestBuilder) -> &mut UserRequestBuilder,
) -> Result<Option<User>, Error> { ) -> Result<Option<User>, Error> {
let mut r = UserRequestBuilder::new(user); let mut r = UserRequestBuilder::new(user);
f(&mut r); f(&mut r);
let res: Vec<raw::User> = self.build_request(r.build(&self.client))?.json()?; let res: Vec<raw::User> = r.build(&self).await?.json().await?;
let res = vec_try_into(res)?; let res = vec_try_into(res)?;
Ok(res.into_iter().next()) Ok(res.into_iter().next())
} }
pub fn scores( pub async fn scores(
&self, &self,
beatmap_id: u64, beatmap_id: u64,
f: impl FnOnce(&mut ScoreRequestBuilder) -> &mut ScoreRequestBuilder, f: impl FnOnce(&mut ScoreRequestBuilder) -> &mut ScoreRequestBuilder,
) -> Result<Vec<Score>, Error> { ) -> Result<Vec<Score>, Error> {
let mut r = ScoreRequestBuilder::new(beatmap_id); let mut r = ScoreRequestBuilder::new(beatmap_id);
f(&mut r); f(&mut r);
let res: Vec<raw::Score> = self.build_request(r.build(&self.client))?.json()?; let res: Vec<raw::Score> = r.build(&self).await?.json().await?;
let mut res: Vec<Score> = vec_try_into(res)?; let mut res: Vec<Score> = vec_try_into(res)?;
// with a scores request you need to fill the beatmap ids yourself // with a scores request you need to fill the beatmap ids yourself
@ -85,23 +91,23 @@ impl Client {
Ok(res) Ok(res)
} }
pub fn user_best( pub async fn user_best(
&self, &self,
user: UserID, user: UserID,
f: impl FnOnce(&mut UserScoreRequestBuilder) -> &mut UserScoreRequestBuilder, f: impl FnOnce(&mut UserScoreRequestBuilder) -> &mut UserScoreRequestBuilder,
) -> Result<Vec<Score>, Error> { ) -> Result<Vec<Score>, Error> {
self.user_scores(UserScoreType::Best, user, f) self.user_scores(UserScoreType::Best, user, f).await
} }
pub fn user_recent( pub async fn user_recent(
&self, &self,
user: UserID, user: UserID,
f: impl FnOnce(&mut UserScoreRequestBuilder) -> &mut UserScoreRequestBuilder, f: impl FnOnce(&mut UserScoreRequestBuilder) -> &mut UserScoreRequestBuilder,
) -> Result<Vec<Score>, Error> { ) -> Result<Vec<Score>, Error> {
self.user_scores(UserScoreType::Recent, user, f) self.user_scores(UserScoreType::Recent, user, f).await
} }
fn user_scores( async fn user_scores(
&self, &self,
u: UserScoreType, u: UserScoreType,
user: UserID, user: UserID,
@ -109,7 +115,7 @@ impl Client {
) -> Result<Vec<Score>, Error> { ) -> Result<Vec<Score>, Error> {
let mut r = UserScoreRequestBuilder::new(u, user); let mut r = UserScoreRequestBuilder::new(u, user);
f(&mut r); f(&mut r);
let res: Vec<raw::Score> = self.build_request(r.build(&self.client))?.json()?; let res: Vec<raw::Score> = r.build(&self).await?.json().await?;
let res = vec_try_into(res)?; let res = vec_try_into(res)?;
Ok(res) Ok(res)
} }

View file

@ -1,6 +1,7 @@
use crate::models::{Mode, Mods}; use crate::models::{Mode, Mods};
use crate::Client;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use reqwest::blocking::{Client, RequestBuilder}; use youmubot_prelude::*;
trait ToQuery { trait ToQuery {
fn to_query(&self) -> Vec<(&'static str, String)>; fn to_query(&self) -> Vec<(&'static str, String)>;
@ -84,6 +85,8 @@ impl ToQuery for BeatmapRequestKind {
} }
pub mod builders { pub mod builders {
use reqwest::Response;
use super::*; use super::*;
/// A builder for a Beatmap request. /// A builder for a Beatmap request.
pub struct BeatmapRequestBuilder { pub struct BeatmapRequestBuilder {
@ -110,12 +113,15 @@ pub mod builders {
self self
} }
pub(crate) fn build(self, client: &Client) -> RequestBuilder { pub(crate) async fn build(self, client: &Client) -> Result<Response> {
client Ok(client
.get("https://osu.ppy.sh/api/get_beatmaps") .build_request("https://osu.ppy.sh/api/get_beatmaps")
.await?
.query(&self.kind.to_query()) .query(&self.kind.to_query())
.query(&self.since.map(|v| ("since", v)).to_query()) .query(&self.since.map(|v| ("since", v)).to_query())
.query(&self.mode.to_query()) .query(&self.mode.to_query())
.send()
.await?)
} }
} }
@ -144,9 +150,10 @@ pub mod builders {
self self
} }
pub(crate) fn build(&self, client: &Client) -> RequestBuilder { pub(crate) async fn build(&self, client: &Client) -> Result<Response> {
client Ok(client
.get("https://osu.ppy.sh/api/get_user") .build_request("https://osu.ppy.sh/api/get_user")
.await?
.query(&self.user.to_query()) .query(&self.user.to_query())
.query(&self.mode.to_query()) .query(&self.mode.to_query())
.query( .query(
@ -155,6 +162,8 @@ pub mod builders {
.map(|v| ("event_days", v.to_string())) .map(|v| ("event_days", v.to_string()))
.to_query(), .to_query(),
) )
.send()
.await?)
} }
} }
@ -197,14 +206,17 @@ pub mod builders {
self self
} }
pub(crate) fn build(&self, client: &Client) -> RequestBuilder { pub(crate) async fn build(&self, client: &Client) -> Result<Response> {
client Ok(client
.get("https://osu.ppy.sh/api/get_scores") .build_request("https://osu.ppy.sh/api/get_scores")
.await?
.query(&[("b", self.beatmap_id)]) .query(&[("b", self.beatmap_id)])
.query(&self.user.to_query()) .query(&self.user.to_query())
.query(&self.mode.to_query()) .query(&self.mode.to_query())
.query(&self.mods.to_query()) .query(&self.mods.to_query())
.query(&self.limit.map(|v| ("limit", v.to_string())).to_query()) .query(&self.limit.map(|v| ("limit", v.to_string())).to_query())
.send()
.await?)
} }
} }
@ -240,15 +252,18 @@ pub mod builders {
self self
} }
pub(crate) fn build(&self, client: &Client) -> RequestBuilder { pub(crate) async fn build(&self, client: &Client) -> Result<Response> {
client Ok(client
.get(match self.score_type { .build_request(match self.score_type {
UserScoreType::Best => "https://osu.ppy.sh/api/get_user_best", UserScoreType::Best => "https://osu.ppy.sh/api/get_user_best",
UserScoreType::Recent => "https://osu.ppy.sh/api/get_user_recent", UserScoreType::Recent => "https://osu.ppy.sh/api/get_user_recent",
}) })
.await?
.query(&self.user.to_query()) .query(&self.user.to_query())
.query(&self.mode.to_query()) .query(&self.mode.to_query())
.query(&self.limit.map(|v| ("limit", v.to_string())).to_query()) .query(&self.limit.map(|v| ("limit", v.to_string())).to_query())
.send()
.await?)
} }
} }
} }

View file

@ -7,9 +7,16 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
serenity = "0.8" anyhow = "1.0"
async-trait = "0.1"
futures-util = "0.3"
tokio = { version = "0.2", features = ["time"] }
youmubot-db = { path = "../youmubot-db" } youmubot-db = { path = "../youmubot-db" }
crossbeam-channel = "0.4"
reqwest = "0.10" reqwest = "0.10"
rayon = "1"
chrono = "0.4" chrono = "0.4"
flume = "0.9"
[dependencies.serenity]
version = "0.9.0-rc.0"
default-features = true
features = ["collector"]

View file

@ -1,10 +1,13 @@
use crate::{AppData, GetCloned}; use crate::{AppData, Result};
use crossbeam_channel::after; use async_trait::async_trait;
use rayon::prelude::*; use futures_util::{
future::{join_all, ready, FutureExt},
stream::{FuturesUnordered, StreamExt},
};
use serenity::{ use serenity::{
framework::standard::{ framework::standard::{
macros::{command, group}, macros::{command, group},
Args, CommandError as Error, CommandResult, Args, CommandResult,
}, },
http::CacheHttp, http::CacheHttp,
model::{ model::{
@ -15,11 +18,7 @@ use serenity::{
utils::MessageBuilder, utils::MessageBuilder,
CacheAndHttp, CacheAndHttp,
}; };
use std::{ use std::{collections::HashMap, sync::Arc};
collections::HashMap,
sync::Arc,
thread::{spawn, JoinHandle},
};
use youmubot_db::DB; use youmubot_db::DB;
/// A list of assigned channels for an announcer. /// A list of assigned channels for an announcer.
@ -33,30 +32,17 @@ pub(crate) type AnnouncerChannels = DB<HashMap<String, HashMap<GuildId, ChannelI
/// - An AppData, which can be used for interacting with internal databases. /// - An AppData, which can be used for interacting with internal databases.
/// - A function "channels", which takes an UserId and returns the list of ChannelIds, which any update related to that user should be /// - A function "channels", which takes an UserId and returns the list of ChannelIds, which any update related to that user should be
/// sent to. /// sent to.
#[async_trait]
pub trait Announcer: Send { pub trait Announcer: Send {
/// Look for updates and send them to respective channels. /// Look for updates and send them to respective channels.
/// ///
/// Errors returned from this function gets ignored and logged down. /// Errors returned from this function gets ignored and logged down.
fn updates( async fn updates(
&mut self, &mut self,
c: Arc<CacheAndHttp>, c: Arc<CacheAndHttp>,
d: AppData, d: AppData,
channels: MemberToChannels, channels: MemberToChannels,
) -> CommandResult; ) -> Result<()>;
}
impl<T> Announcer for T
where
T: FnMut(Arc<CacheAndHttp>, AppData, MemberToChannels) -> CommandResult + Send,
{
fn updates(
&mut self,
c: Arc<CacheAndHttp>,
d: AppData,
channels: MemberToChannels,
) -> CommandResult {
self(c, d, channels)
}
} }
/// A simple struct that allows looking up the relevant channels to an user. /// A simple struct that allows looking up the relevant channels to an user.
@ -64,18 +50,24 @@ pub struct MemberToChannels(Vec<(GuildId, ChannelId)>);
impl MemberToChannels { impl MemberToChannels {
/// Gets the channel list of an user related to that channel. /// Gets the channel list of an user related to that channel.
pub fn channels_of( pub async fn channels_of(
&self, &self,
http: impl CacheHttp + Clone + Sync, http: impl CacheHttp + Clone + Sync,
u: impl Into<UserId>, u: impl Into<UserId>,
) -> Vec<ChannelId> { ) -> Vec<ChannelId> {
let u = u.into(); let u: UserId = u.into();
self.0 self.0
.par_iter() .clone()
.filter_map(|(guild, channel)| { .into_iter()
guild.member(http.clone(), u).ok().map(|_| channel.clone()) .map(|(guild, channel): (GuildId, ChannelId)| {
guild
.member(http.clone(), u)
.map(move |v| v.ok().map(|_| channel.clone()))
}) })
.collect::<Vec<_>>() .collect::<FuturesUnordered<_>>()
.filter_map(|v| ready(v))
.collect()
.await
} }
} }
@ -85,7 +77,7 @@ impl MemberToChannels {
pub struct AnnouncerHandler { pub struct AnnouncerHandler {
cache_http: Arc<CacheAndHttp>, cache_http: Arc<CacheAndHttp>,
data: AppData, data: AppData,
announcers: HashMap<&'static str, Box<dyn Announcer>>, announcers: HashMap<&'static str, RwLock<Box<dyn Announcer + Send + Sync>>>,
} }
// Querying for the AnnouncerHandler in the internal data returns a vec of keys. // Querying for the AnnouncerHandler in the internal data returns a vec of keys.
@ -107,8 +99,15 @@ impl AnnouncerHandler {
/// Insert a new announcer into the handler. /// Insert a new announcer into the handler.
/// ///
/// The handler must take an unique key. If a duplicate is found, this method panics. /// The handler must take an unique key. If a duplicate is found, this method panics.
pub fn add(&mut self, key: &'static str, announcer: impl Announcer + 'static) -> &mut Self { pub fn add(
if let Some(_) = self.announcers.insert(key, Box::new(announcer)) { &mut self,
key: &'static str,
announcer: impl Announcer + Send + Sync + 'static,
) -> &mut Self {
if let Some(_) = self
.announcers
.insert(key, RwLock::new(Box::new(announcer)))
{
panic!( panic!(
"Announcer keys must be unique: another announcer with key `{}` was found", "Announcer keys must be unique: another announcer with key `{}` was found",
key key
@ -122,9 +121,8 @@ impl AnnouncerHandler {
/// Execution-related. /// Execution-related.
impl AnnouncerHandler { impl AnnouncerHandler {
/// Collect the list of guilds and their respective channels, by the key of the announcer. /// Collect the list of guilds and their respective channels, by the key of the announcer.
fn get_guilds(&self, key: &'static str) -> Result<Vec<(GuildId, ChannelId)>, Error> { async fn get_guilds(data: &AppData, key: &'static str) -> Result<Vec<(GuildId, ChannelId)>> {
let d = &self.data; let data = AnnouncerChannels::open(&*data.read().await)
let data = AnnouncerChannels::open(&*d.read())
.borrow()? .borrow()?
.get(key) .get(key)
.map(|m| m.iter().map(|(a, b)| (*a, *b)).collect()) .map(|m| m.iter().map(|(a, b)| (*a, *b)).collect())
@ -133,48 +131,55 @@ impl AnnouncerHandler {
} }
/// Run the announcing sequence on a certain announcer. /// Run the announcing sequence on a certain announcer.
fn announce(&mut self, key: &'static str) -> CommandResult { async fn announce(
let guilds: Vec<_> = self.get_guilds(key)?; data: AppData,
let channels = MemberToChannels(guilds); cache_http: Arc<CacheAndHttp>,
let cache_http = self.cache_http.clone(); key: &'static str,
let data = self.data.clone(); announcer: &'_ RwLock<Box<dyn Announcer + Send + Sync>>,
let announcer = self ) -> Result<()> {
.announcers let channels = MemberToChannels(Self::get_guilds(&data, key).await?);
.get_mut(&key) announcer
.expect("Key is from announcers"); .write()
announcer.updates(cache_http, data, channels)?; .await
Ok(()) .updates(cache_http, data, channels)
.await
} }
/// Start the AnnouncerHandler, moving it into another thread. /// Start the AnnouncerHandler, looping forever.
/// ///
/// It will run all the announcers in sequence every *cooldown* seconds. /// It will run all the announcers in sequence every *cooldown* seconds.
pub fn scan(mut self, cooldown: std::time::Duration) -> JoinHandle<()> { pub async fn scan(self, cooldown: std::time::Duration) -> () {
// First we store all the keys inside the database. // First we store all the keys inside the database.
let keys = self.announcers.keys().cloned().collect::<Vec<_>>(); let keys = self.announcers.keys().cloned().collect::<Vec<_>>();
self.data.write().insert::<Self>(keys.clone()); self.data.write().await.insert::<Self>(keys.clone());
spawn(move || loop { loop {
eprintln!("{}: announcer started scanning", chrono::Utc::now()); eprintln!("{}: announcer started scanning", chrono::Utc::now());
let after_timer = after(cooldown); // let after_timer = after(cooldown);
for key in &keys { let after = tokio::time::delay_for(cooldown);
join_all(self.announcers.iter().map(|(key, announcer)| {
eprintln!(" - scanning key `{}`", key); eprintln!(" - scanning key `{}`", key);
if let Err(e) = self.announce(key) { Self::announce(self.data.clone(), self.cache_http.clone(), *key, announcer).map(
dbg!(e); move |v| {
} if let Err(e) = v {
} eprintln!(" - key `{}`: {:?}", *key, e)
}
},
)
}))
.await;
eprintln!("{}: announcer finished scanning", chrono::Utc::now()); eprintln!("{}: announcer finished scanning", chrono::Utc::now());
after_timer.recv().ok(); after.await;
}) }
} }
} }
/// Gets the announcer of the given guild. /// Gets the announcer of the given guild.
pub fn announcer_of( pub async fn announcer_of(
ctx: &Context, ctx: &Context,
key: &'static str, key: &'static str,
guild: GuildId, guild: GuildId,
) -> Result<Option<ChannelId>, Error> { ) -> Result<Option<ChannelId>> {
Ok(AnnouncerChannels::open(&*ctx.data.read()) Ok(AnnouncerChannels::open(&*ctx.data.read().await)
.borrow()? .borrow()?
.get(key) .get(key)
.and_then(|channels| channels.get(&guild).cloned())) .and_then(|channels| channels.get(&guild).cloned()))
@ -184,20 +189,20 @@ pub fn announcer_of(
#[description = "List the registered announcers of this server"] #[description = "List the registered announcers of this server"]
#[num_args(0)] #[num_args(0)]
#[only_in(guilds)] #[only_in(guilds)]
pub fn list_announcers(ctx: &mut Context, m: &Message, _: Args) -> CommandResult { pub async fn list_announcers(ctx: &Context, m: &Message, _: Args) -> CommandResult {
let guild_id = m.guild_id.unwrap(); let guild_id = m.guild_id.unwrap();
let announcers = AnnouncerChannels::open(&*ctx.data.read()); let data = &*ctx.data.read().await;
let announcers = announcers.borrow()?; let announcers = AnnouncerChannels::open(data);
let channels = data.get::<AnnouncerHandler>().unwrap();
let channels = ctx let channels = channels
.data .iter()
.get_cloned::<AnnouncerHandler>() .filter_map(|&key| {
.into_iter() announcers.borrow().ok().and_then(|announcers| {
.filter_map(|key| { announcers
announcers .get(key)
.get(key) .and_then(|channels| channels.get(&guild_id))
.and_then(|channels| channels.get(&guild_id)) .map(|&ch| (key, ch))
.map(|&ch| (key, ch)) })
}) })
.map(|(key, ch)| format!(" - `{}`: activated on channel {}", key, ch.mention())) .map(|(key, ch)| format!(" - `{}`: activated on channel {}", key, ch.mention()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -208,7 +213,8 @@ pub fn list_announcers(ctx: &mut Context, m: &Message, _: Args) -> CommandResult
"Activated announcers on this server:\n{}", "Activated announcers on this server:\n{}",
channels.join("\n") channels.join("\n")
), ),
)?; )
.await?;
Ok(()) Ok(())
} }
@ -219,23 +225,24 @@ pub fn list_announcers(ctx: &mut Context, m: &Message, _: Args) -> CommandResult
#[required_permissions(MANAGE_CHANNELS)] #[required_permissions(MANAGE_CHANNELS)]
#[only_in(guilds)] #[only_in(guilds)]
#[num_args(1)] #[num_args(1)]
pub fn register_announcer(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { pub async fn register_announcer(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await;
let key = args.single::<String>()?; let key = args.single::<String>()?;
let keys = ctx.data.get_cloned::<AnnouncerHandler>(); let keys = data.get::<AnnouncerHandler>().unwrap();
if !keys.contains(&key.as_str()) { if !keys.contains(&&key[..]) {
m.reply( m.reply(
&ctx, &ctx,
format!( format!(
"Key not found. Available announcer keys are: `{}`", "Key not found. Available announcer keys are: `{}`",
keys.join(", ") keys.join(", ")
), ),
)?; )
.await?;
return Ok(()); return Ok(());
} }
let guild = m.guild(&ctx).expect("Guild-only command"); let guild = m.guild(&ctx).await.expect("Guild-only command");
let guild = guild.read(); let channel = m.channel_id.to_channel(&ctx).await?;
let channel = m.channel_id.to_channel(&ctx)?; AnnouncerChannels::open(&*data)
AnnouncerChannels::open(&*ctx.data.read())
.borrow_mut()? .borrow_mut()?
.entry(key.clone()) .entry(key.clone())
.or_default() .or_default()
@ -250,7 +257,8 @@ pub fn register_announcer(ctx: &mut Context, m: &Message, mut args: Args) -> Com
.push(" on channel ") .push(" on channel ")
.push_bold_safe(channel) .push_bold_safe(channel)
.build(), .build(),
)?; )
.await?;
Ok(()) Ok(())
} }
@ -260,9 +268,10 @@ pub fn register_announcer(ctx: &mut Context, m: &Message, mut args: Args) -> Com
#[required_permissions(MANAGE_CHANNELS)] #[required_permissions(MANAGE_CHANNELS)]
#[only_in(guilds)] #[only_in(guilds)]
#[num_args(1)] #[num_args(1)]
pub fn remove_announcer(ctx: &mut Context, m: &Message, mut args: Args) -> CommandResult { pub async fn remove_announcer(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await;
let key = args.single::<String>()?; let key = args.single::<String>()?;
let keys = ctx.data.get_cloned::<AnnouncerHandler>(); let keys = data.get::<AnnouncerHandler>().unwrap();
if !keys.contains(&key.as_str()) { if !keys.contains(&key.as_str()) {
m.reply( m.reply(
&ctx, &ctx,
@ -270,12 +279,12 @@ pub fn remove_announcer(ctx: &mut Context, m: &Message, mut args: Args) -> Comma
"Key not found. Available announcer keys are: `{}`", "Key not found. Available announcer keys are: `{}`",
keys.join(", ") keys.join(", ")
), ),
)?; )
.await?;
return Ok(()); return Ok(());
} }
let guild = m.guild(&ctx).expect("Guild-only command"); let guild = m.guild(&ctx).await.expect("Guild-only command");
let guild = guild.read(); AnnouncerChannels::open(&*data)
AnnouncerChannels::open(&*ctx.data.read())
.borrow_mut()? .borrow_mut()?
.entry(key.clone()) .entry(key.clone())
.and_modify(|m| { .and_modify(|m| {
@ -289,7 +298,8 @@ pub fn remove_announcer(ctx: &mut Context, m: &Message, mut args: Args) -> Comma
.push(" has been de-activated for server ") .push(" has been de-activated for server ")
.push_bold_safe(&guild.name) .push_bold_safe(&guild.name)
.build(), .build(),
)?; )
.await?;
Ok(()) Ok(())
} }

View file

@ -2,14 +2,17 @@ pub use duration::Duration;
pub use username_arg::UsernameArg; pub use username_arg::UsernameArg;
mod duration { mod duration {
use crate::{Error, Result};
use std::fmt; use std::fmt;
use std::time::Duration as StdDuration; use std::time::Duration as StdDuration;
use String as Error;
// Parse a single duration unit const INVALID_DURATION: &str = "Not a valid duration";
fn parse_duration_string(s: &str) -> Result<StdDuration, Error> {
/// Parse a single duration unit
fn parse_duration_string(s: &str) -> Result<StdDuration> {
// We reject the empty case // We reject the empty case
if s == "" { if s == "" {
return Err(Error::from("empty strings are not valid durations")); return Err(Error::msg("empty strings are not valid durations"));
} }
struct ParseStep { struct ParseStep {
current_value: Option<u64>, current_value: Option<u64>,
@ -26,7 +29,7 @@ mod duration {
current_value: Some(v.unwrap_or(0) * 10 + ((item as u64) - ('0' as u64))), current_value: Some(v.unwrap_or(0) * 10 + ((item as u64) - ('0' as u64))),
..s ..s
}), }),
(_, None) => Err(Error::from("Not a valid duration")), (_, None) => Err(Error::msg(INVALID_DURATION)),
(item, Some(v)) => Ok(ParseStep { (item, Some(v)) => Ok(ParseStep {
current_value: None, current_value: None,
current_duration: s.current_duration current_duration: s.current_duration
@ -36,7 +39,7 @@ mod duration {
'h' => StdDuration::from_secs(60 * 60), 'h' => StdDuration::from_secs(60 * 60),
'd' => StdDuration::from_secs(60 * 60 * 24), 'd' => StdDuration::from_secs(60 * 60 * 24),
'w' => StdDuration::from_secs(60 * 60 * 24 * 7), 'w' => StdDuration::from_secs(60 * 60 * 24 * 7),
_ => return Err(Error::from("Not a valid duration")), _ => return Err(Error::msg(INVALID_DURATION)),
} * (v as u32), } * (v as u32),
}), }),
}, },
@ -44,7 +47,7 @@ mod duration {
.and_then(|v| match v.current_value { .and_then(|v| match v.current_value {
// All values should be consumed // All values should be consumed
None => Ok(v), None => Ok(v),
_ => Err(Error::from("Not a valid duration")), _ => Err(Error::msg(INVALID_DURATION)),
}) })
.map(|v| v.current_duration) .map(|v| v.current_duration)
} }

View file

@ -0,0 +1,24 @@
use crate::{async_trait, future, Context, Result};
use serenity::model::channel::Message;
/// Hook represents the asynchronous hook that is run on every message.
#[async_trait]
pub trait Hook: Send + Sync {
async fn call(&mut self, ctx: &Context, message: &Message) -> Result<()>;
}
#[async_trait]
impl<T> Hook for T
where
T: for<'a> FnMut(
&'a Context,
&'a Message,
)
-> std::pin::Pin<Box<dyn future::Future<Output = Result<()>> + 'a + Send>>
+ Send
+ Sync,
{
async fn call(&mut self, ctx: &Context, message: &Message) -> Result<()> {
self(ctx, message).await
}
}

View file

@ -1,54 +1,40 @@
/// Module `prelude` provides a sane set of default imports that can be used inside
/// a Youmubot source file.
pub use serenity::prelude::*; pub use serenity::prelude::*;
use std::sync::Arc; use std::sync::Arc;
pub mod announcer; pub mod announcer;
pub mod args; pub mod args;
pub mod hook;
pub mod pagination; pub mod pagination;
pub mod reaction_watch; pub mod ratelimit;
pub mod setup; pub mod setup;
pub use announcer::{Announcer, AnnouncerHandler}; pub use announcer::{Announcer, AnnouncerHandler};
pub use args::{Duration, UsernameArg}; pub use args::{Duration, UsernameArg};
pub use pagination::Pagination; pub use hook::Hook;
pub use reaction_watch::{ReactionHandler, ReactionWatcher}; pub use pagination::paginate;
/// Re-exporting async_trait helps with implementing Announcer.
pub use async_trait::async_trait;
/// Re-export the anyhow errors
pub use anyhow::{Error, Result};
/// Re-export useful future and stream utils
pub use futures_util::{future, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
/// Re-export the spawn function
pub use tokio::spawn as spawn_future;
/// The global app data. /// The global app data.
pub type AppData = Arc<RwLock<ShareMap>>; pub type AppData = Arc<RwLock<TypeMap>>;
/// The HTTP client. /// The HTTP client.
pub struct HTTPClient; pub struct HTTPClient;
impl TypeMapKey for HTTPClient { impl TypeMapKey for HTTPClient {
type Value = reqwest::blocking::Client; type Value = reqwest::Client;
}
/// The TypeMap trait that allows TypeMaps to quickly get a clonable item.
pub trait GetCloned {
/// Gets an item from the store, cloned.
fn get_cloned<T>(&self) -> T::Value
where
T: TypeMapKey,
T::Value: Clone + Send + Sync;
}
impl GetCloned for ShareMap {
fn get_cloned<T>(&self) -> T::Value
where
T: TypeMapKey,
T::Value: Clone + Send + Sync,
{
self.get::<T>().cloned().expect("Should be there")
}
}
impl GetCloned for AppData {
fn get_cloned<T>(&self) -> T::Value
where
T: TypeMapKey,
T::Value: Clone + Send + Sync,
{
self.read().get::<T>().cloned().expect("Should be there")
}
} }
pub mod prelude_commands { pub mod prelude_commands {
@ -70,8 +56,8 @@ pub mod prelude_commands {
#[command] #[command]
#[description = "pong!"] #[description = "pong!"]
fn ping(ctx: &mut Context, m: &Message) -> CommandResult { async fn ping(ctx: &Context, m: &Message) -> CommandResult {
m.reply(&ctx, "Pong!")?; m.reply(&ctx, "Pong!").await?;
Ok(()) Ok(())
} }
} }

View file

@ -1,157 +1,111 @@
use crate::{Context, ReactionHandler, ReactionWatcher}; use crate::{Context, Result};
use futures_util::{future::Future, StreamExt};
use serenity::{ use serenity::{
builder::EditMessage, collector::ReactionAction,
framework::standard::{CommandError, CommandResult},
model::{ model::{
channel::{Message, Reaction, ReactionType}, channel::{Message, ReactionType},
id::ChannelId, id::ChannelId,
}, },
}; };
use std::convert::TryFrom;
use tokio::time as tokio_time;
const ARROW_RIGHT: &'static str = "➡️"; const ARROW_RIGHT: &'static str = "➡️";
const ARROW_LEFT: &'static str = "⬅️"; const ARROW_LEFT: &'static str = "⬅️";
impl ReactionWatcher { #[async_trait::async_trait]
/// Start a pagination. pub trait Paginate {
/// async fn render(&mut self, page: u8, ctx: &Context, m: &mut Message) -> Result<bool>;
/// Takes a copy of Context (which you can `clone`), a pager (see "Pagination") and a target channel id.
/// Pagination will handle all events on adding/removing an "arrow" emoji (⬅️ and ➡️).
/// This is a blocking call - it will block the thread until duration is over.
pub fn paginate<T: Pagination + Send + 'static>(
&self,
ctx: Context,
channel: ChannelId,
pager: T,
duration: std::time::Duration,
) -> CommandResult {
let handler = PaginationHandler::new(pager, ctx, channel)?;
self.handle_reactions(handler, duration, |_| {});
Ok(())
}
/// A version of `paginate` that compiles for closures.
///
/// A workaround until https://github.com/rust-lang/rust/issues/36582 is solved.
pub fn paginate_fn<T>(
&self,
ctx: Context,
channel: ChannelId,
pager: T,
duration: std::time::Duration,
) -> CommandResult
where
T: for<'a> FnMut(u8, &'a mut EditMessage) -> (&'a mut EditMessage, CommandResult)
+ Send
+ 'static,
{
self.paginate(ctx, channel, pager, duration)
}
} }
/// Pagination allows the bot to display content in multiple pages. #[async_trait::async_trait]
/// impl<T> Paginate for T
/// You need to implement the "render_page" function, which takes a dummy content and
/// embed assigning function.
/// Pagination is automatically implemented for functions with the same signature as `render_page`.
///
/// Pages start at 0.
pub trait Pagination {
/// Render a page.
///
/// This would either create or edit a message, but you should not be worry about it.
fn render_page<'a>(
&mut self,
page: u8,
target: &'a mut EditMessage,
) -> (&'a mut EditMessage, CommandResult);
}
impl<T> Pagination for T
where where
T: for<'a> FnMut(u8, &'a mut EditMessage) -> (&'a mut EditMessage, CommandResult), T: for<'m> FnMut(
u8,
&'m Context,
&'m mut Message,
) -> std::pin::Pin<Box<dyn Future<Output = Result<bool>> + Send + 'm>>
+ Send,
{ {
fn render_page<'a>( async fn render(&mut self, page: u8, ctx: &Context, m: &mut Message) -> Result<bool> {
&mut self, self(page, ctx, m).await
page: u8,
target: &'a mut EditMessage,
) -> (&'a mut EditMessage, CommandResult) {
self(page, target)
} }
} }
struct PaginationHandler<T: Pagination> { // Paginate! with a pager function.
pager: T, /// If awaited, will block until everything is done.
message: Message, pub async fn paginate(
mut pager: impl for<'m> FnMut(
u8,
&'m Context,
&'m mut Message,
) -> std::pin::Pin<Box<dyn Future<Output = Result<bool>> + Send + 'm>>
+ Send,
ctx: &Context,
channel: ChannelId,
timeout: std::time::Duration,
) -> Result<()> {
let mut message = channel
.send_message(&ctx, |e| e.content("Youmu is loading the first page..."))
.await?;
// React to the message
message
.react(&ctx, ReactionType::try_from(ARROW_LEFT)?)
.await?;
message
.react(&ctx, ReactionType::try_from(ARROW_RIGHT)?)
.await?;
pager(0, ctx, &mut message).await?;
// Build a reaction collector
let mut reaction_collector = message.await_reactions(&ctx).removed(true).await;
let mut page = 0;
// Loop the handler function.
let res: Result<()> = loop {
match tokio_time::timeout(timeout, reaction_collector.next()).await {
Err(_) => break Ok(()),
Ok(None) => break Ok(()),
Ok(Some(reaction)) => {
page = match handle_reaction(page, &mut pager, ctx, &mut message, &reaction).await {
Ok(v) => v,
Err(e) => break Err(e),
};
}
}
};
message.react(&ctx, '🛑').await?;
res
}
// Handle the reaction and return a new page number.
async fn handle_reaction(
page: u8, page: u8,
ctx: Context, pager: &mut impl Paginate,
} ctx: &Context,
message: &mut Message,
impl<T: Pagination> PaginationHandler<T> { reaction: &ReactionAction,
pub fn new(pager: T, mut ctx: Context, channel: ChannelId) -> Result<Self, CommandError> { ) -> Result<u8> {
let message = channel.send_message(&mut ctx, |e| { let reaction = match reaction {
e.content("Youmu is loading the first page...") ReactionAction::Added(v) | ReactionAction::Removed(v) => v,
})?; };
// React to the message match &reaction.emoji {
message.react(&mut ctx, ARROW_LEFT)?; ReactionType::Unicode(ref s) => match s.as_str() {
message.react(&mut ctx, ARROW_RIGHT)?; ARROW_LEFT if page == 0 => Ok(page),
let mut p = Self { ARROW_LEFT => Ok(if pager.render(page - 1, ctx, message).await? {
pager, page - 1
message: message.clone(), } else {
page: 0, page
ctx, }),
}; ARROW_RIGHT => Ok(if pager.render(page + 1, ctx, message).await? {
p.call_pager()?; page + 1
Ok(p) } else {
} page
} }),
_ => Ok(page),
impl<T: Pagination> PaginationHandler<T> { },
/// Call the pager, log the error (if any). _ => Ok(page),
fn call_pager(&mut self) -> CommandResult {
let mut res: CommandResult = Ok(());
let mut msg = self.message.clone();
msg.edit(self.ctx.http.clone(), |e| {
let (e, r) = self.pager.render_page(self.page, e);
res = r;
e
})?;
self.message = msg;
res
}
}
impl<T: Pagination> Drop for PaginationHandler<T> {
fn drop(&mut self) {
self.message.react(&self.ctx, "🛑").ok();
}
}
impl<T: Pagination> ReactionHandler for PaginationHandler<T> {
fn handle_reaction(&mut self, reaction: &Reaction, _is_add: bool) -> CommandResult {
if reaction.message_id != self.message.id {
return Ok(());
}
match &reaction.emoji {
ReactionType::Unicode(ref s) => match s.as_str() {
ARROW_LEFT if self.page == 0 => return Ok(()),
ARROW_LEFT => {
self.page -= 1;
if let Err(e) = self.call_pager() {
self.page += 1;
return Err(e);
}
}
ARROW_RIGHT => {
self.page += 1;
if let Err(e) = self.call_pager() {
self.page -= 1;
return Err(e);
}
}
_ => (),
},
_ => (),
}
Ok(())
} }
} }

View file

@ -0,0 +1,67 @@
/// Provides a simple ratelimit lock (that only works in tokio)
// use tokio::time::
use std::time::Duration;
use crate::Result;
use flume::{bounded as channel, Receiver, Sender};
use std::ops::Deref;
/// Holds the underlying `T` in a rate-limited way.
pub struct Ratelimit<T> {
inner: T,
recv: Receiver<()>,
send: Sender<()>,
wait_time: Duration,
}
struct RatelimitGuard<'a, T> {
inner: &'a T,
send: &'a Sender<()>,
wait_time: &'a Duration,
}
impl<T> Ratelimit<T> {
/// Create a new ratelimit with at most `count` uses in `wait_time`.
pub fn new(inner: T, count: usize, wait_time: Duration) -> Self {
let (send, recv) = channel(count);
(0..count).for_each(|_| {
send.send(()).ok();
});
Self {
inner,
send,
recv,
wait_time,
}
}
/// Borrow the inner `T`. You can only hol this reference `count` times in `wait_time`.
/// The clock counts from the moment the ref is dropped.
pub async fn borrow<'a>(&'a self) -> Result<impl Deref<Target = T> + 'a> {
self.recv.recv_async().await?;
Ok(RatelimitGuard {
inner: &self.inner,
send: &self.send,
wait_time: &self.wait_time,
})
}
}
impl<'a, T> Deref for RatelimitGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner
}
}
impl<'a, T> Drop for RatelimitGuard<'a, T> {
fn drop(&mut self) {
let send = self.send.clone();
let wait_time = self.wait_time.clone();
tokio::spawn(async move {
tokio::time::delay_for(wait_time).await;
send.send_async(()).await.ok();
});
}
}

View file

@ -1,105 +0,0 @@
use crossbeam_channel::{after, bounded, select, Sender};
use serenity::{framework::standard::CommandResult, model::channel::Reaction, prelude::*};
use std::sync::{Arc, Mutex};
/// Handles a reaction.
///
/// Every handler needs an expire time too.
pub trait ReactionHandler {
/// Handle a reaction. This is fired on EVERY reaction.
/// You do the filtering yourself.
///
/// If `is_added` is false, the reaction was removed instead of added.
fn handle_reaction(&mut self, reaction: &Reaction, is_added: bool) -> CommandResult;
}
impl<T> ReactionHandler for T
where
T: FnMut(&Reaction, bool) -> CommandResult,
{
fn handle_reaction(&mut self, reaction: &Reaction, is_added: bool) -> CommandResult {
self(reaction, is_added)
}
}
/// The store for a set of dynamic reaction handlers.
#[derive(Debug, Clone)]
pub struct ReactionWatcher {
channels: Arc<Mutex<Vec<Sender<(Arc<Reaction>, bool)>>>>,
}
impl TypeMapKey for ReactionWatcher {
type Value = ReactionWatcher;
}
impl ReactionWatcher {
/// Create a new ReactionWatcher.
pub fn new() -> Self {
Self {
channels: Arc::new(Mutex::new(vec![])),
}
}
/// Send a reaction.
/// If `is_added` is false, the reaction was removed.
pub fn send(&self, r: Reaction, is_added: bool) {
let r = Arc::new(r);
self.channels
.lock()
.expect("Poisoned!")
.retain(|e| e.send((r.clone(), is_added)).is_ok());
}
/// React! to a series of reaction
///
/// The reactions stop after `duration` of idle.
pub fn handle_reactions<H: ReactionHandler + Send + 'static>(
&self,
mut h: H,
duration: std::time::Duration,
callback: impl FnOnce(H) -> () + Send + 'static,
) {
let (send, reactions) = bounded(0);
{
self.channels.lock().expect("Poisoned!").push(send);
}
std::thread::spawn(move || {
loop {
let timeout = after(duration);
let r = select! {
recv(reactions) -> r => { let (r, is_added) = r.unwrap(); h.handle_reaction(&*r, is_added) },
recv(timeout) -> _ => break,
};
if let Err(v) = r {
dbg!(v);
}
}
callback(h)
});
}
/// React! to a series of reaction
///
/// The handler will stop after `duration` no matter what.
pub fn handle_reactions_timed<H: ReactionHandler + Send + 'static>(
&self,
mut h: H,
duration: std::time::Duration,
callback: impl FnOnce(H) -> () + Send + 'static,
) {
let (send, reactions) = bounded(0);
{
self.channels.lock().expect("Poisoned!").push(send);
}
std::thread::spawn(move || {
let timeout = after(duration);
loop {
let r = select! {
recv(reactions) -> r => { let (r, is_added) = r.unwrap(); h.handle_reaction(&*r, is_added) },
recv(timeout) -> _ => break,
};
if let Err(v) = r {
dbg!(v);
}
}
callback(h);
});
}
}

View file

@ -1,17 +1,14 @@
use serenity::{framework::standard::StandardFramework, prelude::*}; use serenity::prelude::*;
use std::path::Path; use std::path::Path;
/// Set up the prelude libraries. /// Set up the prelude libraries.
/// ///
/// Panics on failure: Youmubot should *NOT* attempt to continue when this function fails. /// Panics on failure: Youmubot should *NOT* attempt to continue when this function fails.
pub fn setup_prelude(db_path: &Path, data: &mut ShareMap, _: &mut StandardFramework) { pub fn setup_prelude(db_path: &Path, data: &mut TypeMap) {
// Setup the announcer DB. // Setup the announcer DB.
crate::announcer::AnnouncerChannels::insert_into(data, db_path.join("announcers.yaml")) crate::announcer::AnnouncerChannels::insert_into(data, db_path.join("announcers.yaml"))
.expect("Announcers DB set up"); .expect("Announcers DB set up");
// Set up the HTTP client. // Set up the HTTP client.
data.insert::<crate::HTTPClient>(reqwest::blocking::Client::new()); data.insert::<crate::HTTPClient>(reqwest::Client::new());
// Set up the reaction watcher.
data.insert::<crate::ReactionWatcher>(crate::ReactionWatcher::new());
} }

View file

@ -12,8 +12,10 @@ osu = ["youmubot-osu"]
codeforces = ["youmubot-cf"] codeforces = ["youmubot-cf"]
[dependencies] [dependencies]
serenity = "0.8" serenity = "0.9.0-rc.0"
tokio = "0.2"
dotenv = "0.15" dotenv = "0.15"
env_logger = "0.7"
youmubot-db = { path = "../youmubot-db" } youmubot-db = { path = "../youmubot-db" }
youmubot-prelude = { path = "../youmubot-prelude" } youmubot-prelude = { path = "../youmubot-prelude" }
youmubot-core = { path = "../youmubot-core" } youmubot-core = { path = "../youmubot-core" }

View file

@ -1,9 +1,10 @@
use dotenv; use dotenv;
use dotenv::var; use dotenv::var;
use serenity::{ use serenity::{
framework::standard::{DispatchError, StandardFramework}, client::bridge::gateway::GatewayIntents,
framework::standard::{macros::hook, CommandResult, DispatchError, StandardFramework},
model::{ model::{
channel::{Channel, Message, Reaction}, channel::{Channel, Message},
gateway, gateway,
id::{ChannelId, GuildId, UserId}, id::{ChannelId, GuildId, UserId},
permissions::Permissions, permissions::Permissions,
@ -12,51 +13,59 @@ use serenity::{
use youmubot_prelude::*; use youmubot_prelude::*;
struct Handler { struct Handler {
hooks: Vec<fn(&mut Context, &Message) -> ()>, hooks: Vec<RwLock<Box<dyn Hook>>>,
} }
impl Handler { impl Handler {
fn new() -> Handler { fn new() -> Handler {
Handler { hooks: vec![] } Handler { hooks: vec![] }
} }
fn push_hook<T: Hook + 'static>(&mut self, f: T) {
self.hooks.push(RwLock::new(Box::new(f)));
}
} }
#[async_trait]
impl EventHandler for Handler { impl EventHandler for Handler {
fn ready(&self, _: Context, ready: gateway::Ready) { async fn ready(&self, _: Context, ready: gateway::Ready) {
println!("{} is connected!", ready.user.name); println!("{} is connected!", ready.user.name);
} }
fn message(&self, mut ctx: Context, message: Message) { async fn message(&self, ctx: Context, message: Message) {
self.hooks.iter().for_each(|f| f(&mut ctx, &message)); self.hooks
} .iter()
.map(|hook| {
fn reaction_add(&self, ctx: Context, reaction: Reaction) { let ctx = ctx.clone();
ctx.data let message = message.clone();
.get_cloned::<ReactionWatcher>() hook.write()
.send(reaction, true); .then(|mut h| async move { h.call(&ctx, &message).await })
} })
.collect::<stream::FuturesUnordered<_>>()
fn reaction_remove(&self, ctx: Context, reaction: Reaction) { .for_each(|v| async move {
ctx.data if let Err(e) = v {
.get_cloned::<ReactionWatcher>() eprintln!("{}", e)
.send(reaction, false); }
})
.await;
} }
} }
/// Returns whether the user has "MANAGE_MESSAGES" permission in the channel. /// Returns whether the user has "MANAGE_MESSAGES" permission in the channel.
fn is_channel_mod(ctx: &mut Context, _: Option<GuildId>, ch: ChannelId, u: UserId) -> bool { async fn is_channel_mod(ctx: &Context, _: Option<GuildId>, ch: ChannelId, u: UserId) -> bool {
match ch.to_channel(&ctx) { match ch.to_channel(&ctx).await {
Ok(Channel::Guild(gc)) => { Ok(Channel::Guild(gc)) => gc
let gc = gc.read(); .permissions_for_user(&ctx, u)
gc.permissions_for_user(&ctx, u) .await
.map(|perms| perms.contains(Permissions::MANAGE_MESSAGES)) .map(|perms| perms.contains(Permissions::MANAGE_MESSAGES))
.unwrap_or(false) .unwrap_or(false),
}
_ => false, _ => false,
} }
} }
fn main() { #[tokio::main]
async fn main() {
env_logger::init();
// Setup dotenv // Setup dotenv
if let Ok(path) = dotenv::dotenv() { if let Ok(path) = dotenv::dotenv() {
println!("Loaded dotenv from {:?}", path); println!("Loaded dotenv from {:?}", path);
@ -65,34 +74,48 @@ fn main() {
let mut handler = Handler::new(); let mut handler = Handler::new();
// Set up hooks // Set up hooks
#[cfg(feature = "osu")] #[cfg(feature = "osu")]
handler.hooks.push(youmubot_osu::discord::hook); handler.push_hook(youmubot_osu::discord::hook);
#[cfg(feature = "codeforces")] #[cfg(feature = "codeforces")]
handler.hooks.push(youmubot_cf::codeforces_info_hook); handler.push_hook(youmubot_cf::InfoHook);
// Collect the token
let token = var("TOKEN").expect("Please set TOKEN as the Discord Bot's token to be used.");
// Set up base framework
let fw = setup_framework(&token[..]).await;
// Sets up a client // Sets up a client
let mut client = { let mut client = {
// Collect the token
let token = var("TOKEN").expect("Please set TOKEN as the Discord Bot's token to be used.");
// Attempt to connect and set up a framework // Attempt to connect and set up a framework
Client::new(token, handler).expect("Cannot connect") Client::new(token)
.framework(fw)
.event_handler(handler)
.intents(
GatewayIntents::GUILDS
| GatewayIntents::GUILD_BANS
| GatewayIntents::GUILD_MESSAGES
| GatewayIntents::GUILD_MESSAGE_REACTIONS
| GatewayIntents::GUILD_PRESENCES
| GatewayIntents::GUILD_MEMBERS
| GatewayIntents::DIRECT_MESSAGES
| GatewayIntents::DIRECT_MESSAGE_REACTIONS,
)
.await
.unwrap()
}; };
// Set up base framework
let mut fw = setup_framework(&client);
// Set up announcer handler // Set up announcer handler
let mut announcers = AnnouncerHandler::new(&client); let mut announcers = AnnouncerHandler::new(&client);
// Setup each package starting from the prelude. // Setup each package starting from the prelude.
{ {
let mut data = client.data.write(); let mut data = client.data.write().await;
let db_path = var("DBPATH") let db_path = var("DBPATH")
.map(|v| std::path::PathBuf::from(v)) .map(|v| std::path::PathBuf::from(v))
.unwrap_or_else(|e| { .unwrap_or_else(|e| {
println!("No DBPATH set up ({:?}), using `/data`", e); println!("No DBPATH set up ({:?}), using `/data`", e);
std::path::PathBuf::from("data") std::path::PathBuf::from("data")
}); });
youmubot_prelude::setup::setup_prelude(&db_path, &mut data, &mut fw); youmubot_prelude::setup::setup_prelude(&db_path, &mut data);
// Setup core // Setup core
#[cfg(feature = "core")] #[cfg(feature = "core")]
youmubot_core::setup(&db_path, &client, &mut data).expect("Setup db should succeed"); youmubot_core::setup(&db_path, &client, &mut data).expect("Setup db should succeed");
@ -102,7 +125,7 @@ fn main() {
.expect("osu! is initialized"); .expect("osu! is initialized");
// codeforces // codeforces
#[cfg(feature = "codeforces")] #[cfg(feature = "codeforces")]
youmubot_cf::setup(&db_path, &mut data, &mut announcers); youmubot_cf::setup(&db_path, &mut data, &mut announcers).await;
} }
#[cfg(feature = "core")] #[cfg(feature = "core")]
@ -112,11 +135,10 @@ fn main() {
#[cfg(feature = "codeforces")] #[cfg(feature = "codeforces")]
println!("codeforces enabled."); println!("codeforces enabled.");
client.with_framework(fw); tokio::spawn(announcers.scan(std::time::Duration::from_secs(120)));
announcers.scan(std::time::Duration::from_secs(120));
println!("Starting..."); println!("Starting...");
if let Err(v) = client.start() { if let Err(v) = client.start().await {
panic!(v) panic!(v)
} }
@ -124,71 +146,43 @@ fn main() {
} }
// Sets up a framework for a client // Sets up a framework for a client
fn setup_framework(client: &Client) -> StandardFramework { async fn setup_framework(token: &str) -> StandardFramework {
let http = serenity::http::Http::new_with_token(token);
// Collect owners // Collect owners
let owner = client let owner = http
.cache_and_http
.http
.get_current_application_info() .get_current_application_info()
.await
.expect("Should be able to get app info") .expect("Should be able to get app info")
.owner; .owner;
let fw = StandardFramework::new() let fw = StandardFramework::new()
.configure(|c| { .configure(|c| {
c.with_whitespace(false) c.with_whitespace(false)
.prefix(&var("PREFIX").unwrap_or("y!".to_owned())) .prefix(&var("PREFIX").unwrap_or("y!".to_owned()))
.delimiters(vec![" / ", "/ ", " /", "/"]) .delimiters(vec![" / ", "/ ", " /", "/"])
.owners([owner.id].iter().cloned().collect()) .owners([owner.id].iter().cloned().collect())
}) })
.help(&youmubot_core::HELP) .help(&youmubot_core::HELP)
.before(|_, msg, command_name| { .before(before_hook)
println!( .after(after_hook)
"Got command '{}' by user '{}'", .on_dispatch_error(on_dispatch_error)
command_name, msg.author.name .bucket("voting", |c| {
); c.check(|ctx, g, ch, u| Box::pin(async move { !is_channel_mod(ctx, g, ch, u).await }))
true .delay(120 /* 2 minutes */)
}) .time_span(120)
.after(|ctx, msg, command_name, error| match error { .limit(1)
Ok(()) => println!("Processed command '{}'", command_name), })
Err(why) => { .await
let reply = format!("Command '{}' returned error {:?}", command_name, why); .bucket("images", |c| c.time_span(60).limit(2))
if let Err(_) = msg.reply(&ctx, &reply) {} .await
println!("{}", reply) .bucket("community", |c| {
} c.check(|ctx, g, ch, u| Box::pin(async move { !is_channel_mod(ctx, g, ch, u).await }))
}) .delay(30)
.on_dispatch_error(|ctx, msg, error| { .time_span(30)
msg.reply( .limit(1)
&ctx, })
&match error { .await
DispatchError::Ratelimited(seconds) => format!( .group(&prelude_commands::PRELUDE_GROUP);
"⏳ You are being rate-limited! Try this again in **{} seconds**.",
seconds
),
DispatchError::NotEnoughArguments { min, given } => format!("😕 The command needs at least **{}** arguments, I only got **{}**!\nDid you know command arguments are separated with a slash (`/`)?", min, given),
DispatchError::TooManyArguments { max, given } => format!("😕 I can only handle at most **{}** arguments, but I got **{}**!", max, given),
DispatchError::OnlyForGuilds => format!("🔇 This command cannot be used in DMs."),
_ => return,
},
)
.unwrap(); // Invoke
})
// Set a function that's called whenever an attempted command-call's
// command could not be found.
.unrecognised_command(|_, _, unknown_command_name| {
println!("Could not find command named '{}'", unknown_command_name);
})
// Set a function that's called whenever a message is not a command.
.normal_message(|_, _| {
// println!("Message is not a command '{}'", message.content);
})
.bucket("voting", |c| {
c.check(|ctx, g, ch, u| !is_channel_mod(ctx, g, ch, u)).delay(120 /* 2 minutes */).time_span(120).limit(1)
})
.bucket("images", |c| c.time_span(60).limit(2))
.bucket("community", |c| {
c.check(|ctx, g, ch, u| !is_channel_mod(ctx, g, ch, u)).delay(30).time_span(30).limit(1)
})
.group(&prelude_commands::PRELUDE_GROUP);
// groups here // groups here
#[cfg(feature = "core")] #[cfg(feature = "core")]
let fw = fw let fw = fw
@ -201,3 +195,53 @@ fn setup_framework(client: &Client) -> StandardFramework {
let fw = fw.group(&youmubot_cf::CODEFORCES_GROUP); let fw = fw.group(&youmubot_cf::CODEFORCES_GROUP);
fw fw
} }
// Hooks!
#[hook]
async fn before_hook(_: &Context, msg: &Message, command_name: &str) -> bool {
println!(
"Got command '{}' by user '{}'",
command_name, msg.author.name
);
true
}
#[hook]
async fn after_hook(ctx: &Context, msg: &Message, command_name: &str, error: CommandResult) {
match error {
Ok(()) => println!("Processed command '{}'", command_name),
Err(why) => {
let reply = format!("Command '{}' returned error {:?}", command_name, why);
msg.reply(&ctx, &reply).await.ok();
println!("{}", reply)
}
}
}
#[hook]
async fn on_dispatch_error(ctx: &Context, msg: &Message, error: DispatchError) {
msg.reply(
&ctx,
&match error {
DispatchError::Ratelimited(seconds) => format!(
"⏳ You are being rate-limited! Try this again in **{}**.",
youmubot_prelude::Duration(seconds),
),
DispatchError::NotEnoughArguments { min, given } => {
format!(
"😕 The command needs at least **{}** arguments, I only got **{}**!",
min, given
) + "\nDid you know command arguments are separated with a slash (`/`)?"
}
DispatchError::TooManyArguments { max, given } => format!(
"😕 I can only handle at most **{}** arguments, but I got **{}**!",
max, given
),
DispatchError::OnlyForGuilds => format!("🔇 This command cannot be used in DMs."),
_ => return,
},
)
.await
.ok(); // Invoke
}