Re-structure how environment is passed around to allow setting up poise

This commit is contained in:
Natsu Kagami 2024-02-25 23:12:34 +01:00
parent 0795a07a2c
commit d5fb2cce69
Signed by: nki
GPG key ID: 55A032EB38B49ADB
19 changed files with 417 additions and 124 deletions

90
Cargo.lock generated
View file

@ -440,6 +440,41 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "darling"
version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn 2.0.48",
]
[[package]]
name = "darling_macro"
version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f"
dependencies = [
"darling_core",
"quote",
"syn 2.0.48",
]
[[package]] [[package]]
name = "dashmap" name = "dashmap"
version = "5.5.3" version = "5.5.3"
@ -481,6 +516,17 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "derivative"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.7" version = "0.10.7"
@ -1013,6 +1059,12 @@ dependencies = [
"cc", "cc",
] ]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]] [[package]]
name = "idna" name = "idna"
version = "0.5.0" version = "0.5.0"
@ -1543,6 +1595,35 @@ version = "0.3.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb"
[[package]]
name = "poise"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1819d5a45e3590ef33754abce46432570c54a120798bdbf893112b4211fa09a6"
dependencies = [
"async-trait",
"derivative",
"futures-util",
"parking_lot",
"poise_macros",
"regex",
"serenity",
"tokio",
"tracing",
]
[[package]]
name = "poise_macros"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fa2c123c961e78315cd3deac7663177f12be4460f5440dbf62a7ed37b1effea"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.48",
]
[[package]] [[package]]
name = "powerfmt" name = "powerfmt"
version = "0.2.0" version = "0.2.0"
@ -2346,6 +2427,12 @@ dependencies = [
"unicode-normalization", "unicode-normalization",
] ]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.5.0" version = "2.5.0"
@ -3106,6 +3193,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"dotenv", "dotenv",
"env_logger", "env_logger",
"poise",
"serenity", "serenity",
"tokio", "tokio",
"youmubot-cf", "youmubot-cf",
@ -3183,6 +3271,7 @@ dependencies = [
"dashmap", "dashmap",
"lazy_static", "lazy_static",
"osuparse", "osuparse",
"poise",
"rand", "rand",
"regex", "regex",
"reqwest", "reqwest",
@ -3208,6 +3297,7 @@ dependencies = [
"dashmap", "dashmap",
"flume 0.10.14", "flume 0.10.14",
"futures-util", "futures-util",
"poise",
"reqwest", "reqwest",
"serenity", "serenity",
"tokio", "tokio",

View file

@ -5,7 +5,7 @@ use serenity::model::{
channel::ReactionType, channel::ReactionType,
id::{MessageId, RoleId, UserId}, id::{MessageId, RoleId, UserId},
}; };
use std::collections::HashMap; use std::{collections::HashMap, sync::Arc};
use youmubot_db::{GuildMap, DB}; use youmubot_db::{GuildMap, DB};
use youmubot_prelude::*; use youmubot_prelude::*;
@ -55,7 +55,7 @@ pub fn load_role_list(
let v2 = Roles::load_from_path(path.as_ref()); let v2 = Roles::load_from_path(path.as_ref());
let v2 = match v2 { let v2 = match v2 {
Ok(v2) => { Ok(v2) => {
map.insert::<Roles>(v2); map.insert::<Roles>(Arc::new(v2));
return Ok(()); return Ok(());
} }
Err(v2) => v2, Err(v2) => v2,

View file

@ -19,7 +19,6 @@ pub use fun::FUN_GROUP;
/// Sets up all databases in the client. /// Sets up all databases in the client.
pub fn setup( pub fn setup(
path: &std::path::Path, path: &std::path::Path,
client: &serenity::client::Client,
data: &mut TypeMap, data: &mut TypeMap,
) -> serenity::framework::standard::CommandResult { ) -> serenity::framework::standard::CommandResult {
db::SoftBans::insert_into(&mut *data, &path.join("soft_bans.yaml"))?; db::SoftBans::insert_into(&mut *data, &path.join("soft_bans.yaml"))?;
@ -29,18 +28,21 @@ pub fn setup(
&path.join("roles.yaml"), &path.join("roles.yaml"),
)?; )?;
// Create handler threads
tokio::spawn(admin::watch_soft_bans(
CacheAndHttp::from_client(client),
client.data.clone(),
));
// Start reaction handlers // Start reaction handlers
data.insert::<community::ReactionWatchers>(community::ReactionWatchers::new(&*data)?); data.insert::<community::ReactionWatchers>(community::ReactionWatchers::new(&*data)?);
Ok(()) Ok(())
} }
pub fn ready_hook(ctx: &Context) -> CommandResult {
// Create handler threads
tokio::spawn(admin::watch_soft_bans(
CacheAndHttp::from_context(ctx),
ctx.data.clone(),
));
Ok(())
}
// A help command // A help command
#[help] #[help]
pub async fn help( pub async fn help(

View file

@ -4,19 +4,20 @@ use serenity::{
model::id::GuildId, model::id::GuildId,
prelude::{TypeMap, TypeMapKey}, prelude::{TypeMap, TypeMapKey},
}; };
use std::{collections::HashMap, path::Path}; use std::{collections::HashMap, path::Path, sync::Arc};
/// GuildMap defines the guild-map type. /// GuildMap defines the guild-map type.
/// It is basically a HashMap from a GuildId to a data structure. /// It is basically a HashMap from a GuildId to a data structure.
pub type GuildMap<V> = HashMap<GuildId, V>; pub type GuildMap<V> = HashMap<GuildId, V>;
/// The generic DB type we will be using. /// The generic DB type we will be using.
pub struct DB<T>(std::marker::PhantomData<T>); #[derive(Debug, Clone)]
pub struct DB<T>(pub Arc<Database<T>>);
/// A short type abbreviation for a FileDatabase. /// A short type abbreviation for a FileDatabase.
type Database<T> = FileDatabase<T, Yaml>; type Database<T> = FileDatabase<T, Yaml>;
impl<T: std::any::Any + Send + Sync> TypeMapKey for DB<T> { impl<T: std::any::Any + Send + Sync> TypeMapKey for DB<T> {
type Value = Database<T>; type Value = Arc<Database<T>>;
} }
impl<T: std::any::Any + Default + Send + Sync + Clone + Serialize + std::fmt::Debug> DB<T> impl<T: std::any::Any + Default + Send + Sync + Clone + Serialize + std::fmt::Debug> DB<T>
@ -29,15 +30,15 @@ where
} }
/// Insert into a ShareMap. /// Insert into a ShareMap.
pub fn insert_into(data: &mut TypeMap, path: impl AsRef<Path>) -> Result<(), DBError> { pub fn insert_into(data: &mut TypeMap, path: impl AsRef<Path>) -> Result<Self, DBError> {
let db = Database::<T>::load_from_path_or_default(path)?; let db = Arc::new(Database::<T>::load_from_path_or_default(path)?);
data.insert::<DB<T>>(db); data.insert::<DB<T>>(db.clone());
Ok(()) Ok(Self(db))
} }
/// Open a previously inserted DB. /// Open a previously inserted DB.
pub fn open(data: &TypeMap) -> DBWriteGuard<T> { pub fn open(data: &TypeMap) -> DBWriteGuard<T> {
data.get::<Self>().expect("DB initialized").into() data.get::<Self>().expect("DB initialized").as_ref().into()
} }
} }

