diff --git a/Cargo.lock b/Cargo.lock index fa25bcc..ec8b79c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -440,6 +440,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.48", +] + +[[package]] +name = "darling_macro" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.48", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -481,6 +516,17 @@ dependencies = [ "serde", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "digest" version = "0.10.7" @@ -1013,6 +1059,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.5.0" @@ -1543,6 +1595,35 @@ version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" +[[package]] +name = "poise" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1819d5a45e3590ef33754abce46432570c54a120798bdbf893112b4211fa09a6" +dependencies = [ + "async-trait", + "derivative", + "futures-util", + "parking_lot", + "poise_macros", + "regex", + "serenity", + "tokio", + "tracing", +] + +[[package]] +name = "poise_macros" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fa2c123c961e78315cd3deac7663177f12be4460f5440dbf62a7ed37b1effea" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -2346,6 +2427,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "subtle" version = "2.5.0" @@ -3106,6 +3193,7 @@ version = "0.1.0" dependencies = [ "dotenv", "env_logger", + "poise", "serenity", "tokio", "youmubot-cf", @@ -3183,6 +3271,7 @@ dependencies = [ "dashmap", "lazy_static", "osuparse", + "poise", "rand", "regex", "reqwest", @@ -3208,6 +3297,7 @@ dependencies = [ "dashmap", "flume 0.10.14", "futures-util", + "poise", "reqwest", "serenity", "tokio", diff --git a/youmubot-core/src/db.rs b/youmubot-core/src/db.rs index 0b4d9c1..1fc0d73 100644 --- a/youmubot-core/src/db.rs +++ b/youmubot-core/src/db.rs @@ -5,7 +5,7 @@ use serenity::model::{ channel::ReactionType, id::{MessageId, RoleId, UserId}, }; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use youmubot_db::{GuildMap, DB}; use youmubot_prelude::*; @@ -55,7 +55,7 @@ pub fn load_role_list( let v2 = Roles::load_from_path(path.as_ref()); let v2 = match v2 { Ok(v2) => { - map.insert::(v2); + map.insert::(Arc::new(v2)); return Ok(()); } Err(v2) => v2, diff --git a/youmubot-core/src/lib.rs b/youmubot-core/src/lib.rs index afeb39d..6a1cae1 100644 --- a/youmubot-core/src/lib.rs +++ b/youmubot-core/src/lib.rs @@ -19,7 +19,6 @@ pub use fun::FUN_GROUP; /// Sets up all databases in the client. pub fn setup( path: &std::path::Path, - client: &serenity::client::Client, data: &mut TypeMap, ) -> serenity::framework::standard::CommandResult { db::SoftBans::insert_into(&mut *data, &path.join("soft_bans.yaml"))?; @@ -29,18 +28,21 @@ pub fn setup( &path.join("roles.yaml"), )?; - // Create handler threads - tokio::spawn(admin::watch_soft_bans( - CacheAndHttp::from_client(client), - client.data.clone(), - )); - // Start reaction handlers data.insert::(community::ReactionWatchers::new(&*data)?); Ok(()) } +pub fn ready_hook(ctx: &Context) -> CommandResult { + // Create handler threads + tokio::spawn(admin::watch_soft_bans( + CacheAndHttp::from_context(ctx), + ctx.data.clone(), + )); + Ok(()) +} + // A help command #[help] pub async fn help( diff --git a/youmubot-db/src/lib.rs b/youmubot-db/src/lib.rs index 34f09a3..d848446 100644 --- a/youmubot-db/src/lib.rs +++ b/youmubot-db/src/lib.rs @@ -4,19 +4,20 @@ use serenity::{ model::id::GuildId, prelude::{TypeMap, TypeMapKey}, }; -use std::{collections::HashMap, path::Path}; +use std::{collections::HashMap, path::Path, sync::Arc}; /// GuildMap defines the guild-map type. /// It is basically a HashMap from a GuildId to a data structure. pub type GuildMap = HashMap; /// The generic DB type we will be using. -pub struct DB(std::marker::PhantomData); +#[derive(Debug, Clone)] +pub struct DB(pub Arc>); /// A short type abbreviation for a FileDatabase. type Database = FileDatabase; impl TypeMapKey for DB { - type Value = Database; + type Value = Arc>; } impl DB @@ -29,15 +30,15 @@ where } /// Insert into a ShareMap. - pub fn insert_into(data: &mut TypeMap, path: impl AsRef) -> Result<(), DBError> { - let db = Database::::load_from_path_or_default(path)?; - data.insert::>(db); - Ok(()) + pub fn insert_into(data: &mut TypeMap, path: impl AsRef) -> Result { + let db = Arc::new(Database::::load_from_path_or_default(path)?); + data.insert::>(db.clone()); + Ok(Self(db)) } /// Open a previously inserted DB. pub fn open(data: &TypeMap) -> DBWriteGuard { - data.get::().expect("DB initialized").into() + data.get::().expect("DB initialized").as_ref().into() } } diff --git a/youmubot-osu/Cargo.toml b/youmubot-osu/Cargo.toml index 53b17fe..f42ad13 100644 --- a/youmubot-osu/Cargo.toml +++ b/youmubot-osu/Cargo.toml @@ -20,6 +20,7 @@ rosu-v2 = { git = "https://github.com/natsukagami/rosu-v2", rev = "6f6731cb2f0d2 time = "0.3" serde = { version = "1.0.137", features = ["derive"] } serenity = "0.12" +poise = "0.6" zip = "0.6.2" rand = "0.8" diff --git a/youmubot-osu/src/discord/app_commands.rs b/youmubot-osu/src/discord/app_commands.rs new file mode 100644 index 0000000..04402c6 --- /dev/null +++ b/youmubot-osu/src/discord/app_commands.rs @@ -0,0 +1,9 @@ +use youmubot_prelude::*; + +#[poise::command(slash_command)] +pub async fn example + Sync>( + context: poise::Context<'_, T, Error>, + arg: String, +) -> Result<(), Error> { + todo!() +} diff --git a/youmubot-osu/src/discord/beatmap_cache.rs b/youmubot-osu/src/discord/beatmap_cache.rs index b523844..30b27bb 100644 --- a/youmubot-osu/src/discord/beatmap_cache.rs +++ b/youmubot-osu/src/discord/beatmap_cache.rs @@ -8,11 +8,18 @@ use youmubot_prelude::*; /// BeatmapMetaCache intercepts beatmap-by-id requests and caches them for later recalling. /// Does not cache non-Ranked beatmaps. +#[derive(Clone)] pub struct BeatmapMetaCache { client: Arc, pool: Pool, } +impl std::fmt::Debug for BeatmapMetaCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + impl TypeMapKey for BeatmapMetaCache { type Value = BeatmapMetaCache; } diff --git a/youmubot-osu/src/discord/db.rs b/youmubot-osu/src/discord/db.rs index 9f8c2cf..e36df02 100644 --- a/youmubot-osu/src/discord/db.rs +++ b/youmubot-osu/src/discord/db.rs @@ -9,6 +9,7 @@ use serenity::model::id::{ChannelId, UserId}; use youmubot_prelude::*; /// Save the user IDs. +#[derive(Debug, Clone)] pub struct OsuSavedUsers { pool: Pool, } @@ -60,6 +61,7 @@ impl OsuSavedUsers { } /// Save each channel's last requested beatmap. +#[derive(Debug, Clone)] pub struct OsuLastBeatmap(Pool); impl TypeMapKey for OsuLastBeatmap { @@ -99,6 +101,7 @@ impl OsuLastBeatmap { } /// Save each channel's last requested beatmap. +#[derive(Debug, Clone)] pub struct OsuUserBests(Pool); impl TypeMapKey for OsuUserBests { diff --git a/youmubot-osu/src/discord/mod.rs b/youmubot-osu/src/discord/mod.rs index 04ad7d7..f540ba3 100644 --- a/youmubot-osu/src/discord/mod.rs +++ b/youmubot-osu/src/discord/mod.rs @@ -21,6 +21,7 @@ use std::{str::FromStr, sync::Arc}; use youmubot_prelude::*; mod announcer; +pub mod app_commands; pub(crate) mod beatmap_cache; mod cache; mod db; @@ -43,6 +44,33 @@ impl TypeMapKey for OsuClient { type Value = Arc; } +/// The environment for osu! app commands. +#[derive(Clone)] +pub struct Env { + pub(crate) prelude: youmubot_prelude::Env, + // databases + pub(crate) saved_users: OsuSavedUsers, + pub(crate) last_beatmaps: OsuLastBeatmap, + pub(crate) user_bests: OsuUserBests, + // clients + pub(crate) client: Arc, + pub(crate) oppai: oppai_cache::BeatmapCache, + pub(crate) beatmaps: beatmap_cache::BeatmapMetaCache, +} + +impl std::fmt::Debug for Env { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + +impl TypeMapKey for Env { + type Value = Env; +} + +/// The command context for osu! app commands. +pub(crate) type CmdContext<'a> = youmubot_prelude::CmdContext<'a, Env>; + /// Sets up the osu! command handling section. /// /// This automatically enables: @@ -55,50 +83,60 @@ impl TypeMapKey for OsuClient { /// - Hooks. Hooks are completely opt-in. /// pub async fn setup( - _path: &std::path::Path, data: &mut TypeMap, + // dependencies + prelude: youmubot_prelude::Env, announcers: &mut AnnouncerHandler, -) -> CommandResult { - let sql_client = data.get::().unwrap().clone(); +) -> Result { // Databases - data.insert::(OsuSavedUsers::new(sql_client.clone())); - data.insert::(OsuLastBeatmap::new(sql_client.clone())); - data.insert::(OsuUserBests::new(sql_client.clone())); + let saved_users = OsuSavedUsers::new(prelude.sql.clone()); + let last_beatmaps = OsuLastBeatmap::new(prelude.sql.clone()); + let user_bests = OsuUserBests::new(prelude.sql.clone()); // API client - let http_client = data.get::().unwrap().clone(); - let mk_osu_client = || async { - Arc::new( - OsuHttpClient::new( - std::env::var("OSU_API_CLIENT_ID") - .expect("Please set OSU_API_CLIENT_ID as osu! api v2 client ID.") - .parse() - .expect("client_id should be u64"), - std::env::var("OSU_API_CLIENT_SECRET") - .expect("Please set OSU_API_CLIENT_SECRET as osu! api v2 client secret."), - ) - .await - .expect("osu! should be initialized"), + let osu_client = Arc::new( + OsuHttpClient::new( + std::env::var("OSU_API_CLIENT_ID") + .expect("Please set OSU_API_CLIENT_ID as osu! api v2 client ID.") + .parse() + .expect("client_id should be u64"), + std::env::var("OSU_API_CLIENT_SECRET") + .expect("Please set OSU_API_CLIENT_SECRET as osu! api v2 client secret."), ) - }; - let osu_client = mk_osu_client().await; - data.insert::(osu_client.clone()); - data.insert::(oppai_cache::BeatmapCache::new( - http_client.clone(), - sql_client.clone(), - )); - data.insert::(beatmap_cache::BeatmapMetaCache::new( - osu_client.clone(), - sql_client, - )); + .await + .expect("osu! should be initialized"), + ); + let oppai_cache = oppai_cache::BeatmapCache::new(prelude.http.clone(), prelude.sql.clone()); + let beatmap_cache = + beatmap_cache::BeatmapMetaCache::new(osu_client.clone(), prelude.sql.clone()); // Announcer - let osu_client = mk_osu_client().await; announcers.add( announcer::ANNOUNCER_KEY, - announcer::Announcer::new(osu_client), + announcer::Announcer::new(osu_client.clone()), ); - Ok(()) + + // Legacy data + data.insert::(last_beatmaps.clone()); + data.insert::(saved_users.clone()); + data.insert::(user_bests.clone()); + data.insert::(osu_client.clone()); + data.insert::(oppai_cache.clone()); + data.insert::(beatmap_cache.clone()); + + let env = Env { + prelude, + saved_users, + last_beatmaps, + user_bests, + client: osu_client, + oppai: oppai_cache, + beatmaps: beatmap_cache, + }; + + data.insert::(env.clone()); + + Ok(env) } #[group] diff --git a/youmubot-osu/src/discord/oppai_cache.rs b/youmubot-osu/src/discord/oppai_cache.rs index 54d86ed..3afe453 100644 --- a/youmubot-osu/src/discord/oppai_cache.rs +++ b/youmubot-osu/src/discord/oppai_cache.rs @@ -304,6 +304,7 @@ impl BeatmapContent { } /// A central cache for the beatmaps. +#[derive(Debug, Clone)] pub struct BeatmapCache { client: ratelimit::Ratelimit, pool: Pool, diff --git a/youmubot-osu/src/lib.rs b/youmubot-osu/src/lib.rs index 1dbe4e0..7e4be8b 100644 --- a/youmubot-osu/src/lib.rs +++ b/youmubot-osu/src/lib.rs @@ -6,11 +6,13 @@ use models::*; use request::builders::*; use request::*; use std::convert::TryInto; +use std::sync::Arc; use youmubot_prelude::*; /// Client is the client that will perform calls to the osu! api server. +#[derive(Clone)] pub struct Client { - rosu: rosu_v2::Osu, + rosu: Arc, } pub fn vec_try_into>(v: Vec) -> Result, T::Error> { @@ -31,7 +33,9 @@ impl Client { .client_secret(client_secret) .build() .await?; - Ok(Client { rosu }) + Ok(Client { + rosu: Arc::new(rosu), + }) } pub async fn beatmaps( diff --git a/youmubot-prelude/Cargo.toml b/youmubot-prelude/Cargo.toml index e970398..2bbc8f0 100644 --- a/youmubot-prelude/Cargo.toml +++ b/youmubot-prelude/Cargo.toml @@ -17,6 +17,7 @@ reqwest = { version = "0.11.10", features = ["json"] } chrono = "0.4.19" flume = "0.10.13" dashmap = "5.3.4" +poise = "0.6" [dependencies.serenity] version = "0.12" diff --git a/youmubot-prelude/src/announcer.rs b/youmubot-prelude/src/announcer.rs index fc51038..34100ee 100644 --- a/youmubot-prelude/src/announcer.rs +++ b/youmubot-prelude/src/announcer.rs @@ -18,7 +18,7 @@ use serenity::{ prelude::*, utils::MessageBuilder, }; -use std::{collections::HashMap, sync::Arc}; +use std::{arch::x86_64::_bittestandcomplement, collections::HashMap, sync::Arc}; use youmubot_db::DB; #[derive(Debug, Clone)] @@ -28,6 +28,10 @@ impl CacheAndHttp { pub fn from_client(client: &Client) -> Self { Self(client.cache.clone(), client.http.clone()) } + + pub fn from_context(context: &Context) -> Self { + Self(context.cache.clone(), context.http.clone()) + } } impl CacheHttp for CacheAndHttp { @@ -94,23 +98,14 @@ impl MemberToChannels { /// /// This struct manages the list of all Announcers, firing them in a certain interval. pub struct AnnouncerHandler { - cache_http: CacheAndHttp, - data: AppData, announcers: HashMap<&'static str, RwLock>>, } -// Querying for the AnnouncerHandler in the internal data returns a vec of keys. -impl TypeMapKey for AnnouncerHandler { - type Value = Vec<&'static str>; -} - /// Announcer-managing related. impl AnnouncerHandler { /// Create a new instance of the handler. - pub fn new(client: &serenity::Client) -> Self { + pub fn new() -> Self { Self { - cache_http: CacheAndHttp(client.cache.clone(), client.http.clone()), - data: client.data.clone(), announcers: HashMap::new(), } } @@ -136,10 +131,30 @@ impl AnnouncerHandler { self } } + + pub fn run(self, client: &Client) -> AnnouncerRunner { + let runner = AnnouncerRunner { + cache_http: CacheAndHttp::from_client(client), + data: client.data.clone(), + announcers: self.announcers, + }; + runner + } +} + +pub struct AnnouncerRunner { + cache_http: CacheAndHttp, + data: AppData, + announcers: HashMap<&'static str, RwLock>>, +} + +// Querying for the AnnouncerRunner in the internal data returns a vec of keys. +impl TypeMapKey for AnnouncerRunner { + type Value = Vec<&'static str>; } /// Execution-related. -impl AnnouncerHandler { +impl AnnouncerRunner { /// Collect the list of guilds and their respective channels, by the key of the announcer. async fn get_guilds(data: &AppData, key: &'static str) -> Result> { let data = AnnouncerChannels::open(&*data.read().await) @@ -214,7 +229,7 @@ pub async fn list_announcers(ctx: &Context, m: &Message, _: Args) -> CommandResu let guild_id = m.guild_id.unwrap(); let data = &*ctx.data.read().await; let announcers = AnnouncerChannels::open(data); - let channels = data.get::().unwrap(); + let channels = data.get::().unwrap(); let channels = channels .iter() .filter_map(|&key| { @@ -249,7 +264,7 @@ pub async fn list_announcers(ctx: &Context, m: &Message, _: Args) -> CommandResu pub async fn register_announcer(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { let data = ctx.data.read().await; let key = args.single::()?; - let keys = data.get::().unwrap(); + let keys = data.get::().unwrap(); if !keys.contains(&&key[..]) { m.reply( &ctx, @@ -296,7 +311,7 @@ pub async fn register_announcer(ctx: &Context, m: &Message, mut args: Args) -> C pub async fn remove_announcer(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { let data = ctx.data.read().await; let key = args.single::()?; - let keys = data.get::().unwrap(); + let keys = data.get::().unwrap(); if !keys.contains(&key.as_str()) { m.reply( &ctx, diff --git a/youmubot-prelude/src/lib.rs b/youmubot-prelude/src/lib.rs index ea555b1..c5a653b 100644 --- a/youmubot-prelude/src/lib.rs +++ b/youmubot-prelude/src/lib.rs @@ -1,3 +1,4 @@ +use announcer::AnnouncerChannels; /// Module `prelude` provides a sane set of default imports that can be used inside /// a Youmubot source file. pub use serenity::prelude::*; @@ -38,6 +39,20 @@ pub type AppData = Arc>; /// The HTTP client. pub struct HTTPClient; +/// The global context type for app commands +pub type CmdContext<'a, Env> = poise::Context<'a, Env, anyhow::Error>; + +/// The created base environment. +#[derive(Debug, Clone)] +pub struct Env { + // clients + pub http: reqwest::Client, + pub sql: youmubot_db_sql::Pool, + pub members: Arc, + // databases + // pub(crate) announcer_channels: announcer::AnnouncerChannels, +} + impl TypeMapKey for HTTPClient { type Value = reqwest::Client; } diff --git a/youmubot-prelude/src/ratelimit.rs b/youmubot-prelude/src/ratelimit.rs index 47b6cc3..a28dc7e 100644 --- a/youmubot-prelude/src/ratelimit.rs +++ b/youmubot-prelude/src/ratelimit.rs @@ -7,6 +7,7 @@ use flume::{bounded as channel, Receiver, Sender}; use std::ops::Deref; /// Holds the underlying `T` in a rate-limited way. +#[derive(Debug, Clone)] pub struct Ratelimit { inner: T, recv: Receiver<()>, diff --git a/youmubot-prelude/src/setup.rs b/youmubot-prelude/src/setup.rs index 3c5c873..d678b33 100644 --- a/youmubot-prelude/src/setup.rs +++ b/youmubot-prelude/src/setup.rs @@ -1,6 +1,8 @@ use serenity::prelude::*; use std::{path::Path, time::Duration}; +use crate::Env; + /// Set up the prelude libraries. /// /// Panics on failure: Youmubot should *NOT* attempt to continue when this function fails. @@ -8,7 +10,7 @@ pub async fn setup_prelude( db_path: impl AsRef, sql_path: impl AsRef, data: &mut TypeMap, -) { +) -> Env { // Setup the announcer DB. crate::announcer::AnnouncerChannels::insert_into( data, @@ -22,17 +24,25 @@ pub async fn setup_prelude( .expect("SQL database set up"); // Set up the HTTP client. - data.insert::( - reqwest::ClientBuilder::new() - .connect_timeout(Duration::from_secs(5)) - .timeout(Duration::from_secs(60)) - .build() - .expect("Build be able to build HTTP client"), - ); + let http_client = reqwest::ClientBuilder::new() + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(60)) + .build() + .expect("Build be able to build HTTP client"); + data.insert::(http_client.clone()); // Set up the member cache. - data.insert::(std::sync::Arc::new(crate::MemberCache::default())); + let member_cache = std::sync::Arc::new(crate::MemberCache::default()); + data.insert::(member_cache.clone()); // Set up the SQL client. - data.insert::(sql_pool); + data.insert::(sql_pool.clone()); + + let env = Env { + http: http_client, + sql: sql_pool, + members: member_cache, + }; + + env } diff --git a/youmubot/Cargo.toml b/youmubot/Cargo.toml index b48d8b1..7dbec7f 100644 --- a/youmubot/Cargo.toml +++ b/youmubot/Cargo.toml @@ -13,6 +13,7 @@ codeforces = ["youmubot-cf"] [dependencies] serenity = "0.12" +poise = "0.6" tokio = { version = "1.19.2", features = ["rt-multi-thread"] } dotenv = "0.15.0" env_logger = "0.9.0" diff --git a/youmubot/src/compose_framework.rs b/youmubot/src/compose_framework.rs index ec338d4..6024e62 100644 --- a/youmubot/src/compose_framework.rs +++ b/youmubot/src/compose_framework.rs @@ -25,6 +25,11 @@ impl Framework for ComposedFramework { .await } } + async fn init(&mut self, client: &Client) { + for f in self.frameworks.iter_mut() { + f.init(&client).await + } + } } impl ComposedFramework { /// Dispatch to all inner frameworks in a loop. Returns a `Pin>` because rust. diff --git a/youmubot/src/main.rs b/youmubot/src/main.rs index 5e28d41..b41620b 100644 --- a/youmubot/src/main.rs +++ b/youmubot/src/main.rs @@ -16,16 +16,44 @@ mod compose_framework; struct Handler { hooks: Vec>>, + ready_hooks: Vec CommandResult>, } impl Handler { fn new() -> Handler { - Handler { hooks: vec![] } + Handler { + hooks: vec![], + ready_hooks: vec![], + } } fn push_hook(&mut self, f: T) { self.hooks.push(RwLock::new(Box::new(f))); } + + fn push_ready_hook(&mut self, f: fn(&Context) -> CommandResult) { + self.ready_hooks.push(f); + } +} + +/// Environment to be passed into the framework +#[derive(Debug, Clone)] +struct Env { + prelude: youmubot_prelude::Env, + #[cfg(feature = "osu")] + osu: youmubot_osu::discord::Env, +} + +impl AsRef for Env { + fn as_ref(&self) -> &youmubot_prelude::Env { + &self.prelude + } +} + +impl AsRef for Env { + fn as_ref(&self) -> &youmubot_osu::discord::Env { + &self.osu + } } #[async_trait] @@ -41,6 +69,10 @@ impl EventHandler for Handler { .init(&ctx) .await; println!("{} is connected!", ready.user.name); + + for f in &self.ready_hooks { + f(&ctx).pls_ok(); + } } async fn message(&self, ctx: Context, message: Message) { @@ -82,20 +114,105 @@ async fn main() { } let mut handler = Handler::new(); + #[cfg(feature = "core")] + handler.push_ready_hook(youmubot_core::ready_hook); // Set up hooks #[cfg(feature = "osu")] - handler.push_hook(youmubot_osu::discord::hook); - #[cfg(feature = "osu")] - handler.push_hook(youmubot_osu::discord::dot_osu_hook); + { + handler.push_hook(youmubot_osu::discord::hook); + handler.push_hook(youmubot_osu::discord::dot_osu_hook); + } #[cfg(feature = "codeforces")] handler.push_hook(youmubot_cf::InfoHook); // Collect the token let token = var("TOKEN").expect("Please set TOKEN as the Discord Bot's token to be used."); + + // Data to be put into context + let mut data = TypeMap::new(); + + // Set up announcer handler + let mut announcers = AnnouncerHandler::new(); + + // Setup each package starting from the prelude. + let env = { + let db_path = var("DBPATH") + .map(std::path::PathBuf::from) + .unwrap_or_else(|e| { + println!("No DBPATH set up ({:?}), using `/data`", e); + std::path::PathBuf::from("/data") + }); + let sql_path = var("SQLPATH") + .map(std::path::PathBuf::from) + .unwrap_or_else(|e| { + let res = db_path.join("youmubot.db"); + println!("No SQLPATH set up ({:?}), using `{:?}`", e, res); + res + }); + let prelude = youmubot_prelude::setup::setup_prelude(&db_path, sql_path, &mut data).await; + // Setup core + #[cfg(feature = "core")] + youmubot_core::setup(&db_path, &mut data).expect("Setup db should succeed"); + // osu! + #[cfg(feature = "osu")] + let osu = youmubot_osu::discord::setup(&mut data, prelude.clone(), &mut announcers) + .await + .expect("osu! is initialized"); + // codeforces + #[cfg(feature = "codeforces")] + youmubot_cf::setup(&db_path, &mut data, &mut announcers).await; + + Env { + prelude, + #[cfg(feature = "osu")] + osu, + } + }; + + #[cfg(feature = "core")] + println!("Core enabled."); + #[cfg(feature = "osu")] + println!("osu! enabled."); + #[cfg(feature = "codeforces")] + println!("codeforces enabled."); + // Set up base framework let fw = setup_framework(&token[..]).await; - let composed = ComposedFramework::new(vec![Box::new(fw)]); + // Poise for application commands + let poise_fw = poise::Framework::builder() + .setup(|_, _, _| Box::pin(async { Ok(env) as Result<_> })) + .options(poise::FrameworkOptions { + prefix_options: poise::PrefixFrameworkOptions { + prefix: None, + mention_as_prefix: true, + execute_untracked_edits: true, + execute_self_messages: false, + ignore_thread_creation: true, + case_insensitive_commands: true, + ..Default::default() + }, + on_error: |err| { + Box::pin(async move { + if let poise::FrameworkError::Command { error, ctx, .. } = err { + let reply = format!( + "Command '{}' returned error {:?}", + ctx.invoked_command_name(), + error + ); + ctx.reply(&reply).await.pls_ok(); + println!("{}", reply) + } else { + eprintln!("Poise error: {:?}", err) + } + }) + }, + commands: vec![poise_register()], + ..Default::default() + }) + .build(); + + let composed = ComposedFramework::new(vec![Box::new(fw), Box::new(poise_fw)]); // Sets up a client let mut client = { @@ -110,60 +227,20 @@ async fn main() { | GatewayIntents::DIRECT_MESSAGES | GatewayIntents::DIRECT_MESSAGE_REACTIONS; Client::builder(token, intents) + .type_map(data) .framework(composed) .event_handler(handler) .await .unwrap() }; - // Set up announcer handler - let mut announcers = AnnouncerHandler::new(&client); - - // Setup each package starting from the prelude. - { - let mut data = client.data.write().await; - let db_path = var("DBPATH") - .map(std::path::PathBuf::from) - .unwrap_or_else(|e| { - println!("No DBPATH set up ({:?}), using `/data`", e); - std::path::PathBuf::from("/data") - }); - let sql_path = var("SQLPATH") - .map(std::path::PathBuf::from) - .unwrap_or_else(|e| { - let res = db_path.join("youmubot.db"); - println!("No SQLPATH set up ({:?}), using `{:?}`", e, res); - res - }); - youmubot_prelude::setup::setup_prelude(&db_path, sql_path, &mut data).await; - // Setup core - #[cfg(feature = "core")] - youmubot_core::setup(&db_path, &client, &mut data).expect("Setup db should succeed"); - // osu! - #[cfg(feature = "osu")] - youmubot_osu::discord::setup(&db_path, &mut data, &mut announcers) - .await - .expect("osu! is initialized"); - // codeforces - #[cfg(feature = "codeforces")] - youmubot_cf::setup(&db_path, &mut data, &mut announcers).await; - } - - #[cfg(feature = "core")] - println!("Core enabled."); - #[cfg(feature = "osu")] - println!("osu! enabled."); - #[cfg(feature = "codeforces")] - println!("codeforces enabled."); - + let announcers = announcers.run(&client); tokio::spawn(announcers.scan(std::time::Duration::from_secs(300))); println!("Starting..."); if let Err(v) = client.start().await { panic!("{}", v) } - - println!("Hello, world!"); } // Sets up a framework for a client @@ -229,6 +306,18 @@ async fn setup_framework(token: &str) -> StandardFramework { fw } +// Poise command to register +#[poise::command( + prefix_command, + rename = "register", + required_permissions = "MANAGE_GUILD" +)] +async fn poise_register(ctx: CmdContext<'_, Env>) -> Result<()> { + // TODO: make this work for guild owners too + poise::builtins::register_application_commands_buttons(ctx).await?; + Ok(()) +} + // Hooks! #[hook]