diff --git a/Cargo.lock b/Cargo.lock index 19c0331..dae9310 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3303,6 +3303,7 @@ dependencies = [ "poise", "reqwest", "serenity", + "thiserror", "tokio", "youmubot-db", "youmubot-db-sql", diff --git a/youmubot-core/src/admin/mod.rs b/youmubot-core/src/admin/mod.rs index 8df2c39..2c766e2 100644 --- a/youmubot-core/src/admin/mod.rs +++ b/youmubot-core/src/admin/mod.rs @@ -5,10 +5,7 @@ use serenity::{ macros::{command, group}, Args, CommandResult, }, - model::{ - channel::{Channel, Message}, - id::UserId, - }, + model::channel::{Channel, Message}, }; use soft_ban::{SOFT_BAN_COMMAND, SOFT_BAN_INIT_COMMAND}; use youmubot_prelude::*; @@ -69,7 +66,7 @@ async fn clean(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { #[max_args(2)] #[only_in("guilds")] async fn ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let user = args.single::()?.to_user(&ctx).await?; + let user = args.single::()?.0.to_user(&ctx).await?; let reason = args.single::().map(|v| format!("`{}`", v)).ok(); let dmds = args.single::().unwrap_or(0); @@ -105,7 +102,7 @@ async fn ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { #[num_args(1)] #[only_in("guilds")] async fn kick(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let user = args.single::()?.to_user(&ctx).await?; + let user = args.single::()?.0.to_user(&ctx).await?; msg.reply(&ctx, format!("🔫 Kicking user {}.", user.tag())) .await?; diff --git a/youmubot-core/src/admin/soft_ban.rs b/youmubot-core/src/admin/soft_ban.rs index cb41974..5a8239e 100644 --- a/youmubot-core/src/admin/soft_ban.rs +++ b/youmubot-core/src/admin/soft_ban.rs @@ -3,10 +3,7 @@ use chrono::offset::Utc; use futures_util::{stream, TryStreamExt}; use serenity::{ framework::standard::{macros::command, Args, CommandResult}, - model::{ - channel::Message, - id::{GuildId, RoleId, UserId}, - }, + model::{channel::Message, id}, }; use youmubot_prelude::*; @@ -19,7 +16,7 @@ use youmubot_prelude::*; #[max_args(2)] #[only_in("guilds")] pub async fn soft_ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let user = args.single::()?.to_user(&ctx).await?; + let user = args.single::()?.0.to_user(&ctx).await?; let data = ctx.data.read().await; let duration = if args.is_empty() { None @@ -81,7 +78,7 @@ pub async fn soft_ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandRe #[num_args(1)] #[only_in("guilds")] pub async fn soft_ban_init(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let role_id = args.single::()?; + let role_id = args.single::()?.0; let data = ctx.data.read().await; let guild = msg.guild_id.unwrap().to_partial_guild(&ctx).await?; // Check whether the role_id is the one we wanted @@ -152,10 +149,10 @@ pub async fn watch_soft_bans(cache_http: impl CacheHttp, data: AppData) { async fn lift_soft_ban_for( cache_http: impl CacheHttp, - server_id: GuildId, + server_id: id::GuildId, server_name: &str, - ban_role: RoleId, - user_id: UserId, + ban_role: id::RoleId, + user_id: id::UserId, ) -> Result<()> { let m = server_id.member(&cache_http, user_id).await?; println!( diff --git a/youmubot-core/src/fun/mod.rs b/youmubot-core/src/fun/mod.rs index aacf08f..7c838b2 100644 --- a/youmubot-core/src/fun/mod.rs +++ b/youmubot-core/src/fun/mod.rs @@ -7,7 +7,7 @@ use serenity::{ macros::{command, group}, Args, CommandResult, }, - model::{channel::Message, id::UserId}, + model::channel::Message, utils::MessageBuilder, }; use youmubot_prelude::*; @@ -159,7 +159,7 @@ async fn name(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let user_id = if args.is_empty() { msg.author.id } else { - args.single::()? + args.single::()?.0 }; let user_mention = if user_id == msg.author.id { diff --git a/youmubot-osu/src/discord/app_commands.rs b/youmubot-osu/src/discord/app_commands.rs index 1023d00..4062e0a 100644 --- a/youmubot-osu/src/discord/app_commands.rs +++ b/youmubot-osu/src/discord/app_commands.rs @@ -25,7 +25,7 @@ async fn check + Sync>( ctx.author(), None, osu_id, - member, + member.map(|m| m.user.id), mods, style, ) diff --git a/youmubot-osu/src/discord/mod.rs b/youmubot-osu/src/discord/mod.rs index b1687d7..43bb3c3 100644 --- a/youmubot-osu/src/discord/mod.rs +++ b/youmubot-osu/src/discord/mod.rs @@ -8,14 +8,14 @@ use crate::{ }; use rand::seq::IteratorRandom; use serenity::{ - all::{ChannelId, Member, Mention}, + all::ChannelId, builder::{CreateMessage, EditMessage}, collector, framework::standard::{ macros::{command, group}, Args, CommandResult, }, - model::channel::Message, + model::{channel::Message, id}, utils::MessageBuilder, }; use std::{str::FromStr, sync::Arc}; @@ -329,14 +329,7 @@ pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult pub async fn forcesave(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let data = ctx.data.read().await; let osu = data.get::().unwrap(); - let target = match args.single::()? { - Mention::User(id) => id, - m => { - msg.reply(&ctx, format!("Expected user_id, got {}", m)) - .await?; - return Ok(()); - } - }; + let target = args.single::()?.0; let username = args.quoted().trimmed().single::()?; let user: Option = osu @@ -619,61 +612,29 @@ pub async fn last(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let data = ctx.data.read().await; let env = data.get::().unwrap(); - let bm = load_beatmap(&env, Some(msg), msg.channel_id).await; - match bm { - None => { - msg.reply(&ctx, "No beatmap queried on this channel.") - .await?; - } - Some((bm, mods_def)) => { - let mods = args.find::().ok().or(mods_def).unwrap_or(Mods::NOMOD); - let b = &bm.0; - let m = bm.1; - let style = args - .single::() - .unwrap_or(ScoreListStyle::Grid); - let username_arg = args.single::().ok(); - let user_id = match username_arg.as_ref() { - Some(UsernameArg::Tagged(v)) => Some(*v), - None => Some(msg.author.id), - _ => None, - }; - let user = to_user_id_query(username_arg, &env, &msg.author).await?; - - let osu = data.get::().unwrap(); - - let user = osu - .user(user, |f| f) - .await? - .ok_or_else(|| Error::msg("User not found"))?; - let mut scores = osu - .scores(b.beatmap_id, |f| f.user(UserID::ID(user.id)).mode(m)) - .await? - .into_iter() - .filter(|s| s.mods.contains(mods)) - .collect::>(); - scores.sort_by(|a, b| b.pp.unwrap().partial_cmp(&a.pp.unwrap()).unwrap()); - - if scores.is_empty() { - msg.reply(&ctx, "No scores found").await?; - return Ok(()); - } - - if let Some(user_id) = user_id { - // Save to database - data.get::() - .unwrap() - .save(user_id, m, scores.clone()) - .await - .pls_ok(); - } - - display::scores::display_scores(style, scores, m, ctx, msg.clone(), msg.channel_id) - .await?; - } - } + let mods = args.find::().ok(); + let style = args.single::().ok(); + let username_arg = args.single::().ok(); + let (osu_id, member) = match username_arg { + Some(UsernameArg::Tagged(user_id)) => (None, Some(user_id)), + Some(UsernameArg::Raw(s)) => (Some(s), None), + None => (None, None), + }; + check_impl( + env, + ctx, + msg.clone(), + msg.channel_id, + &msg.author, + Some(msg), + osu_id, + member, + mods, + style, + ) + .await?; Ok(()) } @@ -685,49 +646,48 @@ pub async fn check_impl( sender: &serenity::all::User, msg: Option<&Message>, osu_id: Option, - member: Option, + member: Option, mods: Option, style: Option, ) -> CommandResult { let bm = load_beatmap(&env, msg, channel_id).await; - match bm { + let BeatmapWithMode(b, m) = match bm { + Some((bm, _)) => bm, None => { reply .reply(&ctx, "No beatmap queried on this channel.") .await?; + return Ok(()); } - Some((bm, mods_def)) => { - let mods = mods.unwrap_or_default(); - let b = &bm.0; - let m = bm.1; - let style = style.unwrap_or_default(); - let username_arg = member - .map(|m| UsernameArg::Tagged(m.user.id)) - .or(osu_id.map(|id| UsernameArg::Raw(id))); - let user = to_user_id_query(username_arg, env, sender).await?; + }; - let user = env - .client - .user(user, |f| f) - .await? - .ok_or_else(|| Error::msg("User not found"))?; - let mut scores = env - .client - .scores(b.beatmap_id, |f| f.user(UserID::ID(user.id)).mode(m)) - .await? - .into_iter() - .filter(|s| s.mods.contains(mods)) - .collect::>(); - scores.sort_by(|a, b| b.pp.unwrap().partial_cmp(&a.pp.unwrap()).unwrap()); + let mods = mods.unwrap_or_default(); + let style = style.unwrap_or_default(); + let username_arg = member + .map(|m| UsernameArg::Tagged(m)) + .or(osu_id.map(|id| UsernameArg::Raw(id))); + let user = to_user_id_query(username_arg, env, sender).await?; - if scores.is_empty() { - reply.reply(&ctx, "No scores found").await?; - return Ok(()); - } - display::scores::display_scores(style, scores, m, ctx, reply, channel_id).await?; - } + let user = env + .client + .user(user, |f| f) + .await? + .ok_or_else(|| Error::msg("User not found"))?; + let mut scores = env + .client + .scores(b.beatmap_id, |f| f.user(UserID::ID(user.id)).mode(m)) + .await? + .into_iter() + .filter(|s| s.mods.contains(mods)) + .collect::>(); + scores.sort_by(|a, b| b.pp.unwrap().partial_cmp(&a.pp.unwrap()).unwrap()); + + if scores.is_empty() { + reply.reply(&ctx, "No scores found").await?; + return Ok(()); } + display::scores::display_scores(style, scores, m, ctx, reply, channel_id).await?; Ok(()) } diff --git a/youmubot-prelude/Cargo.toml b/youmubot-prelude/Cargo.toml index 2bbc8f0..f229ca0 100644 --- a/youmubot-prelude/Cargo.toml +++ b/youmubot-prelude/Cargo.toml @@ -17,6 +17,7 @@ reqwest = { version = "0.11.10", features = ["json"] } chrono = "0.4.19" flume = "0.10.13" dashmap = "5.3.4" +thiserror = "1" poise = "0.6" [dependencies.serenity] diff --git a/youmubot-prelude/src/args.rs b/youmubot-prelude/src/args.rs index f4a02e0..f9aa5c7 100644 --- a/youmubot-prelude/src/args.rs +++ b/youmubot-prelude/src/args.rs @@ -1,4 +1,5 @@ pub use duration::Duration; +pub use ids::*; pub use username_arg::UsernameArg; mod duration { @@ -181,6 +182,73 @@ mod duration { } } +mod ids { + use serenity::{model::id, utils}; + use std::str::FromStr; + + use super::ParseError; + + /// An `UserId` parsed the old way. + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub struct UserId(pub id::UserId); + + impl FromStr for UserId { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + utils::parse_user_mention(s) + .map(UserId) + .ok_or(ParseError::InvalidId) + } + } + + impl AsRef for UserId { + fn as_ref(&self) -> &id::UserId { + &self.0 + } + } + + /// An `ChannelId` parsed the old way. + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub struct ChannelId(pub id::ChannelId); + + impl FromStr for ChannelId { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + utils::parse_channel_mention(s) + .map(ChannelId) + .ok_or(ParseError::InvalidId) + } + } + + impl AsRef for ChannelId { + fn as_ref(&self) -> &id::ChannelId { + &self.0 + } + } + + /// An `RoleId` parsed the old way. + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub struct RoleId(pub id::RoleId); + + impl FromStr for RoleId { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + utils::parse_role_mention(s) + .map(RoleId) + .ok_or(ParseError::InvalidId) + } + } + + impl AsRef for RoleId { + fn as_ref(&self) -> &id::RoleId { + &self.0 + } + } +} + mod username_arg { use serenity::model::id::UserId; use std::str::FromStr; @@ -193,8 +261,8 @@ mod username_arg { impl FromStr for UsernameArg { type Err = String; fn from_str(s: &str) -> Result { - match s.parse::() { - Ok(v) => Ok(UsernameArg::Tagged(v)), + match s.parse::() { + Ok(v) => Ok(UsernameArg::Tagged(v.0)), Err(_) if !s.is_empty() => Ok(UsernameArg::Raw(s.to_owned())), Err(_) => Err("username arg cannot be empty".to_owned()), } @@ -208,3 +276,9 @@ mod username_arg { } } } + +#[derive(Debug, thiserror::Error)] +pub enum ParseError { + #[error("invalid id format")] + InvalidId, +} diff --git a/youmubot-prelude/src/lib.rs b/youmubot-prelude/src/lib.rs index e4950a7..9af9c3e 100644 --- a/youmubot-prelude/src/lib.rs +++ b/youmubot-prelude/src/lib.rs @@ -15,7 +15,7 @@ pub mod replyable; pub mod setup; pub use announcer::{Announcer, AnnouncerHandler}; -pub use args::{Duration, UsernameArg}; +pub use args::{ChannelId, Duration, RoleId, UserId, UsernameArg}; pub use flags::Flags; pub use hook::Hook; pub use member_cache::MemberCache;