View file

@ -20,6 +20,7 @@ rosu-v2 = { git = "https://github.com/natsukagami/rosu-v2", rev = "6f6731cb2f0d2
time = "0.3" time = "0.3"
serde = { version = "1.0.137", features = ["derive"] } serde = { version = "1.0.137", features = ["derive"] }
serenity = "0.12" serenity = "0.12"
poise = "0.6"
zip = "0.6.2" zip = "0.6.2"
rand = "0.8" rand = "0.8"

View file

@ -0,0 +1,9 @@
use youmubot_prelude::*;
#[poise::command(slash_command)]
pub async fn example<T: AsRef<crate::Env> + Sync>(
context: poise::Context<'_, T, Error>,
arg: String,
) -> Result<(), Error> {
todo!()
}

View file

@ -8,11 +8,18 @@ use youmubot_prelude::*;
/// BeatmapMetaCache intercepts beatmap-by-id requests and caches them for later recalling. /// BeatmapMetaCache intercepts beatmap-by-id requests and caches them for later recalling.
/// Does not cache non-Ranked beatmaps. /// Does not cache non-Ranked beatmaps.
#[derive(Clone)]
pub struct BeatmapMetaCache { pub struct BeatmapMetaCache {
client: Arc<Client>, client: Arc<Client>,
pool: Pool, pool: Pool,
} }
impl std::fmt::Debug for BeatmapMetaCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<BeatmapMetaCache>")
}
}
impl TypeMapKey for BeatmapMetaCache { impl TypeMapKey for BeatmapMetaCache {
type Value = BeatmapMetaCache; type Value = BeatmapMetaCache;
} }

