Fix User/Role/Channel id parsing (#35)

* Import thiserror to make nicer errors

* Create wrappers for id that exposes old parsing behavior (from mentions)

* Use id wrappers when parsing parameters
This commit is contained in:
Natsu Kagami 2024-02-28 18:02:42 +00:00 committed by GitHub
parent 431f295b41
commit b4cf6ce94f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 91 additions and 21 deletions

1
Cargo.lock generated
View file

@ -3210,6 +3210,7 @@ dependencies = [
"futures-util",
"reqwest",
"serenity",
"thiserror",
"tokio",
"youmubot-db",
"youmubot-db-sql",

View file

@ -5,10 +5,7 @@ use serenity::{
macros::{command, group},
Args, CommandResult,
},
model::{
channel::{Channel, Message},
id::UserId,
},
model::channel::{Channel, Message},
};
use soft_ban::{SOFT_BAN_COMMAND, SOFT_BAN_INIT_COMMAND};
use youmubot_prelude::*;
@ -69,7 +66,7 @@ async fn clean(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
#[max_args(2)]
#[only_in("guilds")]
async fn ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let user = args.single::<UserId>()?.to_user(&ctx).await?;
let user = args.single::<UserId>()?.0.to_user(&ctx).await?;
let reason = args.single::<String>().map(|v| format!("`{}`", v)).ok();
let dmds = args.single::<u8>().unwrap_or(0);
@ -105,7 +102,7 @@ async fn ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
#[num_args(1)]
#[only_in("guilds")]
async fn kick(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let user = args.single::<UserId>()?.to_user(&ctx).await?;
let user = args.single::<UserId>()?.0.to_user(&ctx).await?;
msg.reply(&ctx, format!("🔫 Kicking user {}.", user.tag()))
.await?;

View file

@ -3,10 +3,7 @@ use chrono::offset::Utc;
use futures_util::{stream, TryStreamExt};
use serenity::{
framework::standard::{macros::command, Args, CommandResult},
model::{
channel::Message,
id::{GuildId, RoleId, UserId},
},
model::{channel::Message, id},
};
use youmubot_prelude::*;
@ -19,7 +16,7 @@ use youmubot_prelude::*;
#[max_args(2)]
#[only_in("guilds")]
pub async fn soft_ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let user = args.single::<UserId>()?.to_user(&ctx).await?;
let user = args.single::<UserId>()?.0.to_user(&ctx).await?;
let data = ctx.data.read().await;
let duration = if args.is_empty() {
None
@ -81,7 +78,7 @@ pub async fn soft_ban(ctx: &Context, msg: &Message, mut args: Args) -> CommandRe
#[num_args(1)]
#[only_in("guilds")]
pub async fn soft_ban_init(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let role_id = args.single::<RoleId>()?;
let role_id = args.single::<RoleId>()?.0;
let data = ctx.data.read().await;
let guild = msg.guild_id.unwrap().to_partial_guild(&ctx).await?;
// Check whether the role_id is the one we wanted
@ -152,10 +149,10 @@ pub async fn watch_soft_bans(cache_http: impl CacheHttp, data: AppData) {
async fn lift_soft_ban_for(
cache_http: impl CacheHttp,
server_id: GuildId,
server_id: id::GuildId,
server_name: &str,
ban_role: RoleId,
user_id: UserId,
ban_role: id::RoleId,
user_id: id::UserId,
) -> Result<()> {
let m = server_id.member(&cache_http, user_id).await?;
println!(

View file

@ -7,7 +7,7 @@ use serenity::{
macros::{command, group},
Args, CommandResult,
},
model::{channel::Message, id::UserId},
model::channel::Message,
utils::MessageBuilder,
};
use youmubot_prelude::*;
@ -159,7 +159,7 @@ async fn name(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let user_id = if args.is_empty() {
msg.author.id
} else {
args.single::<UserId>()?
args.single::<UserId>()?.0
};
let user_mention = if user_id == msg.author.id {

View file

@ -289,7 +289,7 @@ pub async fn save(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult
pub async fn forcesave(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let data = ctx.data.read().await;
let osu = data.get::<OsuClient>().unwrap();
let target = args.single::<serenity::model::id::UserId>()?;
let target = args.single::<UserId>()?.0;
let username = args.quoted().trimmed().single::<String>()?;
let user: Option<User> = osu

View file

@ -17,6 +17,7 @@ reqwest = { version = "0.11.10", features = ["json"] }
chrono = "0.4.19"
flume = "0.10.13"
dashmap = "5.3.4"
thiserror = "1"
[dependencies.serenity]
version = "0.12"

View file

@ -1,4 +1,5 @@
pub use duration::Duration;
pub use ids::*;
pub use username_arg::UsernameArg;
mod duration {
@ -181,6 +182,73 @@ mod duration {
}
}
mod ids {
use serenity::{model::id, utils};
use std::str::FromStr;
use super::ParseError;
/// An `UserId` parsed the old way.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct UserId(pub id::UserId);
impl FromStr for UserId {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
utils::parse_user_mention(s)
.map(UserId)
.ok_or(ParseError::InvalidId)
}
}
impl AsRef<id::UserId> for UserId {
fn as_ref(&self) -> &id::UserId {
&self.0
}
}
/// An `ChannelId` parsed the old way.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ChannelId(pub id::ChannelId);
impl FromStr for ChannelId {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
utils::parse_channel_mention(s)
.map(ChannelId)
.ok_or(ParseError::InvalidId)
}
}
impl AsRef<id::ChannelId> for ChannelId {
fn as_ref(&self) -> &id::ChannelId {
&self.0
}
}
/// An `RoleId` parsed the old way.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RoleId(pub id::RoleId);
impl FromStr for RoleId {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
utils::parse_role_mention(s)
.map(RoleId)
.ok_or(ParseError::InvalidId)
}
}
impl AsRef<id::RoleId> for RoleId {
fn as_ref(&self) -> &id::RoleId {
&self.0
}
}
}
mod username_arg {
use serenity::model::id::UserId;
use std::str::FromStr;
@ -193,8 +261,8 @@ mod username_arg {
impl FromStr for UsernameArg {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.parse::<UserId>() {
Ok(v) => Ok(UsernameArg::Tagged(v)),
match s.parse::<super::UserId>() {
Ok(v) => Ok(UsernameArg::Tagged(v.0)),
Err(_) if !s.is_empty() => Ok(UsernameArg::Raw(s.to_owned())),
Err(_) => Err("username arg cannot be empty".to_owned()),
}
@ -208,3 +276,9 @@ mod username_arg {
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ParseError {
#[error("invalid id format")]
InvalidId,
}

View file

@ -13,7 +13,7 @@ pub mod ratelimit;
pub mod setup;
pub use announcer::{Announcer, AnnouncerHandler};
pub use args::{Duration, UsernameArg};
pub use args::{ChannelId, Duration, RoleId, UserId, UsernameArg};
pub use flags::Flags;
pub use hook::Hook;
pub use member_cache::MemberCache;