Split db into youmubot-db

This commit is contained in:
Natsu Kagami 2020-02-02 21:51:54 -05:00
parent 9287bdf5b7
commit c4916a24f7
16 changed files with 231 additions and 176 deletions

14
Cargo.lock generated
View file

@ -1684,13 +1684,25 @@ dependencies = [
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.3.3 (registry+https://github.com/rust-lang/crates.io-index)", "regex 1.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
"reqwest 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)", "reqwest 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rustbreak 2.0.0-rc3 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
"serenity 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", "serenity 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
"static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"youmubot-db 0.1.0",
"youmubot-osu 0.1.0", "youmubot-osu 0.1.0",
] ]
[[package]]
name = "youmubot-db"
version = "0.1.0"
dependencies = [
"chrono 0.4.10 (registry+https://github.com/rust-lang/crates.io-index)",
"dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rustbreak 2.0.0-rc3 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
"serenity 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]] [[package]]
name = "youmubot-osu" name = "youmubot-osu"
version = "0.1.0" version = "0.1.0"

View file

@ -1,6 +1,7 @@
[workspace] [workspace]
members = [ members = [
"youmubot-db",
"youmubot-osu", "youmubot-osu",
"youmubot", "youmubot",
] ]

24
youmubot-db/Cargo.toml Normal file
View file

@ -0,0 +1,24 @@
[package]
name = "youmubot-db"
version = "0.1.0"
authors = ["Natsu Kagami <natsukagami@gmail.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
serenity = "0.8"
dotenv = "0.15"
serde = { version = "1.0", features = ["derive"] }
chrono = "0.4.9"
# rand = "0.7.2"
# static_assertions = "1.1.0"
# reqwest = "0.10.1"
# regex = "1"
# lazy_static = "1"
# youmubot-osu = { path = "../youmubot-osu" }
rayon = "1.1"
[dependencies.rustbreak]
version = "2.0.0-rc3"
features = ["yaml_enc"]

73
youmubot-db/src/lib.rs Normal file
View file

@ -0,0 +1,73 @@
use rustbreak::{deser::Yaml as Ron, FileDatabase};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serenity::{framework::standard::CommandError as Error, model::id::GuildId, prelude::*};
use std::collections::HashMap;
use std::path::Path;
/// GuildMap defines the guild-map type.
/// It is basically a HashMap from a GuildId to a data structure.
pub type GuildMap<V> = HashMap<GuildId, V>;
/// The generic DB type we will be using.
pub struct DB<T>(std::marker::PhantomData<T>);
impl<T: std::any::Any> serenity::prelude::TypeMapKey for DB<T> {
type Value = FileDatabase<T, Ron>;
}
impl<T: std::any::Any + Default + Send + Sync + Clone + Serialize + std::fmt::Debug> DB<T>
where
for<'de> T: Deserialize<'de>,
{
/// Insert into a ShareMap.
pub fn insert_into(data: &mut ShareMap, path: impl AsRef<Path>) -> Result<(), Error> {
let db = FileDatabase::<T, Ron>::from_path(path, T::default())?;
db.load().or_else(|e| {
dbg!(e);
db.save()
})?;
data.insert::<DB<T>>(db);
Ok(())
}
/// Open a previously inserted DB.
pub fn open(data: &ShareMap) -> DBWriteGuard<'_, T> {
data.get::<Self>().expect("DB initialized").into()
}
}
/// The write guard for our FileDatabase.
/// It wraps the FileDatabase in a write-on-drop lock.
pub struct DBWriteGuard<'a, T>(&'a FileDatabase<T, Ron>)
where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned;
impl<'a, T> From<&'a FileDatabase<T, Ron>> for DBWriteGuard<'a, T>
where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned,
{
fn from(v: &'a FileDatabase<T, Ron>) -> Self {
DBWriteGuard(v)
}
}
impl<'a, T> DBWriteGuard<'a, T>
where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned,
{
/// Borrows the FileDatabase.
pub fn borrow(&self) -> Result<std::sync::RwLockReadGuard<T>, rustbreak::RustbreakError> {
(*self).0.borrow_data()
}
/// Borrows the FileDatabase for writing.
pub fn borrow_mut(&self) -> Result<std::sync::RwLockWriteGuard<T>, rustbreak::RustbreakError> {
(*self).0.borrow_data_mut()
}
}
impl<'a, T> Drop for DBWriteGuard<'a, T>
where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned,
{
fn drop(&mut self) {
self.0.save().expect("Save succeed")
}
}

