diff --git a/Cargo.lock b/Cargo.lock index a0b61f6..5295786 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1684,13 +1684,25 @@ dependencies = [ "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "regex 1.3.3 (registry+https://github.com/rust-lang/crates.io-index)", "reqwest 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)", - "rustbreak 2.0.0-rc3 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", "serenity 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", "static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "youmubot-db 0.1.0", "youmubot-osu 0.1.0", ] +[[package]] +name = "youmubot-db" +version = "0.1.0" +dependencies = [ + "chrono 0.4.10 (registry+https://github.com/rust-lang/crates.io-index)", + "dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "rustbreak 2.0.0-rc3 (registry+https://github.com/rust-lang/crates.io-index)", + "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", + "serenity 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "youmubot-osu" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 10fc90c..1be4f54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ + "youmubot-db", "youmubot-osu", "youmubot", ] diff --git a/youmubot-db/Cargo.toml b/youmubot-db/Cargo.toml new file mode 100644 index 0000000..3008027 --- /dev/null +++ b/youmubot-db/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "youmubot-db" +version = "0.1.0" +authors = ["Natsu Kagami "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +serenity = "0.8" +dotenv = "0.15" +serde = { version = "1.0", features = ["derive"] } +chrono = "0.4.9" +# rand = "0.7.2" +# static_assertions = "1.1.0" +# reqwest = "0.10.1" +# regex = "1" +# lazy_static = "1" +# youmubot-osu = { path = "../youmubot-osu" } +rayon = "1.1" + +[dependencies.rustbreak] +version = "2.0.0-rc3" +features = ["yaml_enc"] diff --git a/youmubot-db/src/lib.rs b/youmubot-db/src/lib.rs new file mode 100644 index 0000000..aa82582 --- /dev/null +++ b/youmubot-db/src/lib.rs @@ -0,0 +1,73 @@ +use rustbreak::{deser::Yaml as Ron, FileDatabase}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serenity::{framework::standard::CommandError as Error, model::id::GuildId, prelude::*}; +use std::collections::HashMap; +use std::path::Path; + +/// GuildMap defines the guild-map type. +/// It is basically a HashMap from a GuildId to a data structure. +pub type GuildMap = HashMap; +/// The generic DB type we will be using. +pub struct DB(std::marker::PhantomData); +impl serenity::prelude::TypeMapKey for DB { + type Value = FileDatabase; +} + +impl DB +where + for<'de> T: Deserialize<'de>, +{ + /// Insert into a ShareMap. + pub fn insert_into(data: &mut ShareMap, path: impl AsRef) -> Result<(), Error> { + let db = FileDatabase::::from_path(path, T::default())?; + db.load().or_else(|e| { + dbg!(e); + db.save() + })?; + data.insert::>(db); + Ok(()) + } + + /// Open a previously inserted DB. + pub fn open(data: &ShareMap) -> DBWriteGuard<'_, T> { + data.get::().expect("DB initialized").into() + } +} + +/// The write guard for our FileDatabase. +/// It wraps the FileDatabase in a write-on-drop lock. +pub struct DBWriteGuard<'a, T>(&'a FileDatabase) +where + T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned; + +impl<'a, T> From<&'a FileDatabase> for DBWriteGuard<'a, T> +where + T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned, +{ + fn from(v: &'a FileDatabase) -> Self { + DBWriteGuard(v) + } +} + +impl<'a, T> DBWriteGuard<'a, T> +where + T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned, +{ + /// Borrows the FileDatabase. + pub fn borrow(&self) -> Result, rustbreak::RustbreakError> { + (*self).0.borrow_data() + } + /// Borrows the FileDatabase for writing. + pub fn borrow_mut(&self) -> Result, rustbreak::RustbreakError> { + (*self).0.borrow_data_mut() + } +} + +impl<'a, T> Drop for DBWriteGuard<'a, T> +where + T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned, +{ + fn drop(&mut self) { + self.0.save().expect("Save succeed") + } +} diff --git a/youmubot/Cargo.toml b/youmubot/Cargo.toml index 5df5fbb..3114212 100644 --- a/youmubot/Cargo.toml +++ b/youmubot/Cargo.toml @@ -18,7 +18,5 @@ regex = "1" lazy_static = "1" youmubot-osu = { path = "../youmubot-osu" } rayon = "1.1" +youmubot-db = { path = "../youmubot-db" } -[dependencies.rustbreak] -version = "2.0.0-rc3" -features = ["yaml_enc"] diff --git a/youmubot/src/commands/admin/soft_ban.rs b/youmubot/src/commands/admin/soft_ban.rs index 0b3aeff..f142773 100644 --- a/youmubot/src/commands/admin/soft_ban.rs +++ b/youmubot/src/commands/admin/soft_ban.rs @@ -1,6 +1,6 @@ use crate::{ commands::args, - db::{DBWriteGuard, ServerSoftBans, SoftBans}, + db::{ServerSoftBans, SoftBans}, }; use chrono::offset::Utc; use serenity::prelude::*; @@ -33,13 +33,10 @@ pub fn soft_ban(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResu }; let guild = msg.guild_id.ok_or(Error::from("Command is guild only"))?; - let data = ctx.data.read(); - let data = data - .get::() - .ok_or(Error::from("DB initialized")) - .map(|v| DBWriteGuard::from(v))?; - let mut data = data.borrow_mut()?; - let mut server_ban = data.get_mut(&guild).and_then(|v| match v { + let db = ctx.data.read(); + let db = SoftBans::open(&*db); + let mut db = db.borrow_mut()?; + let mut server_ban = db.get_mut(&guild).and_then(|v| match v { ServerSoftBans::Unimplemented => None, ServerSoftBans::Implemented(ref mut v) => Some(v), }); @@ -98,11 +95,8 @@ pub fn soft_ban_init(ctx: &mut Context, msg: &Message, mut args: Args) -> Comman ))); } // Check if we already set up - let data = ctx.data.read(); - let db: DBWriteGuard<_> = data - .get::() - .ok_or(Error::from("DB uninitialized"))? - .into(); + let db = ctx.data.read(); + let db = SoftBans::open(&*db); let mut db = db.borrow_mut()?; let server = db .get(&guild.id) @@ -135,12 +129,9 @@ pub fn watch_soft_bans(client: &mut serenity::Client) -> impl FnOnce() -> () + ' // Scope so that locks are released { // Poll the data for any changes. - let data = data.read(); - let db: DBWriteGuard<_> = data - .get::() - .expect("DB wrongly initialized") - .into(); - let mut db = db.borrow_mut().expect("cannot unpack DB"); + let db = data.read(); + let db = SoftBans::open(&*db); + let mut db = db.borrow_mut().expect("Borrowable"); let now = Utc::now(); for (server_id, soft_bans) in db.iter_mut() { let server_name: String = match server_id.to_partial_guild(cache_http) { diff --git a/youmubot/src/commands/announcer.rs b/youmubot/src/commands/announcer.rs index 7a767e5..d36d940 100644 --- a/youmubot/src/commands/announcer.rs +++ b/youmubot/src/commands/announcer.rs @@ -1,9 +1,9 @@ -use crate::db::{AnnouncerChannels, DBWriteGuard}; +use crate::db::AnnouncerChannels; +use crate::prelude::*; use serenity::{ framework::standard::{CommandError as Error, CommandResult}, http::{CacheHttp, Http}, model::id::{ChannelId, GuildId, UserId}, - prelude::ShareMap, }; use std::{ collections::HashSet, @@ -14,33 +14,30 @@ pub trait Announcer { fn announcer_key() -> &'static str; fn send_messages( c: &Http, - d: &ShareMap, + d: AppData, channels: impl Fn(UserId) -> Vec + Sync, ) -> CommandResult; - fn set_channel(d: &ShareMap, guild: GuildId, channel: ChannelId) -> CommandResult { - let data: DBWriteGuard<_> = d.get::().expect("DB initialized").into(); - let mut data = data.borrow_mut()?; - data.entry(Self::announcer_key().to_owned()) + fn set_channel(d: AppData, guild: GuildId, channel: ChannelId) -> CommandResult { + AnnouncerChannels::open(&*d.read()) + .borrow_mut()? + .entry(Self::announcer_key().to_owned()) .or_default() .insert(guild, channel); Ok(()) } - fn get_guilds(d: &ShareMap) -> Result, Error> { - let data = d - .get::() - .expect("DB initialized") - .read(|v| { - v.get(Self::announcer_key()) - .map(|m| m.iter().map(|(a, b)| (*a, *b)).collect()) - .unwrap_or_else(|| vec![]) - })?; + fn get_guilds(d: AppData) -> Result, Error> { + let data = AnnouncerChannels::open(&*d.read()) + .borrow()? + .get(Self::announcer_key()) + .map(|m| m.iter().map(|(a, b)| (*a, *b)).collect()) + .unwrap_or_else(|| vec![]); Ok(data) } - fn announce(c: &Http, d: &ShareMap) -> CommandResult { - let guilds: Vec<_> = Self::get_guilds(d)?; + fn announce(c: impl AsRef, d: AppData) -> CommandResult { + let guilds: Vec<_> = Self::get_guilds(d.clone())?; let member_sets = { let mut v = Vec::with_capacity(guilds.len()); for (guild, channel) in guilds.into_iter() { @@ -72,7 +69,7 @@ pub trait Announcer { let c = client.cache_and_http.clone(); let data = client.data.clone(); spawn(move || loop { - if let Err(e) = Self::announce(c.http(), &*data.read()) { + if let Err(e) = Self::announce(c.http(), data.clone()) { dbg!(e); } std::thread::sleep(cooldown); diff --git a/youmubot/src/commands/fun/images.rs b/youmubot/src/commands/fun/images.rs index 32bb9f1..0898a29 100644 --- a/youmubot/src/commands/fun/images.rs +++ b/youmubot/src/commands/fun/images.rs @@ -1,8 +1,6 @@ -use crate::http::HTTP; -use reqwest::blocking::Client as HTTPClient; +use crate::prelude::*; use serde::Deserialize; use serenity::framework::standard::CommandError as Error; -use serenity::prelude::*; use serenity::{ framework::standard::{ macros::{check, command}, @@ -45,9 +43,8 @@ fn nsfw_check(ctx: &mut Context, msg: &Message, _: &mut Args, _: &CommandOptions fn message_command(ctx: &mut Context, msg: &Message, args: Args, rating: Rating) -> CommandResult { let tags = args.remains().unwrap_or("touhou"); - let http = ctx.data.read(); - let http = http.get::().unwrap(); - let image = get_image(http, rating, tags)?; + let http = ctx.data.get_cloned::(); + let image = get_image(&http, rating, tags)?; match image { None => msg.reply(&ctx, "🖼️ No image found...\n💡 Tip: In danbooru, character names follow Japanese standards (last name before first name), so **Hakurei Reimu** might give you an image while **Reimu Hakurei** won't."), Some(url) => msg.reply( @@ -59,7 +56,11 @@ fn message_command(ctx: &mut Context, msg: &Message, args: Args, rating: Rating) } // Gets an image URL. -fn get_image(client: &HTTPClient, rating: Rating, tags: &str) -> Result, Error> { +fn get_image( + client: &reqwest::blocking::Client, + rating: Rating, + tags: &str, +) -> Result, Error> { // Fix the tags: change whitespaces to + let tags = tags.split_whitespace().collect::>().join("_"); let req = client diff --git a/youmubot/src/commands/osu/announcer.rs b/youmubot/src/commands/osu/announcer.rs index d832191..a0c80d7 100644 --- a/youmubot/src/commands/osu/announcer.rs +++ b/youmubot/src/commands/osu/announcer.rs @@ -2,22 +2,18 @@ use super::{embeds::score_embed, BeatmapWithMode}; use crate::{ commands::announcer::Announcer, db::{OsuSavedUsers, OsuUser}, - http::Osu, + prelude::*, }; use rayon::prelude::*; use serenity::{ framework::standard::{CommandError as Error, CommandResult}, http::Http, - model::{ - id::{ChannelId, UserId}, - misc::Mentionable, - }, - prelude::ShareMap, + model::id::{ChannelId, UserId}, }; use youmubot_osu::{ models::{Mode, Score}, request::{BeatmapRequestKind, UserID}, - Client as OsuClient, + Client as Osu, }; /// Announce osu! top scores. @@ -29,15 +25,12 @@ impl Announcer for OsuAnnouncer { } fn send_messages( c: &Http, - d: &ShareMap, + d: AppData, channels: impl Fn(UserId) -> Vec + Sync, ) -> CommandResult { - let osu = d.get::().expect("osu!client").clone(); + let osu = d.get_cloned::(); // For each user... - let mut data = d - .get::() - .expect("DB initialized") - .read(|f| f.clone())?; + let mut data = OsuSavedUsers::open(&*d.read()).borrow()?.clone(); for (user_id, osu_user) in data.iter_mut() { let mut user = None; for mode in &[Mode::Std, Mode::Taiko, Mode::Mania, Mode::Catch] { @@ -86,15 +79,13 @@ impl Announcer for OsuAnnouncer { osu_user.last_update = chrono::Utc::now(); } // Update users - let f = d.get::().expect("DB initialized"); - f.write(|f| *f = data)?; - f.save()?; + *OsuSavedUsers::open(&*d.read()).borrow_mut()? = data; Ok(()) } } impl OsuAnnouncer { - fn scan_user(osu: &OsuClient, u: &OsuUser, mode: Mode) -> Result, Error> { + fn scan_user(osu: &Osu, u: &OsuUser, mode: Mode) -> Result, Error> { let scores = osu.user_best(UserID::ID(u.id), |f| f.mode(mode).limit(25))?; let scores = scores .into_iter() diff --git a/youmubot/src/commands/osu/cache.rs b/youmubot/src/commands/osu/cache.rs index e52355b..8e3af26 100644 --- a/youmubot/src/commands/osu/cache.rs +++ b/youmubot/src/commands/osu/cache.rs @@ -1,5 +1,5 @@ use super::BeatmapWithMode; -use crate::db::{DBWriteGuard, OsuLastBeatmap}; +use crate::db::OsuLastBeatmap; use serenity::{ framework::standard::{CommandError as Error, CommandResult}, model::id::ChannelId, @@ -12,10 +12,7 @@ pub(crate) fn save_beatmap( channel_id: ChannelId, bm: &BeatmapWithMode, ) -> CommandResult { - let db: DBWriteGuard<_> = data - .get::() - .expect("DB is implemented") - .into(); + let db = OsuLastBeatmap::open(data); let mut db = db.borrow_mut()?; db.insert(channel_id, (bm.0.clone(), bm.mode())); @@ -28,8 +25,8 @@ pub(crate) fn get_beatmap( data: &ShareMap, channel_id: ChannelId, ) -> Result, Error> { - let db = data.get::().expect("DB is implemented"); - let db = db.borrow_data()?; + let db = OsuLastBeatmap::open(data); + let db = db.borrow()?; Ok(db .get(&channel_id) diff --git a/youmubot/src/commands/osu/hook.rs b/youmubot/src/commands/osu/hook.rs index c087580..2b3bf03 100644 --- a/youmubot/src/commands/osu/hook.rs +++ b/youmubot/src/commands/osu/hook.rs @@ -1,11 +1,10 @@ -use crate::http; +use crate::prelude::*; use lazy_static::lazy_static; use regex::Regex; use serenity::{ builder::CreateMessage, framework::standard::{CommandError as Error, CommandResult}, model::channel::Message, - prelude::*, utils::MessageBuilder, }; use youmubot_osu::{ @@ -71,7 +70,7 @@ struct ToPrint<'a> { } fn handle_old_links<'a>(ctx: &mut Context, content: &'a str) -> Result>, Error> { - let osu = ctx.data.read().get::().unwrap().clone(); + let osu = ctx.data.get_cloned::(); let mut to_prints: Vec> = Vec::new(); for capture in OLD_LINK_REGEX.captures_iter(content) { let req_type = capture.name("link_type").unwrap().as_str(); @@ -121,7 +120,7 @@ fn handle_old_links<'a>(ctx: &mut Context, content: &'a str) -> Result(ctx: &mut Context, content: &'a str) -> Result>, Error> { - let osu = ctx.data.read().get::().unwrap().clone(); + let osu = ctx.data.get_cloned::(); let mut to_prints: Vec> = Vec::new(); for capture in NEW_LINK_REGEX.captures_iter(content) { let mode = capture.name("mode").and_then(|v| { diff --git a/youmubot/src/commands/osu/mod.rs b/youmubot/src/commands/osu/mod.rs index 571c6d6..159acfe 100644 --- a/youmubot/src/commands/osu/mod.rs +++ b/youmubot/src/commands/osu/mod.rs @@ -1,19 +1,17 @@ -use crate::db::{DBWriteGuard, OsuSavedUsers, OsuUser}; -use crate::http; +use crate::db::{OsuSavedUsers, OsuUser}; +use crate::prelude::*; use serenity::{ framework::standard::{ macros::{command, group}, Args, CommandError as Error, CommandResult, }, model::{channel::Message, id::UserId}, - prelude::*, utils::MessageBuilder, }; use std::str::FromStr; use youmubot_osu::{ models::{Beatmap, Mode, User}, request::{BeatmapRequestKind, UserID}, - Client as OsuClient, }; mod announcer; @@ -91,17 +89,14 @@ impl AsRef for BeatmapWithMode { #[usage = "[username or user_id]"] #[num_args(1)] pub fn save(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { - let osu = ctx.data.read().get::().unwrap().clone(); + let osu = ctx.data.get_cloned::(); let user = args.single::()?; let user: Option = osu.user(UserID::Auto(user), |f| f)?; match user { Some(u) => { let db = ctx.data.read(); - let db: DBWriteGuard<_> = db - .get::() - .ok_or(Error::from("DB uninitialized"))? - .into(); + let db = OsuSavedUsers::open(&db); let mut db = db.borrow_mut()?; db.insert( @@ -153,10 +148,8 @@ impl UsernameArg { Some(UsernameArg::Tagged(r)) => r, None => msg.author.id, }; - let db: DBWriteGuard<_> = data - .get::() - .ok_or(Error::from("DB uninitialized"))? - .into(); + + let db = OsuSavedUsers::open(data); let db = db.borrow()?; db.get(&id) .cloned() @@ -201,7 +194,7 @@ pub fn recent(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult let user = UsernameArg::to_user_id_query(args.single::().ok(), &*ctx.data.read(), msg)?; - let osu: OsuClient = ctx.data.read().get::().unwrap().clone(); + let osu = ctx.data.get_cloned::(); let user = osu .user(user, |f| f.mode(mode))? .ok_or(Error::from("User not found"))?; @@ -277,7 +270,7 @@ pub fn check(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult msg, )?; - let osu = ctx.data.read().get::().unwrap().clone(); + let osu = ctx.data.get_cloned::(); let user = osu .user(user, |f| f)? @@ -314,7 +307,7 @@ pub fn top(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { let user = UsernameArg::to_user_id_query(args.single::().ok(), &*ctx.data.read(), msg)?; - let osu: OsuClient = ctx.data.read().get::().unwrap().clone(); + let osu = ctx.data.get_cloned::(); let user = osu .user(user, |f| f.mode(mode))? .ok_or(Error::from("User not found"))?; @@ -352,7 +345,7 @@ pub fn top(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { fn get_user(ctx: &mut Context, msg: &Message, mut args: Args, mode: Mode) -> CommandResult { let user = UsernameArg::to_user_id_query(args.single::().ok(), &*ctx.data.read(), msg)?; - let osu = ctx.data.read().get::().unwrap().clone(); + let osu = ctx.data.get_cloned::(); let user = osu.user(user, |f| f.mode(mode))?; match user { Some(u) => { diff --git a/youmubot/src/db/mod.rs b/youmubot/src/db.rs similarity index 52% rename from youmubot/src/db/mod.rs rename to youmubot/src/db.rs index e8f4c98..c6b7a1d 100644 --- a/youmubot/src/db/mod.rs +++ b/youmubot/src/db.rs @@ -1,41 +1,17 @@ use chrono::{DateTime, Utc}; use dotenv::var; -use rustbreak::{deser::Yaml as Ron, FileDatabase}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use serde::{Deserialize, Serialize}; use serenity::{ client::Client, framework::standard::CommandError as Error, - model::id::{ChannelId, GuildId, RoleId, UserId}, - prelude::*, + model::id::{ChannelId, RoleId, UserId}, }; use std::collections::HashMap; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; +use youmubot_db::{GuildMap, DB}; use youmubot_osu::models::{Beatmap, Mode}; -/// GuildMap defines the guild-map type. -/// It is basically a HashMap from a GuildId to a data structure. -pub type GuildMap = HashMap; -/// The generic DB type we will be using. -pub struct DB(std::marker::PhantomData); -impl serenity::prelude::TypeMapKey for DB { - type Value = FileDatabase; -} - -impl DB -where - for<'de> T: Deserialize<'de>, -{ - fn insert_into(data: &mut ShareMap, path: impl AsRef) -> Result<(), Error> { - let db = FileDatabase::::from_path(path, T::default())?; - db.load().or_else(|e| { - dbg!(e); - db.save() - })?; - data.insert::>(db); - Ok(()) - } -} - /// A map from announcer keys to guild IDs and to channels. pub type AnnouncerChannels = DB>>; @@ -63,40 +39,6 @@ pub fn setup_db(client: &mut Client) -> Result<(), Error> { Ok(()) } -pub struct DBWriteGuard<'a, T>(&'a FileDatabase) -where - T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned; - -impl<'a, T> From<&'a FileDatabase> for DBWriteGuard<'a, T> -where - T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned, -{ - fn from(v: &'a FileDatabase) -> Self { - DBWriteGuard(v) - } -} - -impl<'a, T> DBWriteGuard<'a, T> -where - T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned, -{ - pub fn borrow(&self) -> Result, rustbreak::RustbreakError> { - (*self).0.borrow_data() - } - pub fn borrow_mut(&self) -> Result, rustbreak::RustbreakError> { - (*self).0.borrow_data_mut() - } -} - -impl<'a, T> Drop for DBWriteGuard<'a, T> -where - T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned, -{ - fn drop(&mut self) { - self.0.save().expect("Save succeed") - } -} - /// For the admin commands: /// - Each server might have a `soft ban` role implemented. /// - We allow periodical `soft ban` applications. diff --git a/youmubot/src/http.rs b/youmubot/src/http.rs deleted file mode 100644 index 9f319c4..0000000 --- a/youmubot/src/http.rs +++ /dev/null @@ -1,14 +0,0 @@ -use serenity::prelude::TypeMapKey; -use youmubot_osu::Client as OsuClient; - -pub(crate) struct HTTP; - -impl TypeMapKey for HTTP { - type Value = reqwest::blocking::Client; -} - -pub(crate) struct Osu; - -impl TypeMapKey for Osu { - type Value = OsuClient; -} diff --git a/youmubot/src/main.rs b/youmubot/src/main.rs index 8b5b504..5227392 100644 --- a/youmubot/src/main.rs +++ b/youmubot/src/main.rs @@ -4,16 +4,16 @@ use reqwest; use serenity::{ framework::standard::{DispatchError, StandardFramework}, model::{channel::Message, gateway}, - prelude::*, }; -use youmubot_osu::Client as OsuClient; +use youmubot_osu::Client as OsuApiClient; mod commands; mod db; -mod http; +mod prelude; use commands::osu::OsuAnnouncer; use commands::Announcer; +use prelude::*; const MESSAGE_HOOKS: [fn(&mut Context, &Message) -> (); 1] = [commands::osu::hook]; @@ -49,8 +49,8 @@ fn main() { { let mut data = client.data.write(); let http_client = reqwest::blocking::Client::new(); - data.insert::(http_client.clone()); - data.insert::(OsuClient::new( + data.insert::(http_client.clone()); + data.insert::(OsuApiClient::new( http_client.clone(), var("OSU_API_KEY").expect("Please set OSU_API_KEY as osu! api key."), )); diff --git a/youmubot/src/prelude.rs b/youmubot/src/prelude.rs new file mode 100644 index 0000000..dc98a94 --- /dev/null +++ b/youmubot/src/prelude.rs @@ -0,0 +1,50 @@ +use std::sync::Arc; +use youmubot_osu::Client as OsuHttpClient; + +pub use serenity::prelude::*; + +/// The global app data. +pub type AppData = Arc>; + +/// The HTTP client. +pub(crate) struct HTTPClient; + +impl TypeMapKey for HTTPClient { + type Value = reqwest::blocking::Client; +} + +/// The osu! client. +pub(crate) struct OsuClient; + +impl TypeMapKey for OsuClient { + type Value = OsuHttpClient; +} + +/// The TypeMap trait that allows TypeMaps to quickly get a clonable item. +pub trait GetCloned { + /// Gets an item from the store, cloned. + fn get_cloned(&self) -> T::Value + where + T: TypeMapKey, + T::Value: Clone + Send + Sync; +} + +impl GetCloned for ShareMap { + fn get_cloned(&self) -> T::Value + where + T: TypeMapKey, + T::Value: Clone + Send + Sync, + { + self.get::().cloned().expect("Should be there") + } +} + +impl GetCloned for AppData { + fn get_cloned(&self) -> T::Value + where + T: TypeMapKey, + T::Value: Clone + Send + Sync, + { + self.read().get::().cloned().expect("Should be there") + } +}