View file

@ -9,6 +9,7 @@ use serenity::model::id::{ChannelId, UserId};
use youmubot_prelude::*; use youmubot_prelude::*;
/// Save the user IDs. /// Save the user IDs.
#[derive(Debug, Clone)]
pub struct OsuSavedUsers { pub struct OsuSavedUsers {
pool: Pool, pool: Pool,
} }
@ -60,6 +61,7 @@ impl OsuSavedUsers {
} }
/// Save each channel's last requested beatmap. /// Save each channel's last requested beatmap.
#[derive(Debug, Clone)]
pub struct OsuLastBeatmap(Pool); pub struct OsuLastBeatmap(Pool);
impl TypeMapKey for OsuLastBeatmap { impl TypeMapKey for OsuLastBeatmap {
@ -99,6 +101,7 @@ impl OsuLastBeatmap {
} }
/// Save each channel's last requested beatmap. /// Save each channel's last requested beatmap.
#[derive(Debug, Clone)]
pub struct OsuUserBests(Pool); pub struct OsuUserBests(Pool);
impl TypeMapKey for OsuUserBests { impl TypeMapKey for OsuUserBests {

View file

@ -21,6 +21,7 @@ use std::{str::FromStr, sync::Arc};
use youmubot_prelude::*; use youmubot_prelude::*;
mod announcer; mod announcer;
pub mod app_commands;
pub(crate) mod beatmap_cache; pub(crate) mod beatmap_cache;
mod cache; mod cache;
mod db; mod db;
@ -43,6 +44,33 @@ impl TypeMapKey for OsuClient {
type Value = Arc<OsuHttpClient>; type Value = Arc<OsuHttpClient>;
} }
/// The environment for osu! app commands.
#[derive(Clone)]
pub struct Env {
pub(crate) prelude: youmubot_prelude::Env,
// databases
pub(crate) saved_users: OsuSavedUsers,
pub(crate) last_beatmaps: OsuLastBeatmap,
pub(crate) user_bests: OsuUserBests,
// clients
pub(crate) client: Arc<OsuHttpClient>,
pub(crate) oppai: oppai_cache::BeatmapCache,
pub(crate) beatmaps: beatmap_cache::BeatmapMetaCache,
}
impl std::fmt::Debug for Env {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<osu::Env>")
}
}
impl TypeMapKey for Env {
type Value = Env;
}
/// The command context for osu! app commands.
pub(crate) type CmdContext<'a> = youmubot_prelude::CmdContext<'a, Env>;
/// Sets up the osu! command handling section. /// Sets up the osu! command handling section.
/// ///
/// This automatically enables: /// This automatically enables:
@ -55,50 +83,60 @@ impl TypeMapKey for OsuClient {
/// - Hooks. Hooks are completely opt-in. /// - Hooks. Hooks are completely opt-in.
/// ///
pub async fn setup( pub async fn setup(
_path: &std::path::Path,
data: &mut TypeMap, data: &mut TypeMap,
// dependencies
prelude: youmubot_prelude::Env,
announcers: &mut AnnouncerHandler, announcers: &mut AnnouncerHandler,
) -> CommandResult { ) -> Result<Env> {
let sql_client = data.get::<SQLClient>().unwrap().clone();
// Databases // Databases
data.insert::<OsuSavedUsers>(OsuSavedUsers::new(sql_client.clone())); let saved_users = OsuSavedUsers::new(prelude.sql.clone());
data.insert::<OsuLastBeatmap>(OsuLastBeatmap::new(sql_client.clone())); let last_beatmaps = OsuLastBeatmap::new(prelude.sql.clone());
data.insert::<OsuUserBests>(OsuUserBests::new(sql_client.clone())); let user_bests = OsuUserBests::new(prelude.sql.clone());
// API client // API client
let http_client = data.get::<HTTPClient>().unwrap().clone(); let osu_client = Arc::new(
let mk_osu_client = || async { OsuHttpClient::new(
Arc::new( std::env::var("OSU_API_CLIENT_ID")
OsuHttpClient::new( .expect("Please set OSU_API_CLIENT_ID as osu! api v2 client ID.")
std::env::var("OSU_API_CLIENT_ID") .parse()
.expect("Please set OSU_API_CLIENT_ID as osu! api v2 client ID.") .expect("client_id should be u64"),
.parse() std::env::var("OSU_API_CLIENT_SECRET")
.expect("client_id should be u64"), .expect("Please set OSU_API_CLIENT_SECRET as osu! api v2 client secret."),
std::env::var("OSU_API_CLIENT_SECRET")
.expect("Please set OSU_API_CLIENT_SECRET as osu! api v2 client secret."),
)
.await
.expect("osu! should be initialized"),
) )
}; .await
let osu_client = mk_osu_client().await; .expect("osu! should be initialized"),
data.insert::<OsuClient>(osu_client.clone()); );
data.insert::<oppai_cache::BeatmapCache>(oppai_cache::BeatmapCache::new( let oppai_cache = oppai_cache::BeatmapCache::new(prelude.http.clone(), prelude.sql.clone());
http_client.clone(), let beatmap_cache =
sql_client.clone(), beatmap_cache::BeatmapMetaCache::new(osu_client.clone(), prelude.sql.clone());
));
data.insert::<beatmap_cache::BeatmapMetaCache>(beatmap_cache::BeatmapMetaCache::new(
osu_client.clone(),
sql_client,
));
// Announcer // Announcer
let osu_client = mk_osu_client().await;
announcers.add( announcers.add(
announcer::ANNOUNCER_KEY, announcer::ANNOUNCER_KEY,
announcer::Announcer::new(osu_client), announcer::Announcer::new(osu_client.clone()),
); );
Ok(())
// Legacy data
data.insert::<OsuLastBeatmap>(last_beatmaps.clone());
data.insert::<OsuSavedUsers>(saved_users.clone());
data.insert::<OsuUserBests>(user_bests.clone());
data.insert::<OsuClient>(osu_client.clone());
data.insert::<oppai_cache::BeatmapCache>(oppai_cache.clone());
data.insert::<beatmap_cache::BeatmapMetaCache>(beatmap_cache.clone());
let env = Env {
prelude,
saved_users,
last_beatmaps,
user_bests,
client: osu_client,
oppai: oppai_cache,
beatmaps: beatmap_cache,
};
data.insert::<Env>(env.clone());
Ok(env)
} }
#[group] #[group]

View file

@ -304,6 +304,7 @@ impl BeatmapContent {
} }
/// A central cache for the beatmaps. /// A central cache for the beatmaps.
#[derive(Debug, Clone)]
pub struct BeatmapCache { pub struct BeatmapCache {
client: ratelimit::Ratelimit<reqwest::Client>, client: ratelimit::Ratelimit<reqwest::Client>,
pool: Pool, pool: Pool,

View file

@ -6,11 +6,13 @@ use models::*;
use request::builders::*; use request::builders::*;
use request::*; use request::*;
use std::convert::TryInto; use std::convert::TryInto;
use std::sync::Arc;
use youmubot_prelude::*; use youmubot_prelude::*;
/// Client is the client that will perform calls to the osu! api server. /// Client is the client that will perform calls to the osu! api server.
#[derive(Clone)]
pub struct Client { pub struct Client {
rosu: rosu_v2::Osu, rosu: Arc<rosu_v2::Osu>,
} }
pub fn vec_try_into<U, T: std::convert::TryFrom<U>>(v: Vec<U>) -> Result<Vec<T>, T::Error> { pub fn vec_try_into<U, T: std::convert::TryFrom<U>>(v: Vec<U>) -> Result<Vec<T>, T::Error> {
@ -31,7 +33,9 @@ impl Client {
.client_secret(client_secret) .client_secret(client_secret)
.build() .build()
.await?; .await?;
Ok(Client { rosu }) Ok(Client {
rosu: Arc::new(rosu),
})
} }
pub async fn beatmaps( pub async fn beatmaps(

View file

@ -17,6 +17,7 @@ reqwest = { version = "0.11.10", features = ["json"] }
chrono = "0.4.19" chrono = "0.4.19"
flume = "0.10.13" flume = "0.10.13"
dashmap = "5.3.4" dashmap = "5.3.4"
poise = "0.6"
[dependencies.serenity] [dependencies.serenity]
version = "0.12" version = "0.12"

View file

@ -18,7 +18,7 @@ use serenity::{
prelude::*, prelude::*,
utils::MessageBuilder, utils::MessageBuilder,
}; };
use std::{collections::HashMap, sync::Arc}; use std::{arch::x86_64::_bittestandcomplement, collections::HashMap, sync::Arc};
use youmubot_db::DB; use youmubot_db::DB;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -28,6 +28,10 @@ impl CacheAndHttp {
pub fn from_client(client: &Client) -> Self { pub fn from_client(client: &Client) -> Self {
Self(client.cache.clone(), client.http.clone()) Self(client.cache.clone(), client.http.clone())
} }
pub fn from_context(context: &Context) -> Self {
Self(context.cache.clone(), context.http.clone())
}
} }
impl CacheHttp for CacheAndHttp { impl CacheHttp for CacheAndHttp {
@ -94,23 +98,14 @@ impl MemberToChannels {
/// ///
/// This struct manages the list of all Announcers, firing them in a certain interval. /// This struct manages the list of all Announcers, firing them in a certain interval.
pub struct AnnouncerHandler { pub struct AnnouncerHandler {
cache_http: CacheAndHttp,
data: AppData,
announcers: HashMap<&'static str, RwLock<Box<dyn Announcer + Send + Sync>>>, announcers: HashMap<&'static str, RwLock<Box<dyn Announcer + Send + Sync>>>,
} }
// Querying for the AnnouncerHandler in the internal data returns a vec of keys.
impl TypeMapKey for AnnouncerHandler {
type Value = Vec<&'static str>;
}
/// Announcer-managing related. /// Announcer-managing related.
impl AnnouncerHandler { impl AnnouncerHandler {
/// Create a new instance of the handler. /// Create a new instance of the handler.
pub fn new(client: &serenity::Client) -> Self { pub fn new() -> Self {
Self { Self {
cache_http: CacheAndHttp(client.cache.clone(), client.http.clone()),
data: client.data.clone(),
announcers: HashMap::new(), announcers: HashMap::new(),
} }
} }
@ -136,10 +131,30 @@ impl AnnouncerHandler {
self self
} }
} }
pub fn run(self, client: &Client) -> AnnouncerRunner {
let runner = AnnouncerRunner {
cache_http: CacheAndHttp::from_client(client),
data: client.data.clone(),
announcers: self.announcers,
};
runner
}
}
pub struct AnnouncerRunner {
cache_http: CacheAndHttp,
data: AppData,
announcers: HashMap<&'static str, RwLock<Box<dyn Announcer + Send + Sync>>>,
}
// Querying for the AnnouncerRunner in the internal data returns a vec of keys.
impl TypeMapKey for AnnouncerRunner {
type Value = Vec<&'static str>;
} }
/// Execution-related. /// Execution-related.
impl AnnouncerHandler { impl AnnouncerRunner {
/// Collect the list of guilds and their respective channels, by the key of the announcer. /// Collect the list of guilds and their respective channels, by the key of the announcer.
async fn get_guilds(data: &AppData, key: &'static str) -> Result<Vec<(GuildId, ChannelId)>> { async fn get_guilds(data: &AppData, key: &'static str) -> Result<Vec<(GuildId, ChannelId)>> {
let data = AnnouncerChannels::open(&*data.read().await) let data = AnnouncerChannels::open(&*data.read().await)
@ -214,7 +229,7 @@ pub async fn list_announcers(ctx: &Context, m: &Message, _: Args) -> CommandResu
let guild_id = m.guild_id.unwrap(); let guild_id = m.guild_id.unwrap();
let data = &*ctx.data.read().await; let data = &*ctx.data.read().await;
let announcers = AnnouncerChannels::open(data); let announcers = AnnouncerChannels::open(data);
let channels = data.get::<AnnouncerHandler>().unwrap(); let channels = data.get::<AnnouncerRunner>().unwrap();
let channels = channels let channels = channels
.iter() .iter()
.filter_map(|&key| { .filter_map(|&key| {
@ -249,7 +264,7 @@ pub async fn list_announcers(ctx: &Context, m: &Message, _: Args) -> CommandResu
pub async fn register_announcer(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { pub async fn register_announcer(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await; let data = ctx.data.read().await;
let key = args.single::<String>()?; let key = args.single::<String>()?;
let keys = data.get::<AnnouncerHandler>().unwrap(); let keys = data.get::<AnnouncerRunner>().unwrap();
if !keys.contains(&&key[..]) { if !keys.contains(&&key[..]) {
m.reply( m.reply(
&ctx, &ctx,
@ -296,7 +311,7 @@ pub async fn register_announcer(ctx: &Context, m: &Message, mut args: Args) -> C
pub async fn remove_announcer(ctx: &Context, m: &Message, mut args: Args) -> CommandResult { pub async fn remove_announcer(ctx: &Context, m: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await; let data = ctx.data.read().await;
let key = args.single::<String>()?; let key = args.single::<String>()?;
let keys = data.get::<AnnouncerHandler>().unwrap(); let keys = data.get::<AnnouncerRunner>().unwrap();
if !keys.contains(&key.as_str()) { if !keys.contains(&key.as_str()) {
m.reply( m.reply(
&ctx, &ctx,

View file

@ -1,3 +1,4 @@
use announcer::AnnouncerChannels;
/// Module `prelude` provides a sane set of default imports that can be used inside /// Module `prelude` provides a sane set of default imports that can be used inside
/// a Youmubot source file. /// a Youmubot source file.
pub use serenity::prelude::*; pub use serenity::prelude::*;
@ -38,6 +39,20 @@ pub type AppData = Arc<RwLock<TypeMap>>;
/// The HTTP client. /// The HTTP client.
pub struct HTTPClient; pub struct HTTPClient;
/// The global context type for app commands
pub type CmdContext<'a, Env> = poise::Context<'a, Env, anyhow::Error>;
/// The created base environment.
#[derive(Debug, Clone)]
pub struct Env {
// clients
pub http: reqwest::Client,
pub sql: youmubot_db_sql::Pool,
pub members: Arc<MemberCache>,
// databases
// pub(crate) announcer_channels: announcer::AnnouncerChannels,
}
impl TypeMapKey for HTTPClient { impl TypeMapKey for HTTPClient {
type Value = reqwest::Client; type Value = reqwest::Client;
} }

View file

@ -7,6 +7,7 @@ use flume::{bounded as channel, Receiver, Sender};
use std::ops::Deref; use std::ops::Deref;
/// Holds the underlying `T` in a rate-limited way. /// Holds the underlying `T` in a rate-limited way.
#[derive(Debug, Clone)]
pub struct Ratelimit<T> { pub struct Ratelimit<T> {
inner: T, inner: T,
recv: Receiver<()>, recv: Receiver<()>,

View file

@ -1,6 +1,8 @@
use serenity::prelude::*; use serenity::prelude::*;
use std::{path::Path, time::Duration}; use std::{path::Path, time::Duration};
use crate::Env;
/// Set up the prelude libraries. /// Set up the prelude libraries.
/// ///
/// Panics on failure: Youmubot should *NOT* attempt to continue when this function fails. /// Panics on failure: Youmubot should *NOT* attempt to continue when this function fails.
@ -8,7 +10,7 @@ pub async fn setup_prelude(
db_path: impl AsRef<Path>, db_path: impl AsRef<Path>,
sql_path: impl AsRef<Path>, sql_path: impl AsRef<Path>,
data: &mut TypeMap, data: &mut TypeMap,
) { ) -> Env {
// Setup the announcer DB. // Setup the announcer DB.
crate::announcer::AnnouncerChannels::insert_into( crate::announcer::AnnouncerChannels::insert_into(
data, data,
@ -22,17 +24,25 @@ pub async fn setup_prelude(
.expect("SQL database set up"); .expect("SQL database set up");
// Set up the HTTP client. // Set up the HTTP client.
data.insert::<crate::HTTPClient>( let http_client = reqwest::ClientBuilder::new()
reqwest::ClientBuilder::new() .connect_timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(5)) .timeout(Duration::from_secs(60))
.timeout(Duration::from_secs(60)) .build()
.build() .expect("Build be able to build HTTP client");
.expect("Build be able to build HTTP client"), data.insert::<crate::HTTPClient>(http_client.clone());
);
// Set up the member cache. // Set up the member cache.
data.insert::<crate::MemberCache>(std::sync::Arc::new(crate::MemberCache::default())); let member_cache = std::sync::Arc::new(crate::MemberCache::default());
data.insert::<crate::MemberCache>(member_cache.clone());
// Set up the SQL client. // Set up the SQL client.
data.insert::<crate::SQLClient>(sql_pool); data.insert::<crate::SQLClient>(sql_pool.clone());
let env = Env {
http: http_client,
sql: sql_pool,
members: member_cache,
};
env
} }

View file

@ -13,6 +13,7 @@ codeforces = ["youmubot-cf"]
[dependencies] [dependencies]
serenity = "0.12" serenity = "0.12"
poise = "0.6"
tokio = { version = "1.19.2", features = ["rt-multi-thread"] } tokio = { version = "1.19.2", features = ["rt-multi-thread"] }
dotenv = "0.15.0" dotenv = "0.15.0"
env_logger = "0.9.0" env_logger = "0.9.0"

View file

@ -25,6 +25,11 @@ impl Framework for ComposedFramework {
.await .await
} }
} }
async fn init(&mut self, client: &Client) {
for f in self.frameworks.iter_mut() {
f.init(&client).await
}
}
} }
impl ComposedFramework { impl ComposedFramework {
/// Dispatch to all inner frameworks in a loop. Returns a `Pin<Box<Future>>` because rust. /// Dispatch to all inner frameworks in a loop. Returns a `Pin<Box<Future>>` because rust.

View file

@ -16,16 +16,44 @@ mod compose_framework;
struct Handler { struct Handler {
hooks: Vec<RwLock<Box<dyn Hook>>>, hooks: Vec<RwLock<Box<dyn Hook>>>,
ready_hooks: Vec<fn(&Context) -> CommandResult>,
} }
impl Handler { impl Handler {
fn new() -> Handler { fn new() -> Handler {
Handler { hooks: vec![] } Handler {
hooks: vec![],
ready_hooks: vec![],
}
} }
fn push_hook<T: Hook + 'static>(&mut self, f: T) { fn push_hook<T: Hook + 'static>(&mut self, f: T) {
self.hooks.push(RwLock::new(Box::new(f))); self.hooks.push(RwLock::new(Box::new(f)));
} }
fn push_ready_hook(&mut self, f: fn(&Context) -> CommandResult) {
self.ready_hooks.push(f);
}
}
/// Environment to be passed into the framework
#[derive(Debug, Clone)]
struct Env {
prelude: youmubot_prelude::Env,
#[cfg(feature = "osu")]
osu: youmubot_osu::discord::Env,
}
impl AsRef<youmubot_prelude::Env> for Env {
fn as_ref(&self) -> &youmubot_prelude::Env {
&self.prelude
}
}
impl AsRef<youmubot_osu::discord::Env> for Env {
fn as_ref(&self) -> &youmubot_osu::discord::Env {
&self.osu
}
} }
#[async_trait] #[async_trait]
@ -41,6 +69,10 @@ impl EventHandler for Handler {
.init(&ctx) .init(&ctx)
.await; .await;
println!("{} is connected!", ready.user.name); println!("{} is connected!", ready.user.name);
for f in &self.ready_hooks {
f(&ctx).pls_ok();
}
} }
async fn message(&self, ctx: Context, message: Message) { async fn message(&self, ctx: Context, message: Message) {
@ -82,20 +114,105 @@ async fn main() {
} }
let mut handler = Handler::new(); let mut handler = Handler::new();
#[cfg(feature = "core")]
handler.push_ready_hook(youmubot_core::ready_hook);
// Set up hooks // Set up hooks
#[cfg(feature = "osu")] #[cfg(feature = "osu")]
handler.push_hook(youmubot_osu::discord::hook); {
#[cfg(feature = "osu")] handler.push_hook(youmubot_osu::discord::hook);
handler.push_hook(youmubot_osu::discord::dot_osu_hook); handler.push_hook(youmubot_osu::discord::dot_osu_hook);
}
#[cfg(feature = "codeforces")] #[cfg(feature = "codeforces")]
handler.push_hook(youmubot_cf::InfoHook); handler.push_hook(youmubot_cf::InfoHook);
// Collect the token // Collect the token
let token = var("TOKEN").expect("Please set TOKEN as the Discord Bot's token to be used."); let token = var("TOKEN").expect("Please set TOKEN as the Discord Bot's token to be used.");
// Data to be put into context
let mut data = TypeMap::new();
// Set up announcer handler
let mut announcers = AnnouncerHandler::new();
// Setup each package starting from the prelude.
let env = {
let db_path = var("DBPATH")
.map(std::path::PathBuf::from)
.unwrap_or_else(|e| {
println!("No DBPATH set up ({:?}), using `/data`", e);
std::path::PathBuf::from("/data")
});
let sql_path = var("SQLPATH")
.map(std::path::PathBuf::from)
.unwrap_or_else(|e| {
let res = db_path.join("youmubot.db");
println!("No SQLPATH set up ({:?}), using `{:?}`", e, res);
res
});
let prelude = youmubot_prelude::setup::setup_prelude(&db_path, sql_path, &mut data).await;
// Setup core
#[cfg(feature = "core")]
youmubot_core::setup(&db_path, &mut data).expect("Setup db should succeed");
// osu!
#[cfg(feature = "osu")]
let osu = youmubot_osu::discord::setup(&mut data, prelude.clone(), &mut announcers)
.await
.expect("osu! is initialized");
// codeforces
#[cfg(feature = "codeforces")]
youmubot_cf::setup(&db_path, &mut data, &mut announcers).await;
Env {
prelude,
#[cfg(feature = "osu")]
osu,
}
};
#[cfg(feature = "core")]
println!("Core enabled.");
#[cfg(feature = "osu")]
println!("osu! enabled.");
#[cfg(feature = "codeforces")]
println!("codeforces enabled.");
// Set up base framework // Set up base framework
let fw = setup_framework(&token[..]).await; let fw = setup_framework(&token[..]).await;
let composed = ComposedFramework::new(vec![Box::new(fw)]); // Poise for application commands
let poise_fw = poise::Framework::builder()
.setup(|_, _, _| Box::pin(async { Ok(env) as Result<_> }))
.options(poise::FrameworkOptions {
prefix_options: poise::PrefixFrameworkOptions {
prefix: None,
mention_as_prefix: true,
execute_untracked_edits: true,
execute_self_messages: false,
ignore_thread_creation: true,
case_insensitive_commands: true,
..Default::default()
},
on_error: |err| {
Box::pin(async move {
if let poise::FrameworkError::Command { error, ctx, .. } = err {
let reply = format!(
"Command '{}' returned error {:?}",
ctx.invoked_command_name(),
error
);
ctx.reply(&reply).await.pls_ok();
println!("{}", reply)
} else {
eprintln!("Poise error: {:?}", err)
}
})
},
commands: vec![poise_register()],
..Default::default()
})
.build();
let composed = ComposedFramework::new(vec![Box::new(fw), Box::new(poise_fw)]);
// Sets up a client // Sets up a client
let mut client = { let mut client = {
@ -110,60 +227,20 @@ async fn main() {
| GatewayIntents::DIRECT_MESSAGES | GatewayIntents::DIRECT_MESSAGES
| GatewayIntents::DIRECT_MESSAGE_REACTIONS; | GatewayIntents::DIRECT_MESSAGE_REACTIONS;
Client::builder(token, intents) Client::builder(token, intents)
.type_map(data)
.framework(composed) .framework(composed)
.event_handler(handler) .event_handler(handler)
.await .await
.unwrap() .unwrap()
}; };
// Set up announcer handler let announcers = announcers.run(&client);
let mut announcers = AnnouncerHandler::new(&client);
// Setup each package starting from the prelude.
{
let mut data = client.data.write().await;
let db_path = var("DBPATH")
.map(std::path::PathBuf::from)
.unwrap_or_else(|e| {
println!("No DBPATH set up ({:?}), using `/data`", e);
std::path::PathBuf::from("/data")
});
let sql_path = var("SQLPATH")
.map(std::path::PathBuf::from)
.unwrap_or_else(|e| {
let res = db_path.join("youmubot.db");
println!("No SQLPATH set up ({:?}), using `{:?}`", e, res);
res
});
youmubot_prelude::setup::setup_prelude(&db_path, sql_path, &mut data).await;
// Setup core
#[cfg(feature = "core")]
youmubot_core::setup(&db_path, &client, &mut data).expect("Setup db should succeed");
// osu!
#[cfg(feature = "osu")]
youmubot_osu::discord::setup(&db_path, &mut data, &mut announcers)
.await
.expect("osu! is initialized");
// codeforces
#[cfg(feature = "codeforces")]
youmubot_cf::setup(&db_path, &mut data, &mut announcers).await;
}
#[cfg(feature = "core")]
println!("Core enabled.");
#[cfg(feature = "osu")]
println!("osu! enabled.");
#[cfg(feature = "codeforces")]
println!("codeforces enabled.");
tokio::spawn(announcers.scan(std::time::Duration::from_secs(300))); tokio::spawn(announcers.scan(std::time::Duration::from_secs(300)));
println!("Starting..."); println!("Starting...");
if let Err(v) = client.start().await { if let Err(v) = client.start().await {
panic!("{}", v) panic!("{}", v)
} }
println!("Hello, world!");
} }
// Sets up a framework for a client // Sets up a framework for a client
@ -229,6 +306,18 @@ async fn setup_framework(token: &str) -> StandardFramework {
fw fw
} }
// Poise command to register
#[poise::command(
prefix_command,
rename = "register",
required_permissions = "MANAGE_GUILD"
)]
async fn poise_register(ctx: CmdContext<'_, Env>) -> Result<()> {
// TODO: make this work for guild owners too
poise::builtins::register_application_commands_buttons(ctx).await?;
Ok(())
}
// Hooks! // Hooks!
#[hook] #[hook]