View file

@ -18,7 +18,5 @@ regex = "1"
lazy_static = "1" lazy_static = "1"
youmubot-osu = { path = "../youmubot-osu" } youmubot-osu = { path = "../youmubot-osu" }
rayon = "1.1" rayon = "1.1"
youmubot-db = { path = "../youmubot-db" }
[dependencies.rustbreak]
version = "2.0.0-rc3"
features = ["yaml_enc"]

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
commands::args, commands::args,
db::{DBWriteGuard, ServerSoftBans, SoftBans}, db::{ServerSoftBans, SoftBans},
}; };
use chrono::offset::Utc; use chrono::offset::Utc;
use serenity::prelude::*; use serenity::prelude::*;
@ -33,13 +33,10 @@ pub fn soft_ban(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResu
}; };
let guild = msg.guild_id.ok_or(Error::from("Command is guild only"))?; let guild = msg.guild_id.ok_or(Error::from("Command is guild only"))?;
let data = ctx.data.read(); let db = ctx.data.read();
let data = data let db = SoftBans::open(&*db);
.get::<SoftBans>() let mut db = db.borrow_mut()?;
.ok_or(Error::from("DB initialized")) let mut server_ban = db.get_mut(&guild).and_then(|v| match v {
.map(|v| DBWriteGuard::from(v))?;
let mut data = data.borrow_mut()?;
let mut server_ban = data.get_mut(&guild).and_then(|v| match v {
ServerSoftBans::Unimplemented => None, ServerSoftBans::Unimplemented => None,
ServerSoftBans::Implemented(ref mut v) => Some(v), ServerSoftBans::Implemented(ref mut v) => Some(v),
}); });
@ -98,11 +95,8 @@ pub fn soft_ban_init(ctx: &mut Context, msg: &Message, mut args: Args) -> Comman
))); )));
} }
// Check if we already set up // Check if we already set up
let data = ctx.data.read(); let db = ctx.data.read();
let db: DBWriteGuard<_> = data let db = SoftBans::open(&*db);
.get::<SoftBans>()
.ok_or(Error::from("DB uninitialized"))?
.into();
let mut db = db.borrow_mut()?; let mut db = db.borrow_mut()?;
let server = db let server = db
.get(&guild.id) .get(&guild.id)
@ -135,12 +129,9 @@ pub fn watch_soft_bans(client: &mut serenity::Client) -> impl FnOnce() -> () + '
// Scope so that locks are released // Scope so that locks are released
{ {
// Poll the data for any changes. // Poll the data for any changes.
let data = data.read(); let db = data.read();
let db: DBWriteGuard<_> = data let db = SoftBans::open(&*db);
.get::<SoftBans>() let mut db = db.borrow_mut().expect("Borrowable");
.expect("DB wrongly initialized")
.into();
let mut db = db.borrow_mut().expect("cannot unpack DB");
let now = Utc::now(); let now = Utc::now();
for (server_id, soft_bans) in db.iter_mut() { for (server_id, soft_bans) in db.iter_mut() {
let server_name: String = match server_id.to_partial_guild(cache_http) { let server_name: String = match server_id.to_partial_guild(cache_http) {

View file

@ -1,9 +1,9 @@
use crate::db::{AnnouncerChannels, DBWriteGuard}; use crate::db::AnnouncerChannels;
use crate::prelude::*;
use serenity::{ use serenity::{
framework::standard::{CommandError as Error, CommandResult}, framework::standard::{CommandError as Error, CommandResult},
http::{CacheHttp, Http}, http::{CacheHttp, Http},
model::id::{ChannelId, GuildId, UserId}, model::id::{ChannelId, GuildId, UserId},
prelude::ShareMap,
}; };
use std::{ use std::{
collections::HashSet, collections::HashSet,
@ -14,33 +14,30 @@ pub trait Announcer {
fn announcer_key() -> &'static str; fn announcer_key() -> &'static str;
fn send_messages( fn send_messages(
c: &Http, c: &Http,
d: &ShareMap, d: AppData,
channels: impl Fn(UserId) -> Vec<ChannelId> + Sync, channels: impl Fn(UserId) -> Vec<ChannelId> + Sync,
) -> CommandResult; ) -> CommandResult;
fn set_channel(d: &ShareMap, guild: GuildId, channel: ChannelId) -> CommandResult { fn set_channel(d: AppData, guild: GuildId, channel: ChannelId) -> CommandResult {
let data: DBWriteGuard<_> = d.get::<AnnouncerChannels>().expect("DB initialized").into(); AnnouncerChannels::open(&*d.read())
let mut data = data.borrow_mut()?; .borrow_mut()?
data.entry(Self::announcer_key().to_owned()) .entry(Self::announcer_key().to_owned())
.or_default() .or_default()
.insert(guild, channel); .insert(guild, channel);
Ok(()) Ok(())
} }
fn get_guilds(d: &ShareMap) -> Result<Vec<(GuildId, ChannelId)>, Error> { fn get_guilds(d: AppData) -> Result<Vec<(GuildId, ChannelId)>, Error> {
let data = d let data = AnnouncerChannels::open(&*d.read())
.get::<AnnouncerChannels>() .borrow()?
.expect("DB initialized") .get(Self::announcer_key())
.read(|v| { .map(|m| m.iter().map(|(a, b)| (*a, *b)).collect())
v.get(Self::announcer_key()) .unwrap_or_else(|| vec![]);
.map(|m| m.iter().map(|(a, b)| (*a, *b)).collect())
.unwrap_or_else(|| vec![])
})?;
Ok(data) Ok(data)
} }
fn announce(c: &Http, d: &ShareMap) -> CommandResult { fn announce(c: impl AsRef<Http>, d: AppData) -> CommandResult {
let guilds: Vec<_> = Self::get_guilds(d)?; let guilds: Vec<_> = Self::get_guilds(d.clone())?;
let member_sets = { let member_sets = {
let mut v = Vec::with_capacity(guilds.len()); let mut v = Vec::with_capacity(guilds.len());
for (guild, channel) in guilds.into_iter() { for (guild, channel) in guilds.into_iter() {
@ -72,7 +69,7 @@ pub trait Announcer {
let c = client.cache_and_http.clone(); let c = client.cache_and_http.clone();
let data = client.data.clone(); let data = client.data.clone();
spawn(move || loop { spawn(move || loop {
if let Err(e) = Self::announce(c.http(), &*data.read()) { if let Err(e) = Self::announce(c.http(), data.clone()) {
dbg!(e); dbg!(e);
} }
std::thread::sleep(cooldown); std::thread::sleep(cooldown);

View file

@ -1,8 +1,6 @@
use crate::http::HTTP; use crate::prelude::*;
use reqwest::blocking::Client as HTTPClient;
use serde::Deserialize; use serde::Deserialize;
use serenity::framework::standard::CommandError as Error; use serenity::framework::standard::CommandError as Error;
use serenity::prelude::*;
use serenity::{ use serenity::{
framework::standard::{ framework::standard::{
macros::{check, command}, macros::{check, command},
@ -45,9 +43,8 @@ fn nsfw_check(ctx: &mut Context, msg: &Message, _: &mut Args, _: &CommandOptions
fn message_command(ctx: &mut Context, msg: &Message, args: Args, rating: Rating) -> CommandResult { fn message_command(ctx: &mut Context, msg: &Message, args: Args, rating: Rating) -> CommandResult {
let tags = args.remains().unwrap_or("touhou"); let tags = args.remains().unwrap_or("touhou");
let http = ctx.data.read(); let http = ctx.data.get_cloned::<HTTPClient>();
let http = http.get::<HTTP>().unwrap(); let image = get_image(&http, rating, tags)?;
let image = get_image(http, rating, tags)?;
match image { match image {
None => msg.reply(&ctx, "🖼️ No image found...\n💡 Tip: In danbooru, character names follow Japanese standards (last name before first name), so **Hakurei Reimu** might give you an image while **Reimu Hakurei** won't."), None => msg.reply(&ctx, "🖼️ No image found...\n💡 Tip: In danbooru, character names follow Japanese standards (last name before first name), so **Hakurei Reimu** might give you an image while **Reimu Hakurei** won't."),
Some(url) => msg.reply( Some(url) => msg.reply(
@ -59,7 +56,11 @@ fn message_command(ctx: &mut Context, msg: &Message, args: Args, rating: Rating)
} }
// Gets an image URL. // Gets an image URL.
fn get_image(client: &HTTPClient, rating: Rating, tags: &str) -> Result<Option<String>, Error> { fn get_image(
client: &reqwest::blocking::Client,
rating: Rating,
tags: &str,
) -> Result<Option<String>, Error> {
// Fix the tags: change whitespaces to + // Fix the tags: change whitespaces to +
let tags = tags.split_whitespace().collect::<Vec<_>>().join("_"); let tags = tags.split_whitespace().collect::<Vec<_>>().join("_");
let req = client let req = client

View file

@ -2,22 +2,18 @@ use super::{embeds::score_embed, BeatmapWithMode};
use crate::{ use crate::{
commands::announcer::Announcer, commands::announcer::Announcer,
db::{OsuSavedUsers, OsuUser}, db::{OsuSavedUsers, OsuUser},
http::Osu, prelude::*,
}; };
use rayon::prelude::*; use rayon::prelude::*;
use serenity::{ use serenity::{
framework::standard::{CommandError as Error, CommandResult}, framework::standard::{CommandError as Error, CommandResult},
http::Http, http::Http,
model::{ model::id::{ChannelId, UserId},
id::{ChannelId, UserId},
misc::Mentionable,
},
prelude::ShareMap,
}; };
use youmubot_osu::{ use youmubot_osu::{
models::{Mode, Score}, models::{Mode, Score},
request::{BeatmapRequestKind, UserID}, request::{BeatmapRequestKind, UserID},
Client as OsuClient, Client as Osu,
}; };
/// Announce osu! top scores. /// Announce osu! top scores.
@ -29,15 +25,12 @@ impl Announcer for OsuAnnouncer {
} }
fn send_messages( fn send_messages(
c: &Http, c: &Http,
d: &ShareMap, d: AppData,
channels: impl Fn(UserId) -> Vec<ChannelId> + Sync, channels: impl Fn(UserId) -> Vec<ChannelId> + Sync,
) -> CommandResult { ) -> CommandResult {
let osu = d.get::<Osu>().expect("osu!client").clone(); let osu = d.get_cloned::<OsuClient>();
// For each user... // For each user...
let mut data = d let mut data = OsuSavedUsers::open(&*d.read()).borrow()?.clone();
.get::<OsuSavedUsers>()
.expect("DB initialized")
.read(|f| f.clone())?;
for (user_id, osu_user) in data.iter_mut() { for (user_id, osu_user) in data.iter_mut() {
let mut user = None; let mut user = None;
for mode in &[Mode::Std, Mode::Taiko, Mode::Mania, Mode::Catch] { for mode in &[Mode::Std, Mode::Taiko, Mode::Mania, Mode::Catch] {
@ -86,15 +79,13 @@ impl Announcer for OsuAnnouncer {
osu_user.last_update = chrono::Utc::now(); osu_user.last_update = chrono::Utc::now();
} }
// Update users // Update users
let f = d.get::<OsuSavedUsers>().expect("DB initialized"); *OsuSavedUsers::open(&*d.read()).borrow_mut()? = data;
f.write(|f| *f = data)?;
f.save()?;
Ok(()) Ok(())
} }
} }
impl OsuAnnouncer { impl OsuAnnouncer {
fn scan_user(osu: &OsuClient, u: &OsuUser, mode: Mode) -> Result<Vec<(u8, Score)>, Error> { fn scan_user(osu: &Osu, u: &OsuUser, mode: Mode) -> Result<Vec<(u8, Score)>, Error> {
let scores = osu.user_best(UserID::ID(u.id), |f| f.mode(mode).limit(25))?; let scores = osu.user_best(UserID::ID(u.id), |f| f.mode(mode).limit(25))?;
let scores = scores let scores = scores
.into_iter() .into_iter()

View file

@ -1,5 +1,5 @@
use super::BeatmapWithMode; use super::BeatmapWithMode;
use crate::db::{DBWriteGuard, OsuLastBeatmap}; use crate::db::OsuLastBeatmap;
use serenity::{ use serenity::{
framework::standard::{CommandError as Error, CommandResult}, framework::standard::{CommandError as Error, CommandResult},
model::id::ChannelId, model::id::ChannelId,
@ -12,10 +12,7 @@ pub(crate) fn save_beatmap(
channel_id: ChannelId, channel_id: ChannelId,
bm: &BeatmapWithMode, bm: &BeatmapWithMode,
) -> CommandResult { ) -> CommandResult {
let db: DBWriteGuard<_> = data let db = OsuLastBeatmap::open(data);
.get::<OsuLastBeatmap>()
.expect("DB is implemented")
.into();
let mut db = db.borrow_mut()?; let mut db = db.borrow_mut()?;
db.insert(channel_id, (bm.0.clone(), bm.mode())); db.insert(channel_id, (bm.0.clone(), bm.mode()));
@ -28,8 +25,8 @@ pub(crate) fn get_beatmap(
data: &ShareMap, data: &ShareMap,
channel_id: ChannelId, channel_id: ChannelId,
) -> Result<Option<BeatmapWithMode>, Error> { ) -> Result<Option<BeatmapWithMode>, Error> {
let db = data.get::<OsuLastBeatmap>().expect("DB is implemented"); let db = OsuLastBeatmap::open(data);
let db = db.borrow_data()?; let db = db.borrow()?;
Ok(db Ok(db
.get(&channel_id) .get(&channel_id)

View file

@ -1,11 +1,10 @@
use crate::http; use crate::prelude::*;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use serenity::{ use serenity::{
builder::CreateMessage, builder::CreateMessage,
framework::standard::{CommandError as Error, CommandResult}, framework::standard::{CommandError as Error, CommandResult},
model::channel::Message, model::channel::Message,
prelude::*,
utils::MessageBuilder, utils::MessageBuilder,
}; };
use youmubot_osu::{ use youmubot_osu::{
@ -71,7 +70,7 @@ struct ToPrint<'a> {
} }
fn handle_old_links<'a>(ctx: &mut Context, content: &'a str) -> Result<Vec<ToPrint<'a>>, Error> { fn handle_old_links<'a>(ctx: &mut Context, content: &'a str) -> Result<Vec<ToPrint<'a>>, Error> {
let osu = ctx.data.read().get::<http::Osu>().unwrap().clone(); let osu = ctx.data.get_cloned::<OsuClient>();
let mut to_prints: Vec<ToPrint<'a>> = Vec::new(); let mut to_prints: Vec<ToPrint<'a>> = Vec::new();
for capture in OLD_LINK_REGEX.captures_iter(content) { for capture in OLD_LINK_REGEX.captures_iter(content) {
let req_type = capture.name("link_type").unwrap().as_str(); let req_type = capture.name("link_type").unwrap().as_str();
@ -121,7 +120,7 @@ fn handle_old_links<'a>(ctx: &mut Context, content: &'a str) -> Result<Vec<ToPri
} }
fn handle_new_links<'a>(ctx: &mut Context, content: &'a str) -> Result<Vec<ToPrint<'a>>, Error> { fn handle_new_links<'a>(ctx: &mut Context, content: &'a str) -> Result<Vec<ToPrint<'a>>, Error> {
let osu = ctx.data.read().get::<http::Osu>().unwrap().clone(); let osu = ctx.data.get_cloned::<OsuClient>();
let mut to_prints: Vec<ToPrint<'a>> = Vec::new(); let mut to_prints: Vec<ToPrint<'a>> = Vec::new();
for capture in NEW_LINK_REGEX.captures_iter(content) { for capture in NEW_LINK_REGEX.captures_iter(content) {
let mode = capture.name("mode").and_then(|v| { let mode = capture.name("mode").and_then(|v| {

View file

@ -1,19 +1,17 @@
use crate::db::{DBWriteGuard, OsuSavedUsers, OsuUser}; use crate::db::{OsuSavedUsers, OsuUser};
use crate::http; use crate::prelude::*;
use serenity::{ use serenity::{
framework::standard::{ framework::standard::{
macros::{command, group}, macros::{command, group},
Args, CommandError as Error, CommandResult, Args, CommandError as Error, CommandResult,
}, },
model::{channel::Message, id::UserId}, model::{channel::Message, id::UserId},
prelude::*,
utils::MessageBuilder, utils::MessageBuilder,
}; };
use std::str::FromStr; use std::str::FromStr;
use youmubot_osu::{ use youmubot_osu::{
models::{Beatmap, Mode, User}, models::{Beatmap, Mode, User},
request::{BeatmapRequestKind, UserID}, request::{BeatmapRequestKind, UserID},
Client as OsuClient,
}; };
mod announcer; mod announcer;
@ -91,17 +89,14 @@ impl AsRef<Beatmap> for BeatmapWithMode {
#[usage = "[username or user_id]"] #[usage = "[username or user_id]"]
#[num_args(1)] #[num_args(1)]
pub fn save(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { pub fn save(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
let osu = ctx.data.read().get::<http::Osu>().unwrap().clone(); let osu = ctx.data.get_cloned::<OsuClient>();
let user = args.single::<String>()?; let user = args.single::<String>()?;
let user: Option<User> = osu.user(UserID::Auto(user), |f| f)?; let user: Option<User> = osu.user(UserID::Auto(user), |f| f)?;
match user { match user {
Some(u) => { Some(u) => {
let db = ctx.data.read(); let db = ctx.data.read();
let db: DBWriteGuard<_> = db let db = OsuSavedUsers::open(&db);
.get::<OsuSavedUsers>()
.ok_or(Error::from("DB uninitialized"))?
.into();
let mut db = db.borrow_mut()?; let mut db = db.borrow_mut()?;
db.insert( db.insert(
@ -153,10 +148,8 @@ impl UsernameArg {
Some(UsernameArg::Tagged(r)) => r, Some(UsernameArg::Tagged(r)) => r,
None => msg.author.id, None => msg.author.id,
}; };
let db: DBWriteGuard<_> = data
.get::<OsuSavedUsers>() let db = OsuSavedUsers::open(data);
.ok_or(Error::from("DB uninitialized"))?
.into();
let db = db.borrow()?; let db = db.borrow()?;
db.get(&id) db.get(&id)
.cloned() .cloned()
@ -201,7 +194,7 @@ pub fn recent(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult
let user = let user =
UsernameArg::to_user_id_query(args.single::<UsernameArg>().ok(), &*ctx.data.read(), msg)?; UsernameArg::to_user_id_query(args.single::<UsernameArg>().ok(), &*ctx.data.read(), msg)?;
let osu: OsuClient = ctx.data.read().get::<http::Osu>().unwrap().clone(); let osu = ctx.data.get_cloned::<OsuClient>();
let user = osu let user = osu
.user(user, |f| f.mode(mode))? .user(user, |f| f.mode(mode))?
.ok_or(Error::from("User not found"))?; .ok_or(Error::from("User not found"))?;
@ -277,7 +270,7 @@ pub fn check(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult
msg, msg,
)?; )?;
let osu = ctx.data.read().get::<http::Osu>().unwrap().clone(); let osu = ctx.data.get_cloned::<OsuClient>();
let user = osu let user = osu
.user(user, |f| f)? .user(user, |f| f)?
@ -314,7 +307,7 @@ pub fn top(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
let user = let user =
UsernameArg::to_user_id_query(args.single::<UsernameArg>().ok(), &*ctx.data.read(), msg)?; UsernameArg::to_user_id_query(args.single::<UsernameArg>().ok(), &*ctx.data.read(), msg)?;
let osu: OsuClient = ctx.data.read().get::<http::Osu>().unwrap().clone(); let osu = ctx.data.get_cloned::<OsuClient>();
let user = osu let user = osu
.user(user, |f| f.mode(mode))? .user(user, |f| f.mode(mode))?
.ok_or(Error::from("User not found"))?; .ok_or(Error::from("User not found"))?;
@ -352,7 +345,7 @@ pub fn top(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult {
fn get_user(ctx: &mut Context, msg: &Message, mut args: Args, mode: Mode) -> CommandResult { fn get_user(ctx: &mut Context, msg: &Message, mut args: Args, mode: Mode) -> CommandResult {
let user = let user =
UsernameArg::to_user_id_query(args.single::<UsernameArg>().ok(), &*ctx.data.read(), msg)?; UsernameArg::to_user_id_query(args.single::<UsernameArg>().ok(), &*ctx.data.read(), msg)?;
let osu = ctx.data.read().get::<http::Osu>().unwrap().clone(); let osu = ctx.data.get_cloned::<OsuClient>();
let user = osu.user(user, |f| f.mode(mode))?; let user = osu.user(user, |f| f.mode(mode))?;
match user { match user {
Some(u) => { Some(u) => {

View file

@ -1,41 +1,17 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use dotenv::var; use dotenv::var;
use rustbreak::{deser::Yaml as Ron, FileDatabase};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serenity::{ use serenity::{
client::Client, client::Client,
framework::standard::CommandError as Error, framework::standard::CommandError as Error,
model::id::{ChannelId, GuildId, RoleId, UserId}, model::id::{ChannelId, RoleId, UserId},
prelude::*,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::path::{Path, PathBuf}; use std::path::PathBuf;
use youmubot_db::{GuildMap, DB};
use youmubot_osu::models::{Beatmap, Mode}; use youmubot_osu::models::{Beatmap, Mode};
/// GuildMap defines the guild-map type.
/// It is basically a HashMap from a GuildId to a data structure.
pub type GuildMap<V> = HashMap<GuildId, V>;
/// The generic DB type we will be using.
pub struct DB<T>(std::marker::PhantomData<T>);
impl<T: std::any::Any> serenity::prelude::TypeMapKey for DB<T> {
type Value = FileDatabase<T, Ron>;
}
impl<T: std::any::Any + Default + Send + Sync + Clone + Serialize + std::fmt::Debug> DB<T>
where
for<'de> T: Deserialize<'de>,
{
fn insert_into(data: &mut ShareMap, path: impl AsRef<Path>) -> Result<(), Error> {
let db = FileDatabase::<T, Ron>::from_path(path, T::default())?;
db.load().or_else(|e| {
dbg!(e);
db.save()
})?;
data.insert::<DB<T>>(db);
Ok(())
}
}
/// A map from announcer keys to guild IDs and to channels. /// A map from announcer keys to guild IDs and to channels.
pub type AnnouncerChannels = DB<HashMap<String, GuildMap<ChannelId>>>; pub type AnnouncerChannels = DB<HashMap<String, GuildMap<ChannelId>>>;
@ -63,40 +39,6 @@ pub fn setup_db(client: &mut Client) -> Result<(), Error> {
Ok(()) Ok(())
} }
pub struct DBWriteGuard<'a, T>(&'a FileDatabase<T, Ron>)
where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned;
impl<'a, T> From<&'a FileDatabase<T, Ron>> for DBWriteGuard<'a, T>
where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned,
{
fn from(v: &'a FileDatabase<T, Ron>) -> Self {
DBWriteGuard(v)
}
}
impl<'a, T> DBWriteGuard<'a, T>
where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned,
{
pub fn borrow(&self) -> Result<std::sync::RwLockReadGuard<T>, rustbreak::RustbreakError> {
(*self).0.borrow_data()
}
pub fn borrow_mut(&self) -> Result<std::sync::RwLockWriteGuard<T>, rustbreak::RustbreakError> {
(*self).0.borrow_data_mut()
}
}
impl<'a, T> Drop for DBWriteGuard<'a, T>
where
T: Send + Sync + Clone + std::fmt::Debug + Serialize + DeserializeOwned,
{
fn drop(&mut self) {
self.0.save().expect("Save succeed")
}
}
/// For the admin commands: /// For the admin commands:
/// - Each server might have a `soft ban` role implemented. /// - Each server might have a `soft ban` role implemented.
/// - We allow periodical `soft ban` applications. /// - We allow periodical `soft ban` applications.

View file

@ -1,14 +0,0 @@
use serenity::prelude::TypeMapKey;
use youmubot_osu::Client as OsuClient;
pub(crate) struct HTTP;
impl TypeMapKey for HTTP {
type Value = reqwest::blocking::Client;
}
pub(crate) struct Osu;
impl TypeMapKey for Osu {
type Value = OsuClient;
}

View file

@ -4,16 +4,16 @@ use reqwest;
use serenity::{ use serenity::{
framework::standard::{DispatchError, StandardFramework}, framework::standard::{DispatchError, StandardFramework},
model::{channel::Message, gateway}, model::{channel::Message, gateway},
prelude::*,
}; };
use youmubot_osu::Client as OsuClient; use youmubot_osu::Client as OsuApiClient;
mod commands; mod commands;
mod db; mod db;
mod http; mod prelude;
use commands::osu::OsuAnnouncer; use commands::osu::OsuAnnouncer;
use commands::Announcer; use commands::Announcer;
use prelude::*;
const MESSAGE_HOOKS: [fn(&mut Context, &Message) -> (); 1] = [commands::osu::hook]; const MESSAGE_HOOKS: [fn(&mut Context, &Message) -> (); 1] = [commands::osu::hook];
@ -49,8 +49,8 @@ fn main() {
{ {
let mut data = client.data.write(); let mut data = client.data.write();
let http_client = reqwest::blocking::Client::new(); let http_client = reqwest::blocking::Client::new();
data.insert::<http::HTTP>(http_client.clone()); data.insert::<HTTPClient>(http_client.clone());
data.insert::<http::Osu>(OsuClient::new( data.insert::<OsuClient>(OsuApiClient::new(
http_client.clone(), http_client.clone(),
var("OSU_API_KEY").expect("Please set OSU_API_KEY as osu! api key."), var("OSU_API_KEY").expect("Please set OSU_API_KEY as osu! api key."),
)); ));

50
youmubot/src/prelude.rs Normal file
View file

@ -0,0 +1,50 @@
use std::sync::Arc;
use youmubot_osu::Client as OsuHttpClient;
pub use serenity::prelude::*;
/// The global app data.
pub type AppData = Arc<RwLock<ShareMap>>;
/// The HTTP client.
pub(crate) struct HTTPClient;
impl TypeMapKey for HTTPClient {
type Value = reqwest::blocking::Client;
}
/// The osu! client.
pub(crate) struct OsuClient;
impl TypeMapKey for OsuClient {
type Value = OsuHttpClient;
}
/// The TypeMap trait that allows TypeMaps to quickly get a clonable item.
pub trait GetCloned {
/// Gets an item from the store, cloned.
fn get_cloned<T>(&self) -> T::Value
where
T: TypeMapKey,
T::Value: Clone + Send + Sync;
}
impl GetCloned for ShareMap {
fn get_cloned<T>(&self) -> T::Value
where
T: TypeMapKey,
T::Value: Clone + Send + Sync,
{
self.get::<T>().cloned().expect("Should be there")
}
}
impl GetCloned for AppData {
fn get_cloned<T>(&self) -> T::Value
where
T: TypeMapKey,
T::Value: Clone + Send + Sync,
{
self.read().get::<T>().cloned().expect("Should be there")
}
}