diff --git a/youmubot-cf/src/lib.rs b/youmubot-cf/src/lib.rs index 61c45a8..6404a33 100644 --- a/youmubot-cf/src/lib.rs +++ b/youmubot-cf/src/lib.rs @@ -13,6 +13,7 @@ use serenity::{ use db::{CfSavedUsers, CfUser}; pub use hook::InfoHook; +use youmubot_prelude::announcer::AnnouncerHandler; use youmubot_prelude::table_format::table_formatting_unsafe; use youmubot_prelude::table_format::Align::{Left, Right}; use youmubot_prelude::{ diff --git a/youmubot-core/src/lib.rs b/youmubot-core/src/lib.rs index afeb39d..ce32b17 100644 --- a/youmubot-core/src/lib.rs +++ b/youmubot-core/src/lib.rs @@ -1,10 +1,15 @@ +use std::collections::HashSet; + use serenity::{ framework::standard::{ help_commands, macros::help, Args, CommandGroup, CommandResult, HelpOptions, }, model::{channel::Message, id::UserId}, }; -use std::collections::HashSet; + +pub use admin::ADMIN_GROUP; +pub use community::COMMUNITY_GROUP; +pub use fun::FUN_GROUP; use youmubot_prelude::{announcer::CacheAndHttp, *}; pub mod admin; @@ -12,14 +17,9 @@ pub mod community; mod db; pub mod fun; -pub use admin::ADMIN_GROUP; -pub use community::COMMUNITY_GROUP; -pub use fun::FUN_GROUP; - /// Sets up all databases in the client. pub fn setup( path: &std::path::Path, - client: &serenity::client::Client, data: &mut TypeMap, ) -> serenity::framework::standard::CommandResult { db::SoftBans::insert_into(&mut *data, &path.join("soft_bans.yaml"))?; @@ -29,18 +29,21 @@ pub fn setup( &path.join("roles.yaml"), )?; - // Create handler threads - tokio::spawn(admin::watch_soft_bans( - CacheAndHttp::from_client(client), - client.data.clone(), - )); - // Start reaction handlers data.insert::(community::ReactionWatchers::new(&*data)?); Ok(()) } +pub fn ready_hook(ctx: &Context) -> CommandResult { + // Create handler threads + tokio::spawn(admin::watch_soft_bans( + CacheAndHttp::from_context(ctx), + ctx.data.clone(), + )); + Ok(()) +} + // A help command #[help] pub async fn help( diff --git a/youmubot-osu/src/discord/announcer.rs b/youmubot-osu/src/discord/announcer.rs index ffceb7e..094ef44 100644 --- a/youmubot-osu/src/discord/announcer.rs +++ b/youmubot-osu/src/discord/announcer.rs @@ -15,16 +15,15 @@ use youmubot_prelude::stream::TryStreamExt; use youmubot_prelude::*; use crate::{ - discord::beatmap_cache::BeatmapMetaCache, discord::cache::save_beatmap, - discord::oppai_cache::{BeatmapCache, BeatmapContent}, + discord::oppai_cache::BeatmapContent, models::{Mode, Score, User, UserEventRank}, request::UserID, Client as Osu, }; use super::db::{OsuSavedUsers, OsuUser}; -use super::{calculate_weighted_map_length, OsuClient}; +use super::{calculate_weighted_map_length, OsuEnv}; use super::{embeds::score_embed, BeatmapWithMode}; /// osu! announcer's unique announcer key. @@ -51,9 +50,8 @@ impl youmubot_prelude::Announcer for Announcer { ) -> Result<()> { // For each user... let users = { - let data = d.read().await; - let data = data.get::().unwrap(); - data.all().await? + let env = d.read().await.get::().unwrap().clone(); + env.saved_users.all().await? }; let now = chrono::Utc::now(); users @@ -198,13 +196,12 @@ impl Announcer { } async fn std_weighted_map_length(ctx: &Context, u: &OsuUser) -> Result { - let data = ctx.data.read().await; - let client = data.get::().unwrap().clone(); - let cache = data.get::().unwrap(); - let scores = client + let env = ctx.data.read().await.get::().unwrap().clone(); + let scores = env + .client .user_best(UserID::ID(u.id), |f| f.mode(Mode::Std).limit(100)) .await?; - calculate_weighted_map_length(&scores, cache, Mode::Std).await + calculate_weighted_map_length(&scores, &env.beatmaps, Mode::Std).await } } @@ -282,11 +279,12 @@ impl<'a> CollectedScore<'a> { } async fn get_beatmap(&self, ctx: &Context) -> Result<(BeatmapWithMode, BeatmapContent)> { - let data = ctx.data.read().await; - let cache = data.get::().unwrap(); - let oppai = data.get::().unwrap(); - let beatmap = cache.get_beatmap_default(self.score.beatmap_id).await?; - let content = oppai.get_beatmap(beatmap.beatmap_id).await?; + let env = ctx.data.read().await.get::().unwrap().clone(); + let beatmap = env + .beatmaps + .get_beatmap_default(self.score.beatmap_id) + .await?; + let content = env.oppai.get_beatmap(beatmap.beatmap_id).await?; Ok((BeatmapWithMode(beatmap, self.mode), content)) } @@ -341,9 +339,10 @@ impl<'a> CollectedScore<'a> { }), ) .await?; - save_beatmap(&*ctx.data.read().await, channel, bm) - .await - .pls_ok(); + + let env = ctx.data.read().await.get::().unwrap().clone(); + + save_beatmap(&env, channel, bm).await.pls_ok(); Ok(m) } } diff --git a/youmubot-osu/src/discord/beatmap_cache.rs b/youmubot-osu/src/discord/beatmap_cache.rs index b523844..0a89c1f 100644 --- a/youmubot-osu/src/discord/beatmap_cache.rs +++ b/youmubot-osu/src/discord/beatmap_cache.rs @@ -1,18 +1,27 @@ +use std::sync::Arc; + +use youmubot_db_sql::{models::osu as models, Pool}; +use youmubot_prelude::*; + use crate::{ models::{ApprovalStatus, Beatmap, Mode}, Client, }; -use std::sync::Arc; -use youmubot_db_sql::{models::osu as models, Pool}; -use youmubot_prelude::*; /// BeatmapMetaCache intercepts beatmap-by-id requests and caches them for later recalling. /// Does not cache non-Ranked beatmaps. +#[derive(Clone)] pub struct BeatmapMetaCache { client: Arc, pool: Pool, } +impl std::fmt::Debug for BeatmapMetaCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + impl TypeMapKey for BeatmapMetaCache { type Value = BeatmapMetaCache; } diff --git a/youmubot-osu/src/discord/cache.rs b/youmubot-osu/src/discord/cache.rs index 2778eca..30e5aba 100644 --- a/youmubot-osu/src/discord/cache.rs +++ b/youmubot-osu/src/discord/cache.rs @@ -1,29 +1,26 @@ -use super::db::OsuLastBeatmap; -use super::BeatmapWithMode; use serenity::model::id::ChannelId; + use youmubot_prelude::*; +use super::{BeatmapWithMode, OsuEnv}; + /// Save the beatmap into the server data storage. pub(crate) async fn save_beatmap( - data: &TypeMap, + env: &OsuEnv, channel_id: ChannelId, bm: &BeatmapWithMode, ) -> Result<()> { - data.get::() - .unwrap() - .save(channel_id, &bm.0, bm.1) - .await?; + env.last_beatmaps.save(channel_id, &bm.0, bm.1).await?; Ok(()) } /// Get the last beatmap requested from this channel. pub(crate) async fn get_beatmap( - data: &TypeMap, + env: &OsuEnv, channel_id: ChannelId, ) -> Result> { - data.get::() - .unwrap() + env.last_beatmaps .by_channel(channel_id) .await .map(|v| v.map(|(bm, mode)| BeatmapWithMode(bm, mode))) diff --git a/youmubot-osu/src/discord/db.rs b/youmubot-osu/src/discord/db.rs index ad54c8d..9ff7737 100644 --- a/youmubot-osu/src/discord/db.rs +++ b/youmubot-osu/src/discord/db.rs @@ -1,14 +1,16 @@ use std::borrow::Cow; use chrono::{DateTime, Utc}; -use youmubot_db_sql::{models::osu as models, models::osu_user as model, Pool}; - -use crate::models::{Beatmap, Mode, Score}; use serde::{Deserialize, Serialize}; use serenity::model::id::{ChannelId, UserId}; + +use youmubot_db_sql::{models::osu as models, models::osu_user as model, Pool}; use youmubot_prelude::*; +use crate::models::{Beatmap, Mode, Score}; + /// Save the user IDs. +#[derive(Debug, Clone)] pub struct OsuSavedUsers { pool: Pool, } @@ -60,6 +62,7 @@ impl OsuSavedUsers { } /// Save each channel's last requested beatmap. +#[derive(Debug, Clone)] pub struct OsuLastBeatmap(Pool); impl TypeMapKey for OsuLastBeatmap { @@ -99,6 +102,7 @@ impl OsuLastBeatmap { } /// Save each channel's last requested beatmap. +#[derive(Debug, Clone)] pub struct OsuUserBests(Pool); impl TypeMapKey for OsuUserBests { @@ -188,14 +192,16 @@ impl From for OsuUser { #[allow(dead_code)] mod legacy { - use chrono::{DateTime, Utc}; + use std::collections::HashMap; - use crate::models::{Beatmap, Mode, Score}; + use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use serenity::model::id::{ChannelId, UserId}; - use std::collections::HashMap; + use youmubot_db::DB; + use crate::models::{Beatmap, Mode, Score}; + pub type OsuSavedUsers = DB>; /// An osu! saved user. diff --git a/youmubot-osu/src/discord/display.rs b/youmubot-osu/src/discord/display.rs index 68755ef..1faebed 100644 --- a/youmubot-osu/src/discord/display.rs +++ b/youmubot-osu/src/discord/display.rs @@ -54,9 +54,7 @@ mod scores { use youmubot_prelude::*; - use crate::discord::{ - cache::save_beatmap, BeatmapCache, BeatmapMetaCache, BeatmapWithMode, - }; + use crate::discord::{cache::save_beatmap, BeatmapWithMode, OsuEnv}; use crate::models::{Mode, Score}; pub async fn display_scores_grid<'a>( @@ -88,19 +86,17 @@ mod scores { #[async_trait] impl pagination::Paginate for Paginate { async fn render(&mut self, page: u8, ctx: &Context, msg: &mut Message) -> Result { - let data = ctx.data.read().await; - let client = data.get::().unwrap(); - let osu = data.get::().unwrap(); - let beatmap_cache = data.get::().unwrap(); + let env = ctx.data.read().await.get::().unwrap().clone(); let page = page as usize; let score = &self.scores[page]; let hourglass = msg.react(ctx, '⌛').await?; let mode = self.mode; - let beatmap = osu.get_beatmap(score.beatmap_id, mode).await?; - let content = beatmap_cache.get_beatmap(beatmap.beatmap_id).await?; + let beatmap = env.beatmaps.get_beatmap(score.beatmap_id, mode).await?; + let content = env.oppai.get_beatmap(beatmap.beatmap_id).await?; let bm = BeatmapWithMode(beatmap, mode); - let user = client + let user = env + .client .user(crate::request::UserID::ID(score.user_id), |f| f) .await? .ok_or_else(|| Error::msg("user not found"))?; @@ -114,7 +110,7 @@ mod scores { }), ) .await?; - save_beatmap(&*ctx.data.read().await, msg.channel_id, &bm).await?; + save_beatmap(&env, msg.channel_id, &bm).await?; // End hourglass.delete(ctx).await?; @@ -138,7 +134,7 @@ mod scores { use youmubot_prelude::*; use crate::discord::oppai_cache::Accuracy; - use crate::discord::{Beatmap, BeatmapCache, BeatmapInfo, BeatmapMetaCache}; + use crate::discord::{Beatmap, BeatmapInfo, OsuEnv}; use crate::models::{Mode, Score}; pub async fn display_scores_table<'a>( @@ -178,9 +174,10 @@ mod scores { #[async_trait] impl pagination::Paginate for Paginate { async fn render(&mut self, page: u8, ctx: &Context, msg: &mut Message) -> Result { - let data = ctx.data.read().await; - let osu = data.get::().unwrap(); - let beatmap_cache = data.get::().unwrap(); + let env = ctx.data.read().await.get::().unwrap().clone(); + + let meta_cache = &env.beatmaps; + let oppai = &env.oppai; let page = page as usize; let start = page * ITEMS_PER_PAGE; let end = self.scores.len().min(start + ITEMS_PER_PAGE); @@ -194,9 +191,9 @@ mod scores { let beatmaps = plays .iter() .map(|play| async move { - let beatmap = osu.get_beatmap(play.beatmap_id, mode).await?; + let beatmap = meta_cache.get_beatmap(play.beatmap_id, mode).await?; let info = { - let b = beatmap_cache.get_beatmap(beatmap.beatmap_id).await?; + let b = oppai.get_beatmap(beatmap.beatmap_id).await?; b.get_info_with(mode, play.mods).ok() }; Ok((beatmap, info)) as Result<(Beatmap, Option)> @@ -211,7 +208,7 @@ mod scores { match p.pp.map(|pp| format!("{:.2}", pp)) { Some(v) => Ok(v), None => { - let b = beatmap_cache.get_beatmap(p.beatmap_id).await?; + let b = oppai.get_beatmap(p.beatmap_id).await?; let r: Result<_> = Ok({ b.get_pp_from( mode, @@ -335,10 +332,9 @@ mod beatmapset { use youmubot_prelude::*; + use crate::discord::OsuEnv; use crate::{ - discord::{ - cache::save_beatmap, oppai_cache::BeatmapInfoWithPP, BeatmapCache, BeatmapWithMode, - }, + discord::{cache::save_beatmap, oppai_cache::BeatmapInfoWithPP, BeatmapWithMode}, models::{Beatmap, Mode, Mods}, }; @@ -386,9 +382,9 @@ mod beatmapset { impl Paginate { async fn get_beatmap_info(&self, ctx: &Context, b: &Beatmap) -> Result { - let data = ctx.data.read().await; - let cache = data.get::().unwrap(); - cache + let env = ctx.data.read().await.get::().unwrap().clone(); + + env.oppai .get_beatmap(b.beatmap_id) .await .and_then(move |v| v.get_possible_pp_with(self.mode.unwrap_or(b.mode), self.mods)) @@ -401,15 +397,10 @@ mod beatmapset { Some(self.maps.len()) } - async fn render( - &mut self, - page: u8, - ctx: &Context, - m: &mut serenity::model::channel::Message, - ) -> Result { + async fn render(&mut self, page: u8, ctx: &Context, msg: &mut Message) -> Result { let page = page as usize; if page == self.maps.len() { - m.edit( + msg.edit( ctx, EditMessage::new().embed(crate::discord::embeds::beatmapset_embed( &self.maps[..], @@ -432,8 +423,8 @@ mod beatmapset { info } }; - m.edit(ctx, - EditMessage::new().content(self.message.as_str()).embed( + msg.edit(ctx, + EditMessage::new().content(self.message.as_str()).embed( crate::discord::embeds::beatmap_embed( map, self.mode.unwrap_or(map.mode), @@ -451,9 +442,10 @@ mod beatmapset { ), ) .await?; + let env = ctx.data.read().await.get::().unwrap().clone(); save_beatmap( - &*ctx.data.read().await, - m.channel_id, + &env, + msg.channel_id, &BeatmapWithMode(map.clone(), self.mode.unwrap_or(map.mode)), ) .await diff --git a/youmubot-osu/src/discord/hook.rs b/youmubot-osu/src/discord/hook.rs index da33442..65f6b57 100644 --- a/youmubot-osu/src/discord/hook.rs +++ b/youmubot-osu/src/discord/hook.rs @@ -1,14 +1,17 @@ -use crate::{ - discord::beatmap_cache::BeatmapMetaCache, - discord::oppai_cache::{BeatmapCache, BeatmapInfoWithPP}, - models::{Beatmap, Mode, Mods}, -}; +use std::str::FromStr; + use lazy_static::lazy_static; use regex::Regex; use serenity::{builder::CreateMessage, model::channel::Message, utils::MessageBuilder}; -use std::str::FromStr; + use youmubot_prelude::*; +use crate::discord::OsuEnv; +use crate::{ + discord::oppai_cache::BeatmapInfoWithPP, + models::{Beatmap, Mode, Mods}, +}; + use super::embeds::beatmap_embed; lazy_static! { @@ -43,9 +46,9 @@ pub fn dot_osu_hook<'a>( let url = attachment.url.clone(); async move { - let data = ctx.data.read().await; - let oppai = data.get::().unwrap(); - let (beatmap, _) = oppai.download_beatmap_from_url(&url).await.ok()?; + let env = ctx.data.read().await.get::().unwrap().clone(); + + let (beatmap, _) = env.oppai.download_beatmap_from_url(&url).await.ok()?; crate::discord::embeds::beatmap_offline_embed( &beatmap, Mode::from(beatmap.content.mode as u8), /*For now*/ @@ -68,9 +71,9 @@ pub fn dot_osu_hook<'a>( .map(|attachment| { let url = attachment.url.clone(); async move { - let data = ctx.data.read().await; - let oppai = data.get::().unwrap(); - let beatmaps = oppai.download_osz_from_url(&url).await.pls_ok()?; + let env = ctx.data.read().await.get::().unwrap().clone(); + + let beatmaps = env.oppai.download_osz_from_url(&url).await.pls_ok()?; Some( beatmaps .into_iter() @@ -133,13 +136,12 @@ pub fn hook<'a>( .pls_ok(); let mode = l.mode.unwrap_or(b.mode); let bm = super::BeatmapWithMode(*b, mode); - crate::discord::cache::save_beatmap( - &*ctx.data.read().await, - msg.channel_id, - &bm, - ) - .await - .pls_ok(); + + let env = ctx.data.read().await.get::().unwrap().clone(); + + crate::discord::cache::save_beatmap(&env, msg.channel_id, &bm) + .await + .pls_ok(); } EmbedType::Beatmapset(b) => { handle_beatmapset(ctx, b, l.link, l.mode, msg) @@ -174,8 +176,7 @@ fn handle_old_links<'a>( .captures_iter(content) .map(move |capture| async move { let data = ctx.data.read().await; - let cache = data.get::().unwrap(); - let osu = data.get::().unwrap(); + let env = data.get::().unwrap(); let req_type = capture.name("link_type").unwrap().as_str(); let mode = capture .name("mode") @@ -192,10 +193,18 @@ fn handle_old_links<'a>( }); let beatmaps = match req_type { "b" => vec![match mode { - Some(mode) => osu.get_beatmap(capture["id"].parse()?, mode).await?, - None => osu.get_beatmap_default(capture["id"].parse()?).await?, + Some(mode) => { + env.beatmaps + .get_beatmap(capture["id"].parse()?, mode) + .await? + } + None => { + env.beatmaps + .get_beatmap_default(capture["id"].parse()?) + .await? + } }], - "s" => osu.get_beatmapset(capture["id"].parse()?).await?, + "s" => env.beatmaps.get_beatmapset(capture["id"].parse()?).await?, _ => unreachable!(), }; if beatmaps.is_empty() { @@ -211,7 +220,7 @@ fn handle_old_links<'a>( .unwrap_or(Mods::NOMOD); let info = { let mode = mode.unwrap_or(b.mode); - cache + env.oppai .get_beatmap(b.beatmap_id) .await .and_then(|b| b.get_possible_pp_with(mode, mods))? @@ -233,13 +242,10 @@ fn handle_old_links<'a>( }) .collect::>() .filter_map(|v| { - future::ready(match v { - Ok(v) => v, - Err(e) => { - eprintln!("{}", e); - None - } - }) + future::ready(v.unwrap_or_else(|e| { + eprintln!("{}", e); + None + })) }) } @@ -250,20 +256,23 @@ fn handle_new_links<'a>( NEW_LINK_REGEX .captures_iter(content) .map(|capture| async move { - let data = ctx.data.read().await; - let osu = data.get::().unwrap(); - let cache = data.get::().unwrap(); + let env = ctx.data.read().await.get::().unwrap().clone(); let mode = capture .name("mode") .and_then(|v| Mode::parse_from_new_site(v.as_str())); let link = capture.get(0).unwrap().as_str(); let beatmaps = match capture.name("beatmap_id") { Some(ref v) => vec![match mode { - Some(mode) => osu.get_beatmap(v.as_str().parse()?, mode).await?, - None => osu.get_beatmap_default(v.as_str().parse()?).await?, + Some(mode) => env.beatmaps.get_beatmap(v.as_str().parse()?, mode).await?, + None => { + env.beatmaps + .get_beatmap_default(v.as_str().parse()?) + .await? + } }], None => { - osu.get_beatmapset(capture.name("set_id").unwrap().as_str().parse()?) + env.beatmaps + .get_beatmapset(capture.name("set_id").unwrap().as_str().parse()?) .await? } }; @@ -280,7 +289,7 @@ fn handle_new_links<'a>( .unwrap_or(Mods::NOMOD); let info = { let mode = mode.unwrap_or(beatmap.mode); - cache + env.oppai .get_beatmap(beatmap.beatmap_id) .await .and_then(|b| b.get_possible_pp_with(mode, mods))? @@ -328,16 +337,14 @@ fn handle_short_links<'a>( return Err(Error::msg("not in server announcer channel")); } } - let data = ctx.data.read().await; - let osu = data.get::().unwrap(); - let cache = data.get::().unwrap(); + let env = ctx.data.read().await.get::().unwrap().clone(); let mode = capture .name("mode") .and_then(|v| Mode::parse_from_new_site(v.as_str())); let id: u64 = capture.name("id").unwrap().as_str().parse()?; let beatmap = match mode { - Some(mode) => osu.get_beatmap(id, mode).await, - None => osu.get_beatmap_default(id).await, + Some(mode) => env.beatmaps.get_beatmap(id, mode).await, + None => env.beatmaps.get_beatmap_default(id).await, }?; let mods = capture .name("mods") @@ -345,7 +352,7 @@ fn handle_short_links<'a>( .unwrap_or(Mods::NOMOD); let info = { let mode = mode.unwrap_or(beatmap.mode); - cache + env.oppai .get_beatmap(beatmap.beatmap_id) .await .and_then(|b| b.get_possible_pp_with(mode, mods))? diff --git a/youmubot-osu/src/discord/mod.rs b/youmubot-osu/src/discord/mod.rs index ba777f7..30eb979 100644 --- a/youmubot-osu/src/discord/mod.rs +++ b/youmubot-osu/src/discord/mod.rs @@ -1,11 +1,5 @@ -use crate::{ - discord::beatmap_cache::BeatmapMetaCache, - discord::display::ScoreListStyle, - discord::oppai_cache::{BeatmapCache, BeatmapInfo}, - models::{self, Beatmap, Mode, Mods, Score, User}, - request::{BeatmapRequestKind, UserID}, - Client as OsuHttpClient, -}; +use std::{str::FromStr, sync::Arc}; + use rand::seq::IteratorRandom; use serenity::{ builder::{CreateMessage, EditMessage}, @@ -17,9 +11,24 @@ use serenity::{ model::channel::Message, utils::MessageBuilder, }; -use std::{str::FromStr, sync::Arc}; + +use db::{OsuLastBeatmap, OsuSavedUsers, OsuUser, OsuUserBests}; +use embeds::{beatmap_embed, score_embed, user_embed}; +use hook::SHORT_LINK_REGEX; +pub use hook::{dot_osu_hook, hook}; +use server_rank::{SERVER_RANK_COMMAND, SHOW_LEADERBOARD_COMMAND}; +use youmubot_prelude::announcer::AnnouncerHandler; use youmubot_prelude::{stream::FuturesUnordered, *}; +use crate::{ + discord::beatmap_cache::BeatmapMetaCache, + discord::display::ScoreListStyle, + discord::oppai_cache::{BeatmapCache, BeatmapInfo}, + models::{Beatmap, Mode, Mods, Score, User}, + request::{BeatmapRequestKind, UserID}, + Client as OsuHttpClient, +}; + mod announcer; pub(crate) mod beatmap_cache; mod cache; @@ -30,12 +39,6 @@ mod hook; pub(crate) mod oppai_cache; mod server_rank; -use db::{OsuLastBeatmap, OsuSavedUsers, OsuUser, OsuUserBests}; -use embeds::{beatmap_embed, score_embed, user_embed}; -use hook::SHORT_LINK_REGEX; -pub use hook::{dot_osu_hook, hook}; -use server_rank::{SERVER_RANK_COMMAND, SHOW_LEADERBOARD_COMMAND}; - /// The osu! client. pub(crate) struct OsuClient; @@ -43,6 +46,30 @@ impl TypeMapKey for OsuClient { type Value = Arc; } +/// The environment for osu! app commands. +#[derive(Clone)] +pub struct OsuEnv { + pub(crate) prelude: Env, + // databases + pub(crate) saved_users: OsuSavedUsers, + pub(crate) last_beatmaps: OsuLastBeatmap, + pub(crate) user_bests: OsuUserBests, + // clients + pub(crate) client: Arc, + pub(crate) oppai: BeatmapCache, + pub(crate) beatmaps: BeatmapMetaCache, +} + +impl std::fmt::Debug for OsuEnv { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + +impl TypeMapKey for OsuEnv { + type Value = OsuEnv; +} + /// Sets up the osu! command handling section. /// /// This automatically enables: @@ -55,50 +82,58 @@ impl TypeMapKey for OsuClient { /// - Hooks. Hooks are completely opt-in. /// pub async fn setup( - _path: &std::path::Path, data: &mut TypeMap, + prelude: youmubot_prelude::Env, announcers: &mut AnnouncerHandler, -) -> CommandResult { - let sql_client = data.get::().unwrap().clone(); +) -> Result { // Databases - data.insert::(OsuSavedUsers::new(sql_client.clone())); - data.insert::(OsuLastBeatmap::new(sql_client.clone())); - data.insert::(OsuUserBests::new(sql_client.clone())); + let saved_users = OsuSavedUsers::new(prelude.sql.clone()); + let last_beatmaps = OsuLastBeatmap::new(prelude.sql.clone()); + let user_bests = OsuUserBests::new(prelude.sql.clone()); // API client - let http_client = data.get::().unwrap().clone(); - let mk_osu_client = || async { - Arc::new( - OsuHttpClient::new( - std::env::var("OSU_API_CLIENT_ID") - .expect("Please set OSU_API_CLIENT_ID as osu! api v2 client ID.") - .parse() - .expect("client_id should be u64"), - std::env::var("OSU_API_CLIENT_SECRET") - .expect("Please set OSU_API_CLIENT_SECRET as osu! api v2 client secret."), - ) - .await - .expect("osu! should be initialized"), + let osu_client = Arc::new( + OsuHttpClient::new( + std::env::var("OSU_API_CLIENT_ID") + .expect("Please set OSU_API_CLIENT_ID as osu! api v2 client ID.") + .parse() + .expect("client_id should be u64"), + std::env::var("OSU_API_CLIENT_SECRET") + .expect("Please set OSU_API_CLIENT_SECRET as osu! api v2 client secret."), ) - }; - let osu_client = mk_osu_client().await; - data.insert::(osu_client.clone()); - data.insert::(oppai_cache::BeatmapCache::new( - http_client.clone(), - sql_client.clone(), - )); - data.insert::(beatmap_cache::BeatmapMetaCache::new( - osu_client.clone(), - sql_client, - )); + .await + .expect("osu! should be initialized"), + ); + let oppai_cache = BeatmapCache::new(prelude.http.clone(), prelude.sql.clone()); + let beatmap_cache = BeatmapMetaCache::new(osu_client.clone(), prelude.sql.clone()); // Announcer - let osu_client = mk_osu_client().await; announcers.add( announcer::ANNOUNCER_KEY, - announcer::Announcer::new(osu_client), + announcer::Announcer::new(osu_client.clone()), ); - Ok(()) + + // Legacy data + data.insert::(last_beatmaps.clone()); + data.insert::(saved_users.clone()); + data.insert::(user_bests.clone()); + data.insert::(osu_client.clone()); + data.insert::(oppai_cache.clone()); + data.insert::(beatmap_cache.clone()); + + let env = OsuEnv { + prelude, + saved_users, + last_beatmaps, + user_bests, + client: osu_client, + oppai: oppai_cache, + beatmaps: beatmap_cache, + }; + + data.insert::(env.clone()); + + Ok(env) } #[group] @@ -128,7 +163,8 @@ struct Osu; #[usage = "[username or user_id = your saved username]"] #[max_args(1)] pub async fn std(ctx: &Context, msg: &Message, args: Args) -> CommandResult { - get_user(ctx, msg, args, Mode::Std).await + let env = ctx.data.read().await.get::().unwrap().clone(); + get_user(ctx, &env, msg, args, Mode::Std).await } #[command] @@ -137,7 +173,8 @@ pub async fn std(ctx: &Context, msg: &Message, args: Args) -> CommandResult { #[usage = "[username or user_id = your saved username]"] #[max_args(1)] pub async fn taiko(ctx: &Context, msg: &Message, args: Args) -> CommandResult { - get_user(ctx, msg, args, Mode::Taiko).await + let env = ctx.data.read().await.get::().unwrap().clone(); + get_user(ctx, &env, msg, args, Mode::Taiko).await } #[command] @@ -146,7 +183,8 @@ pub async fn taiko(ctx: &Context, msg: &Message, args: Args) -> CommandResult { #[usage = "[username or user_id = your saved username]"] #[max_args(1)] pub async fn catch(ctx: &Context, msg: &Message, args: Args) -> CommandResult { - get_user(ctx, msg, args, Mode::Catch).await + let env = ctx.data.read().await.get::().unwrap().clone(); + get_user(ctx, &env, msg, args, Mode::Catch).await } #[command] @@ -155,7 +193,8 @@ pub async fn catch(ctx: &Context, msg: &Message, args: Args) -> CommandResult { #[usage = "[username or user_id = your saved username]"] #[max_args(1)] pub async fn mania(ctx: &Context, msg: &Message, args: Args) -> CommandResult { - get_user(ctx, msg, args, Mode::Mania).await + let env = ctx.data.read().await.get::().unwrap().clone(); + get_user(ctx, &env, msg, args, Mode::Mania).await } pub(crate) struct BeatmapWithMode(pub Beatmap, pub Mode); @@ -177,18 +216,18 @@ impl AsRef for BeatmapWithMode { #[usage = "[username or user_id]"] #[num_args(1)] pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let data = ctx.data.read().await; - let osu = data.get::().unwrap(); + let env = ctx.data.read().await.get::().unwrap().clone(); + let osu_client = &env.client; let user = args.single::()?; - let u = match osu.user(UserID::from_string(user), |f| f).await? { + let u = match osu_client.user(UserID::from_string(user), |f| f).await? { Some(u) => u, None => { msg.reply(&ctx, "user not found...").await?; return Ok(()); } }; - async fn find_score(client: &OsuHttpClient, u: &User) -> Result> { + async fn find_score(client: &OsuHttpClient, u: &User) -> Result> { for mode in &[Mode::Std, Mode::Taiko, Mode::Catch, Mode::Mania] { let scores = client .user_best(UserID::ID(u.id), |f| f.mode(*mode)) @@ -199,7 +238,7 @@ pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult } Ok(None) } - let (score, mode) = match find_score(osu, &u).await? { + let (score, mode) = match find_score(osu_client, &u).await? { Some(v) => v, None => { msg.reply( @@ -220,19 +259,27 @@ pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult .any(|s| s.beatmap_id == map_id)) } - let reply = msg.reply(&ctx, format!("To set your osu username, please make your most recent play be the following map: `/b/{}` in **{}** mode! It does **not** have to be a pass, and **NF** can be used! React to this message with 👌 within 5 minutes when you're done!", score.beatmap_id, mode.as_str_new_site())); - let beatmap = osu - .beatmaps( - crate::request::BeatmapRequestKind::Beatmap(score.beatmap_id), - |f| f.mode(mode, true), - ) + let reply = msg.reply( + &ctx, + format!( + "To set your osu username, please make your most recent play \ + be the following map: `/b/{}` in **{}** mode! \ + It does **not** have to be a pass, and **NF** can be used! \ + React to this message with 👌 within 5 minutes when you're done!", + score.beatmap_id, + mode.as_str_new_site() + ), + ); + let beatmap = osu_client + .beatmaps(BeatmapRequestKind::Beatmap(score.beatmap_id), |f| { + f.mode(mode, true) + }) .await? .into_iter() .next() .unwrap(); - let info = data - .get::() - .unwrap() + let info = env + .oppai .get_beatmap(beatmap.beatmap_id) .await? .get_possible_pp_with(mode, Mods::NOMOD)?; @@ -254,7 +301,7 @@ pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult .next() .await; if let Some(ur) = user_reaction { - if check(osu, &u, score.beatmap_id).await? { + if check(osu_client, &u, score.beatmap_id).await? { break true; } ur.delete(&ctx).await?; @@ -268,7 +315,7 @@ pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult } let username = u.username.clone(); - add_user(msg.author.id, u, &data).await?; + add_user(msg.author.id, u, &env).await?; msg.reply( &ctx, MessageBuilder::new() @@ -287,17 +334,19 @@ pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult #[delimiters(" ")] #[num_args(2)] pub async fn forcesave(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let data = ctx.data.read().await; - let osu = data.get::().unwrap(); + let env = ctx.data.read().await.get::().unwrap().clone(); + + let osu_client = &env.client; + let target = args.single::()?.0; let username = args.quoted().trimmed().single::()?; - let user: Option = osu + let user: Option = osu_client .user(UserID::from_string(username.clone()), |f| f) .await?; match user { Some(u) => { - add_user(target, u, &data).await?; + add_user(target, u, &env).await?; msg.reply( &ctx, MessageBuilder::new() @@ -314,11 +363,7 @@ pub async fn forcesave(ctx: &Context, msg: &Message, mut args: Args) -> CommandR Ok(()) } -async fn add_user( - target: serenity::model::id::UserId, - user: models::User, - data: &TypeMap, -) -> Result<()> { +async fn add_user(target: serenity::model::id::UserId, user: User, env: &OsuEnv) -> Result<()> { let u = OsuUser { user_id: target, username: user.username.into(), @@ -328,7 +373,7 @@ async fn add_user( pp: [None, None, None, None], std_weighted_map_length: None, }; - data.get::().unwrap().new_user(u).await?; + env.saved_users.new_user(u).await?; Ok(()) } @@ -349,7 +394,7 @@ impl FromStr for ModeArg { async fn to_user_id_query( s: Option, - data: &TypeMap, + env: &OsuEnv, msg: &Message, ) -> Result { let id = match s { @@ -358,8 +403,7 @@ async fn to_user_id_query( None => msg.author.id, }; - data.get::() - .unwrap() + env.saved_users .by_user_id(id) .await? .map(|u| UserID::ID(u.id)) @@ -393,34 +437,37 @@ impl FromStr for Nth { #[delimiters("/", " ")] #[max_args(4)] pub async fn recent(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let data = ctx.data.read().await; + let env = ctx.data.read().await.get::().unwrap().clone(); + let nth = args.single::().unwrap_or(Nth::All); let style = args.single::().unwrap_or_default(); let mode = args.single::().unwrap_or(ModeArg(Mode::Std)).0; let user = to_user_id_query( args.quoted().trimmed().single::().ok(), - &data, + &env, msg, ) .await?; - let osu = data.get::().unwrap(); - let meta_cache = data.get::().unwrap(); - let oppai = data.get::().unwrap(); - let user = osu + let osu_client = &env.client; + + let user = osu_client .user(user, |f| f.mode(mode)) .await? .ok_or_else(|| Error::msg("User not found"))?; match nth { Nth::Nth(nth) => { - let recent_play = osu + let recent_play = osu_client .user_recent(UserID::ID(user.id), |f| f.mode(mode).limit(nth)) .await? .into_iter() .last() .ok_or_else(|| Error::msg("No such play"))?; - let beatmap = meta_cache.get_beatmap(recent_play.beatmap_id, mode).await?; - let content = oppai.get_beatmap(beatmap.beatmap_id).await?; + let beatmap = env + .beatmaps + .get_beatmap(recent_play.beatmap_id, mode) + .await?; + let content = env.oppai.get_beatmap(beatmap.beatmap_id).await?; let beatmap_mode = BeatmapWithMode(beatmap, mode); msg.channel_id @@ -434,10 +481,10 @@ pub async fn recent(ctx: &Context, msg: &Message, mut args: Args) -> CommandResu .await?; // Save the beatmap... - cache::save_beatmap(&data, msg.channel_id, &beatmap_mode).await?; + cache::save_beatmap(&env, msg.channel_id, &beatmap_mode).await?; } Nth::All => { - let plays = osu + let plays = osu_client .user_recent(UserID::ID(user.id), |f| f.mode(mode).limit(50)) .await?; style.display_scores(plays, mode, ctx, msg).await?; @@ -447,9 +494,9 @@ pub async fn recent(ctx: &Context, msg: &Message, mut args: Args) -> CommandResu } /// Get beatmapset. -struct OptBeatmapset; +struct OptBeatmapSet; -impl FromStr for OptBeatmapset { +impl FromStr for OptBeatmapSet { type Err = Error; fn from_str(s: &str) -> Result { @@ -462,11 +509,9 @@ impl FromStr for OptBeatmapset { /// Load the mentioned beatmap from the given message. pub(crate) async fn load_beatmap( - ctx: &Context, + env: &OsuEnv, msg: &Message, ) -> Option<(BeatmapWithMode, Option)> { - let data = ctx.data.read().await; - if let Some(replied) = &msg.referenced_message { // Try to look for a mention of the replied message. let beatmap_id = SHORT_LINK_REGEX.captures(&replied.content).or_else(|| { @@ -489,8 +534,8 @@ pub(crate) async fn load_beatmap( let mods = caps .name("mods") .and_then(|m| m.as_str().parse::().ok()); - let osu = data.get::().unwrap(); - let bms = osu + let osu_client = &env.client; + let bms = osu_client .beatmaps(BeatmapRequestKind::Beatmap(id), |f| f.maybe_mode(mode)) .await .ok() @@ -499,7 +544,7 @@ pub(crate) async fn load_beatmap( let bm_mode = beatmap.mode; let bm = BeatmapWithMode(beatmap, mode.unwrap_or(bm_mode)); // Store the beatmap in history - cache::save_beatmap(&data, msg.channel_id, &bm) + cache::save_beatmap(&env, msg.channel_id, &bm) .await .pls_ok(); @@ -508,7 +553,7 @@ pub(crate) async fn load_beatmap( } } - let b = cache::get_beatmap(&data, msg.channel_id) + let b = cache::get_beatmap(&env, msg.channel_id) .await .ok() .flatten(); @@ -522,16 +567,16 @@ pub(crate) async fn load_beatmap( #[delimiters(" ")] #[max_args(2)] pub async fn last(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let data = ctx.data.read().await; - let b = load_beatmap(ctx, msg).await; - let beatmapset = args.find::().is_ok(); + let env = ctx.data.read().await.get::().unwrap().clone(); + + let b = load_beatmap(&env, msg).await; + let beatmapset = args.find::().is_ok(); match b { Some((BeatmapWithMode(b, m), mods_def)) => { let mods = args.find::().ok().or(mods_def).unwrap_or(Mods::NOMOD); if beatmapset { - let beatmap_cache = data.get::().unwrap(); - let beatmapset = beatmap_cache.get_beatmapset(b.beatmapset_id).await?; + let beatmapset = env.beatmaps.get_beatmapset(b.beatmapset_id).await?; display::display_beatmapset( ctx, beatmapset, @@ -543,9 +588,8 @@ pub async fn last(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult .await?; return Ok(()); } - let info = data - .get::() - .unwrap() + let info = env + .oppai .get_beatmap(b.beatmap_id) .await? .get_possible_pp_with(m, mods)?; @@ -574,8 +618,8 @@ pub async fn last(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult #[description = "Check your own or someone else's best record on the last beatmap. Also stores the result if possible."] #[max_args(3)] pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let data = ctx.data.read().await; - let bm = load_beatmap(ctx, msg).await; + let env = ctx.data.read().await.get::().unwrap().clone(); + let bm = load_beatmap(&env, msg).await; let bm = match bm { Some((bm, _)) => bm, @@ -598,15 +642,15 @@ pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResul None => Some(msg.author.id), _ => None, }; - let user = to_user_id_query(username_arg, &data, msg).await?; + let user = to_user_id_query(username_arg, &env, msg).await?; - let osu = data.get::().unwrap(); + let osu_client = env.client; - let user = osu + let user = osu_client .user(user, |f| f) .await? .ok_or_else(|| Error::msg("User not found"))?; - let mut scores = osu + let mut scores = osu_client .scores(b.beatmap_id, |f| f.user(UserID::ID(user.id)).mode(m)) .await? .into_iter() @@ -625,8 +669,7 @@ pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResul if let Some(user_id) = user_id { // Save to database - data.get::() - .unwrap() + env.user_bests .save(user_id, m, scores.clone()) .await .pls_ok(); @@ -644,7 +687,7 @@ pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResul #[example = "#2 / taiko / natsukagami"] #[max_args(4)] pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let data = ctx.data.read().await; + let env = ctx.data.read().await.get::().unwrap().clone(); let nth = args.single::().unwrap_or(Nth::All); let style = args.single::().unwrap_or_default(); let mode = args @@ -652,19 +695,16 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult .map(|ModeArg(t)| t) .unwrap_or(Mode::Std); - let user = to_user_id_query(args.single::().ok(), &data, msg).await?; - let meta_cache = data.get::().unwrap(); - let osu = data.get::().unwrap(); - - let oppai = data.get::().unwrap(); - let user = osu + let user = to_user_id_query(args.single::().ok(), &env, msg).await?; + let osu_client = &env.client; + let user = osu_client .user(user, |f| f.mode(mode)) .await? .ok_or_else(|| Error::msg("User not found"))?; match nth { Nth::Nth(nth) => { - let top_play = osu + let top_play = osu_client .user_best(UserID::ID(user.id), |f| f.mode(mode).limit(nth)) .await?; @@ -674,8 +714,8 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult .into_iter() .last() .ok_or_else(|| Error::msg("No such play"))?; - let beatmap = meta_cache.get_beatmap(top_play.beatmap_id, mode).await?; - let content = oppai.get_beatmap(beatmap.beatmap_id).await?; + let beatmap = env.beatmaps.get_beatmap(top_play.beatmap_id, mode).await?; + let content = env.oppai.get_beatmap(beatmap.beatmap_id).await?; let beatmap = BeatmapWithMode(beatmap, mode); msg.channel_id @@ -694,10 +734,10 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult .await?; // Save the beatmap... - cache::save_beatmap(&data, msg.channel_id, &beatmap).await?; + cache::save_beatmap(&env, msg.channel_id, &beatmap).await?; } Nth::All => { - let plays = osu + let plays = osu_client .user_best(UserID::ID(user.id), |f| f.mode(mode).limit(100)) .await?; style.display_scores(plays, mode, ctx, msg).await?; @@ -712,34 +752,39 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult #[usage = "[--oppai to clear oppai cache as well]"] #[max_args(1)] pub async fn clean_cache(ctx: &Context, msg: &Message, args: Args) -> CommandResult { - let data = ctx.data.read().await; - let meta_cache = data.get::().unwrap(); - meta_cache.clear().await?; + let env = ctx.data.read().await.get::().unwrap().clone(); + env.beatmaps.clear().await?; + if args.remains() == Some("--oppai") { - let oppai = data.get::().unwrap(); - oppai.clear().await?; + env.oppai.clear().await?; } msg.reply_ping(ctx, "Beatmap cache cleared!").await?; Ok(()) } -async fn get_user(ctx: &Context, msg: &Message, mut args: Args, mode: Mode) -> CommandResult { - let data = ctx.data.read().await; - let user = to_user_id_query(args.single::().ok(), &data, msg).await?; - let osu = data.get::().unwrap(); - let cache = data.get::().unwrap(); - let user = osu.user(user, |f| f.mode(mode)).await?; - let oppai = data.get::().unwrap(); +async fn get_user( + ctx: &Context, + env: &OsuEnv, + msg: &Message, + mut args: Args, + mode: Mode, +) -> CommandResult { + let user = to_user_id_query(args.single::().ok(), &env, msg).await?; + let osu_client = &env.client; + let meta_cache = &env.beatmaps; + let user = osu_client.user(user, |f| f.mode(mode)).await?; + match user { Some(u) => { - let bests = osu + let bests = osu_client .user_best(UserID::ID(u.id), |f| f.limit(100).mode(mode)) .await?; - let map_length = calculate_weighted_map_length(&bests, cache, mode).await?; + let map_length = calculate_weighted_map_length(&bests, meta_cache, mode).await?; let best = match bests.into_iter().next() { Some(m) => { - let beatmap = cache.get_beatmap(m.beatmap_id, mode).await?; - let info = oppai + let beatmap = meta_cache.get_beatmap(m.beatmap_id, mode).await?; + let info = env + .oppai .get_beatmap(m.beatmap_id) .await? .get_info_with(mode, m.mods)?; diff --git a/youmubot-osu/src/discord/oppai_cache.rs b/youmubot-osu/src/discord/oppai_cache.rs index 54d86ed..1a788e9 100644 --- a/youmubot-osu/src/discord/oppai_cache.rs +++ b/youmubot-osu/src/discord/oppai_cache.rs @@ -1,15 +1,18 @@ -use crate::{models::Mode, mods::Mods}; +use std::io::Read; +use std::sync::Arc; + use osuparse::MetadataSection; use rosu_pp::catch::CatchDifficultyAttributes; use rosu_pp::mania::ManiaDifficultyAttributes; use rosu_pp::osu::OsuDifficultyAttributes; use rosu_pp::taiko::TaikoDifficultyAttributes; use rosu_pp::{AttributeProvider, Beatmap, CatchPP, DifficultyAttributes, ManiaPP, OsuPP, TaikoPP}; -use std::io::Read; -use std::sync::Arc; + use youmubot_db_sql::{models::osu as models, Pool}; use youmubot_prelude::*; +use crate::{models::Mode, mods::Mods}; + /// the information collected from a download/Oppai request. #[derive(Debug)] pub struct BeatmapContent { @@ -37,7 +40,8 @@ impl BeatmapInfo { #[derive(Clone, Copy, Debug)] pub enum Accuracy { - ByCount(u64, u64, u64, u64), // 300 / 100 / 50 / misses + ByCount(u64, u64, u64, u64), + // 300 / 100 / 50 / misses #[allow(dead_code)] ByValue(f64, u64), } @@ -159,6 +163,7 @@ impl<'a> PPCalc<'a> for OsuPP<'a> { self.calculate().difficulty } } + impl<'a> PPCalc<'a> for TaikoPP<'a> { type Attrs = TaikoDifficultyAttributes; @@ -193,6 +198,7 @@ impl<'a> PPCalc<'a> for TaikoPP<'a> { self.calculate().difficulty } } + impl<'a> PPCalc<'a> for CatchPP<'a> { type Attrs = CatchDifficultyAttributes; @@ -227,6 +233,7 @@ impl<'a> PPCalc<'a> for CatchPP<'a> { self.calculate().difficulty } } + impl<'a> PPCalc<'a> for ManiaPP<'a> { type Attrs = ManiaDifficultyAttributes; @@ -304,6 +311,7 @@ impl BeatmapContent { } /// A central cache for the beatmaps. +#[derive(Debug, Clone)] pub struct BeatmapCache { client: ratelimit::Ratelimit, pool: Pool, diff --git a/youmubot-osu/src/discord/server_rank.rs b/youmubot-osu/src/discord/server_rank.rs index 2d3307a..5d1e991 100644 --- a/youmubot-osu/src/discord/server_rank.rs +++ b/youmubot-osu/src/discord/server_rank.rs @@ -15,15 +15,12 @@ use youmubot_prelude::{ }; use crate::{ - discord::{ - display::ScoreListStyle, - oppai_cache::{Accuracy, BeatmapCache}, - }, + discord::{display::ScoreListStyle, oppai_cache::Accuracy}, models::{Mode, Mods}, request::UserID, }; -use super::{db::OsuSavedUsers, ModeArg, OsuClient}; +use super::{ModeArg, OsuEnv}; #[derive(Debug, Clone, Copy)] enum RankQuery { @@ -50,21 +47,23 @@ impl FromStr for RankQuery { #[max_args(1)] #[only_in(guilds)] pub async fn server_rank(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { - let data = ctx.data.read().await; + let env = ctx.data.read().await.get::().unwrap().clone(); let mode = args .single::() .unwrap_or(RankQuery::Mode(Mode::Std)); let guild = m.guild_id.expect("Guild-only command"); - let member_cache = data.get::().unwrap(); - let osu_users = data - .get::() - .unwrap() + + let osu_users = env + .saved_users .all() .await? .into_iter() .map(|v| (v.user_id, v)) .collect::>(); - let users = member_cache + + let users = env + .prelude + .members .query_members(&ctx, guild) .await? .iter() @@ -102,7 +101,7 @@ pub async fn server_rank(ctx: &Context, m: &Message, mut args: Args) -> CommandR return Ok(()); } - let users = std::sync::Arc::new(users); + let users = Arc::new(users); let last_update = last_update.unwrap(); paginate_reply_fn( move |page: u8, ctx: &Context, m: &mut Message| { @@ -197,7 +196,7 @@ impl Default for OrderBy { } } -impl std::str::FromStr for OrderBy { +impl FromStr for OrderBy { type Err = Error; fn from_str(s: &str) -> Result { @@ -215,55 +214,57 @@ impl std::str::FromStr for OrderBy { #[description = "See the server's ranks on the last seen beatmap"] #[max_args(2)] #[only_in(guilds)] -pub async fn show_leaderboard(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { +pub async fn show_leaderboard(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let order = args.single::().unwrap_or_default(); let style = args.single::().unwrap_or_default(); - let data = ctx.data.read().await; - let member_cache = data.get::().unwrap(); + let env = ctx.data.read().await.get::().unwrap().clone(); - let (bm, _) = match super::load_beatmap(ctx, m).await { + let (bm, _) = match super::load_beatmap(&env, msg).await { Some((bm, mods_def)) => { let mods = args.find::().ok().or(mods_def).unwrap_or(Mods::NOMOD); (bm, mods) } None => { - m.reply(&ctx, "No beatmap queried on this channel.").await?; + msg.reply(&ctx, "No beatmap queried on this channel.") + .await?; return Ok(()); } }; - let osu = data.get::().unwrap().clone(); + let osu_client = env.client.clone(); // Get oppai map. let mode = bm.1; - let oppai = data.get::().unwrap(); + let oppai = env.oppai; let oppai_map = oppai.get_beatmap(bm.0.beatmap_id).await?; - let guild = m.guild_id.expect("Guild-only command"); + let guild = msg.guild_id.expect("Guild-only command"); let scores = { const NO_SCORES: &str = "No scores have been recorded for this beatmap."; // Signal that we are running. - let running_reaction = m.react(&ctx, '⌛').await?; + let running_reaction = msg.react(&ctx, '⌛').await?; - let osu_users = data - .get::() - .unwrap() + let osu_users = env + .saved_users .all() .await? .into_iter() .map(|v| (v.user_id, v)) .collect::>(); - let mut scores = member_cache + let mut scores = env + .prelude + .members .query_members(&ctx, guild) .await? .iter() .filter_map(|m| osu_users.get(&m.user.id).map(|ou| (m.distinct(), ou.id))) .map(|(mem, osu_id)| { - osu.scores(bm.0.beatmap_id, move |f| { - f.user(UserID::ID(osu_id)).mode(bm.1) - }) - .map(|r| Some((mem, r.ok()?))) + osu_client + .scores(bm.0.beatmap_id, move |f| { + f.user(UserID::ID(osu_id)).mode(bm.1) + }) + .map(|r| Some((mem, r.ok()?))) }) .collect::>() .filter_map(future::ready) @@ -300,7 +301,7 @@ pub async fn show_leaderboard(ctx: &Context, m: &Message, mut args: Args) -> Com running_reaction.delete(&ctx).await?; if scores.is_empty() { - m.reply(&ctx, NO_SCORES).await?; + msg.reply(&ctx, NO_SCORES).await?; return Ok(()); } match order { @@ -315,7 +316,7 @@ pub async fn show_leaderboard(ctx: &Context, m: &Message, mut args: Args) -> Com }; if scores.is_empty() { - m.reply( + msg.reply( &ctx, "No scores have been recorded for this beatmap. Run `osu check` to scan for yours!", ) @@ -329,7 +330,7 @@ pub async fn show_leaderboard(ctx: &Context, m: &Message, mut args: Args) -> Com scores.into_iter().map(|(_, _, a)| a).collect(), mode, ctx, - m, + msg, ) .await?; return Ok(()); @@ -409,7 +410,7 @@ pub async fn show_leaderboard(ctx: &Context, m: &Message, mut args: Args) -> Com }) }, ctx, - m, + msg, std::time::Duration::from_secs(60), ) .await?; diff --git a/youmubot-osu/src/lib.rs b/youmubot-osu/src/lib.rs index 1dbe4e0..ec87269 100644 --- a/youmubot-osu/src/lib.rs +++ b/youmubot-osu/src/lib.rs @@ -1,16 +1,19 @@ -pub mod discord; -pub mod models; -pub mod request; +use std::convert::TryInto; +use std::sync::Arc; use models::*; use request::builders::*; use request::*; -use std::convert::TryInto; use youmubot_prelude::*; +pub mod discord; +pub mod models; +pub mod request; + /// Client is the client that will perform calls to the osu! api server. +#[derive(Clone)] pub struct Client { - rosu: rosu_v2::Osu, + rosu: Arc, } pub fn vec_try_into>(v: Vec) -> Result, T::Error> { @@ -31,7 +34,9 @@ impl Client { .client_secret(client_secret) .build() .await?; - Ok(Client { rosu }) + Ok(Client { + rosu: Arc::new(rosu), + }) } pub async fn beatmaps( diff --git a/youmubot-prelude/src/announcer.rs b/youmubot-prelude/src/announcer.rs index fc51038..5c86f83 100644 --- a/youmubot-prelude/src/announcer.rs +++ b/youmubot-prelude/src/announcer.rs @@ -1,4 +1,5 @@ -use crate::{AppData, MemberCache, Result}; +use std::{collections::HashMap, sync::Arc}; + use async_trait::async_trait; use futures_util::{ future::{join_all, ready, FutureExt}, @@ -18,9 +19,11 @@ use serenity::{ prelude::*, utils::MessageBuilder, }; -use std::{collections::HashMap, sync::Arc}; + use youmubot_db::DB; +use crate::{AppData, MemberCache, Result}; + #[derive(Debug, Clone)] pub struct CacheAndHttp(Arc, Arc); @@ -28,15 +31,19 @@ impl CacheAndHttp { pub fn from_client(client: &Client) -> Self { Self(client.cache.clone(), client.http.clone()) } + + pub fn from_context(context: &Context) -> Self { + Self(context.cache.clone(), context.http.clone()) + } } impl CacheHttp for CacheAndHttp { - fn cache(&self) -> Option<&Arc> { - Some(&self.0) - } fn http(&self) -> &Http { &self.1 } + fn cache(&self) -> Option<&Arc> { + Some(&self.0) + } } /// A list of assigned channels for an announcer. @@ -94,23 +101,14 @@ impl MemberToChannels { /// /// This struct manages the list of all Announcers, firing them in a certain interval. pub struct AnnouncerHandler { - cache_http: CacheAndHttp, - data: AppData, announcers: HashMap<&'static str, RwLock>>, } -// Querying for the AnnouncerHandler in the internal data returns a vec of keys. -impl TypeMapKey for AnnouncerHandler { - type Value = Vec<&'static str>; -} - /// Announcer-managing related. impl AnnouncerHandler { /// Create a new instance of the handler. - pub fn new(client: &serenity::Client) -> Self { + pub fn new() -> Self { Self { - cache_http: CacheAndHttp(client.cache.clone(), client.http.clone()), - data: client.data.clone(), announcers: HashMap::new(), } } @@ -136,10 +134,30 @@ impl AnnouncerHandler { self } } + + pub fn run(self, client: &Client) -> AnnouncerRunner { + let runner = AnnouncerRunner { + cache_http: CacheAndHttp::from_client(client), + data: client.data.clone(), + announcers: self.announcers, + }; + runner + } +} + +pub struct AnnouncerRunner { + cache_http: CacheAndHttp, + data: AppData, + announcers: HashMap<&'static str, RwLock>>, +} + +// Querying for the AnnouncerRunner in the internal data returns a vec of keys. +impl TypeMapKey for AnnouncerRunner { + type Value = Vec<&'static str>; } /// Execution-related. -impl AnnouncerHandler { +impl AnnouncerRunner { /// Collect the list of guilds and their respective channels, by the key of the announcer. async fn get_guilds(data: &AppData, key: &'static str) -> Result> { let data = AnnouncerChannels::open(&*data.read().await) @@ -214,7 +232,7 @@ pub async fn list_announcers(ctx: &Context, m: &Message, _: Args) -> CommandResu let guild_id = m.guild_id.unwrap(); let data = &*ctx.data.read().await; let announcers = AnnouncerChannels::open(data); - let channels = data.get::().unwrap(); + let channels = data.get::().unwrap(); let channels = channels .iter() .filter_map(|&key| { @@ -249,7 +267,7 @@ pub async fn list_announcers(ctx: &Context, m: &Message, _: Args) -> CommandResu pub async fn register_announcer(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { let data = ctx.data.read().await; let key = args.single::()?; - let keys = data.get::().unwrap(); + let keys = data.get::().unwrap(); if !keys.contains(&&key[..]) { m.reply( &ctx, @@ -296,7 +314,7 @@ pub async fn register_announcer(ctx: &Context, m: &Message, mut args: Args) -> C pub async fn remove_announcer(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { let data = ctx.data.read().await; let key = args.single::()?; - let keys = data.get::().unwrap(); + let keys = data.get::().unwrap(); if !keys.contains(&key.as_str()) { m.reply( &ctx, diff --git a/youmubot-prelude/src/lib.rs b/youmubot-prelude/src/lib.rs index 6ca7c07..c6f295f 100644 --- a/youmubot-prelude/src/lib.rs +++ b/youmubot-prelude/src/lib.rs @@ -1,7 +1,24 @@ +use std::sync::Arc; + +/// Re-export the anyhow errors +pub use anyhow::{anyhow as error, bail, Error, Result}; +/// Re-exporting async_trait helps with implementing Announcer. +pub use async_trait::async_trait; +/// Re-export useful future and stream utils +pub use futures_util::{future, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; /// Module `prelude` provides a sane set of default imports that can be used inside /// a Youmubot source file. pub use serenity::prelude::*; -use std::sync::Arc; +/// Re-export the spawn function +pub use tokio::spawn as spawn_future; + +pub use announcer::{Announcer, AnnouncerRunner}; +pub use args::{ChannelId, Duration, RoleId, UserId, UsernameArg}; +pub use debugging_ok::OkPrint; +pub use flags::Flags; +pub use hook::Hook; +pub use member_cache::MemberCache; +pub use pagination::{paginate, paginate_fn, paginate_reply, paginate_reply_fn, Paginate}; pub mod announcer; pub mod args; @@ -13,26 +30,6 @@ pub mod ratelimit; pub mod setup; pub mod table_format; -pub use announcer::{Announcer, AnnouncerHandler}; -pub use args::{ChannelId, Duration, RoleId, UserId, UsernameArg}; -pub use flags::Flags; -pub use hook::Hook; -pub use member_cache::MemberCache; -pub use pagination::{paginate, paginate_fn, paginate_reply, paginate_reply_fn, Paginate}; - -/// Re-exporting async_trait helps with implementing Announcer. -pub use async_trait::async_trait; - -/// Re-export the anyhow errors -pub use anyhow::{anyhow as error, bail, Error, Result}; -pub use debugging_ok::OkPrint; - -/// Re-export useful future and stream utils -pub use futures_util::{future, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; - -/// Re-export the spawn function -pub use tokio::spawn as spawn_future; - /// The global app data. pub type AppData = Arc>; @@ -50,8 +47,18 @@ impl TypeMapKey for SQLClient { type Value = youmubot_db_sql::Pool; } +/// The created base environment. +#[derive(Debug, Clone)] +pub struct Env { + // clients + pub http: reqwest::Client, + pub sql: youmubot_db_sql::Pool, + pub members: Arc, + // databases + // pub(crate) announcer_channels: announcer::AnnouncerChannels, +} + pub mod prelude_commands { - use crate::announcer::ANNOUNCERCOMMANDS_GROUP; use serenity::{ framework::standard::{ macros::{command, group}, @@ -61,6 +68,8 @@ pub mod prelude_commands { prelude::Context, }; + use crate::announcer::ANNOUNCERCOMMANDS_GROUP; + #[group("Prelude")] #[description = "All the commands that makes the base of Youmu"] #[commands(ping)] diff --git a/youmubot-prelude/src/ratelimit.rs b/youmubot-prelude/src/ratelimit.rs index 47b6cc3..122f96e 100644 --- a/youmubot-prelude/src/ratelimit.rs +++ b/youmubot-prelude/src/ratelimit.rs @@ -1,12 +1,14 @@ +use std::ops::Deref; /// Provides a simple ratelimit lock (that only works in tokio) // use tokio::time:: use std::time::Duration; -use crate::Result; use flume::{bounded as channel, Receiver, Sender}; -use std::ops::Deref; + +use crate::Result; /// Holds the underlying `T` in a rate-limited way. +#[derive(Debug, Clone)] pub struct Ratelimit { inner: T, recv: Receiver<()>, diff --git a/youmubot-prelude/src/setup.rs b/youmubot-prelude/src/setup.rs index 3c5c873..62494b9 100644 --- a/youmubot-prelude/src/setup.rs +++ b/youmubot-prelude/src/setup.rs @@ -1,6 +1,9 @@ -use serenity::prelude::*; use std::{path::Path, time::Duration}; +use serenity::prelude::*; + +use crate::Env; + /// Set up the prelude libraries. /// /// Panics on failure: Youmubot should *NOT* attempt to continue when this function fails. @@ -8,8 +11,8 @@ pub async fn setup_prelude( db_path: impl AsRef, sql_path: impl AsRef, data: &mut TypeMap, -) { - // Setup the announcer DB. +) -> Env { + // Set up the announcer DB. crate::announcer::AnnouncerChannels::insert_into( data, db_path.as_ref().join("announcers.yaml"), @@ -22,17 +25,25 @@ pub async fn setup_prelude( .expect("SQL database set up"); // Set up the HTTP client. - data.insert::( - reqwest::ClientBuilder::new() - .connect_timeout(Duration::from_secs(5)) - .timeout(Duration::from_secs(60)) - .build() - .expect("Build be able to build HTTP client"), - ); + let http_client = reqwest::ClientBuilder::new() + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(60)) + .build() + .expect("Build be able to build HTTP client"); + data.insert::(http_client.clone()); // Set up the member cache. - data.insert::(std::sync::Arc::new(crate::MemberCache::default())); + let member_cache = std::sync::Arc::new(crate::MemberCache::default()); + data.insert::(member_cache.clone()); // Set up the SQL client. - data.insert::(sql_pool); + data.insert::(sql_pool.clone()); + + let env = Env { + http: http_client, + sql: sql_pool, + members: member_cache, + }; + + env } diff --git a/youmubot/src/main.rs b/youmubot/src/main.rs index 1a18019..a719c5f 100644 --- a/youmubot/src/main.rs +++ b/youmubot/src/main.rs @@ -9,37 +9,58 @@ use serenity::{ permissions::Permissions, }, }; + +use youmubot_prelude::announcer::AnnouncerHandler; use youmubot_prelude::*; struct Handler { hooks: Vec>>, + ready_hooks: Vec CommandResult>, } impl Handler { fn new() -> Handler { - Handler { hooks: vec![] } + Handler { + hooks: vec![], + ready_hooks: vec![], + } } fn push_hook(&mut self, f: T) { self.hooks.push(RwLock::new(Box::new(f))); } + + fn push_ready_hook(&mut self, f: fn(&Context) -> CommandResult) { + self.ready_hooks.push(f); + } +} + +/// Environment to be passed into the framework +#[derive(Debug, Clone)] +struct Env { + prelude: youmubot_prelude::Env, + #[cfg(feature = "osu")] + osu: youmubot_osu::discord::OsuEnv, +} + +impl AsRef for Env { + fn as_ref(&self) -> &youmubot_prelude::Env { + &self.prelude + } +} + +impl AsRef for Env { + fn as_ref(&self) -> &youmubot_osu::discord::OsuEnv { + &self.osu + } +} + +impl TypeMapKey for Env { + type Value = Env; } #[async_trait] impl EventHandler for Handler { - async fn ready(&self, ctx: Context, ready: gateway::Ready) { - // Start ReactionWatchers for community. - #[cfg(feature = "core")] - ctx.data - .read() - .await - .get::() - .unwrap() - .init(&ctx) - .await; - println!("{} is connected!", ready.user.name); - } - async fn message(&self, ctx: Context, message: Message) { self.hooks .iter() @@ -57,6 +78,23 @@ impl EventHandler for Handler { }) .await; } + + async fn ready(&self, ctx: Context, ready: gateway::Ready) { + // Start ReactionWatchers for community. + #[cfg(feature = "core")] + ctx.data + .read() + .await + .get::() + .unwrap() + .init(&ctx) + .await; + println!("{} is connected!", ready.user.name); + + for f in &self.ready_hooks { + f(&ctx).pls_ok(); + } + } } /// Returns whether the user has "MANAGE_MESSAGES" permission in the channel. @@ -79,16 +117,70 @@ async fn main() { } let mut handler = Handler::new(); + #[cfg(feature = "core")] + handler.push_ready_hook(youmubot_core::ready_hook); // Set up hooks #[cfg(feature = "osu")] - handler.push_hook(youmubot_osu::discord::hook); - #[cfg(feature = "osu")] - handler.push_hook(youmubot_osu::discord::dot_osu_hook); + { + handler.push_hook(youmubot_osu::discord::hook); + handler.push_hook(youmubot_osu::discord::dot_osu_hook); + } #[cfg(feature = "codeforces")] handler.push_hook(youmubot_cf::InfoHook); // Collect the token let token = var("TOKEN").expect("Please set TOKEN as the Discord Bot's token to be used."); + + // Data to be put into context + let mut data = TypeMap::new(); + + // Set up announcer handler + let mut announcers = AnnouncerHandler::new(); + + // Setup each package starting from the prelude. + let env = { + let db_path = var("DBPATH") + .map(std::path::PathBuf::from) + .unwrap_or_else(|e| { + println!("No DBPATH set up ({:?}), using `/data`", e); + std::path::PathBuf::from("/data") + }); + let sql_path = var("SQLPATH") + .map(std::path::PathBuf::from) + .unwrap_or_else(|e| { + let res = db_path.join("youmubot.db"); + println!("No SQLPATH set up ({:?}), using `{:?}`", e, res); + res + }); + let prelude = setup::setup_prelude(&db_path, sql_path, &mut data).await; + // Setup core + #[cfg(feature = "core")] + youmubot_core::setup(&db_path, &mut data).expect("Setup db should succeed"); + // osu! + #[cfg(feature = "osu")] + let osu = youmubot_osu::discord::setup(&mut data, prelude.clone(), &mut announcers) + .await + .expect("osu! is initialized"); + // codeforces + #[cfg(feature = "codeforces")] + youmubot_cf::setup(&db_path, &mut data, &mut announcers).await; + + Env { + prelude, + #[cfg(feature = "osu")] + osu, + } + }; + + data.insert::(env); + + #[cfg(feature = "core")] + println!("Core enabled."); + #[cfg(feature = "osu")] + println!("osu! enabled."); + #[cfg(feature = "codeforces")] + println!("codeforces enabled."); + // Set up base framework let fw = setup_framework(&token[..]).await; @@ -105,60 +197,20 @@ async fn main() { | GatewayIntents::DIRECT_MESSAGES | GatewayIntents::DIRECT_MESSAGE_REACTIONS; Client::builder(token, intents) + .type_map(data) .framework(fw) .event_handler(handler) .await .unwrap() }; - // Set up announcer handler - let mut announcers = AnnouncerHandler::new(&client); - - // Setup each package starting from the prelude. - { - let mut data = client.data.write().await; - let db_path = var("DBPATH") - .map(std::path::PathBuf::from) - .unwrap_or_else(|e| { - println!("No DBPATH set up ({:?}), using `/data`", e); - std::path::PathBuf::from("/data") - }); - let sql_path = var("SQLPATH") - .map(std::path::PathBuf::from) - .unwrap_or_else(|e| { - let res = db_path.join("youmubot.db"); - println!("No SQLPATH set up ({:?}), using `{:?}`", e, res); - res - }); - youmubot_prelude::setup::setup_prelude(&db_path, sql_path, &mut data).await; - // Setup core - #[cfg(feature = "core")] - youmubot_core::setup(&db_path, &client, &mut data).expect("Setup db should succeed"); - // osu! - #[cfg(feature = "osu")] - youmubot_osu::discord::setup(&db_path, &mut data, &mut announcers) - .await - .expect("osu! is initialized"); - // codeforces - #[cfg(feature = "codeforces")] - youmubot_cf::setup(&db_path, &mut data, &mut announcers).await; - } - - #[cfg(feature = "core")] - println!("Core enabled."); - #[cfg(feature = "osu")] - println!("osu! enabled."); - #[cfg(feature = "codeforces")] - println!("codeforces enabled."); - + let announcers = announcers.run(&client); tokio::spawn(announcers.scan(std::time::Duration::from_secs(300))); println!("Starting..."); if let Err(v) = client.start().await { panic!("{}", v) } - - println!("Hello, world!"); } // Sets up a framework for a client