Refactor out link parsers

This commit is contained in:
Natsu Kagami 2024-07-13 22:14:32 +02:00
parent 48635cad78
commit 1d250b8ea7
Signed by: nki
GPG key ID: 55A032EB38B49ADB
3 changed files with 212 additions and 282 deletions

View file

@ -1,14 +1,12 @@
use std::str::FromStr;
use std::sync::Arc;
use futures_util::stream::FuturesOrdered;
use lazy_static::lazy_static;
use pagination::paginate_from_fn;
use regex::Regex;
use serenity::{
all::EditMessage, builder::CreateMessage, model::channel::Message, utils::MessageBuilder,
};
use stream::Stream;
use youmubot_prelude::*;
use crate::discord::embeds::score_embed;
@ -19,24 +17,7 @@ use crate::{
};
use super::embeds::beatmap_embed;
lazy_static! {
// Beatmap(set) hooks
pub(crate) static ref OLD_LINK_REGEX: Regex = Regex::new(
r"(?:https?://)?osu\.ppy\.sh/(?P<link_type>s|b)/(?P<id>\d+)(?:[\&\?]m=(?P<mode>\d))?(?:\+(?P<mods>[A-Z]+))?"
).unwrap();
pub(crate) static ref NEW_LINK_REGEX: Regex = Regex::new(
r"(?:https?://)?osu\.ppy\.sh/beatmapsets/(?P<set_id>\d+)/?(?:\#(?P<mode>osu|taiko|fruits|mania)(?:/(?P<beatmap_id>\d+)|/?))?(?:\+(?P<mods>[A-Z]+))?"
).unwrap();
pub(crate) static ref SHORT_LINK_REGEX: Regex = Regex::new(
r"(?:^|\s|\W)(?P<main>/b/(?P<id>\d+)(?:/(?P<mode>osu|taiko|fruits|mania))?(?:\+(?P<mods>[A-Z]+))?)"
).unwrap();
// Score hook
pub(crate) static ref SCORE_LINK_REGEX: Regex = Regex::new(
r"(?:https?://)?osu\.ppy\.sh/scores/(?P<score_id>\d+)"
).unwrap();
}
use super::link_parser::*;
/// React to /scores/{id} links.
pub fn score_hook<'a>(
@ -242,12 +223,29 @@ pub fn hook<'a>(
if msg.author.bot {
return Ok(());
}
let (old_links, new_links, short_links) = (
handle_old_links(ctx, &msg.content),
handle_new_links(ctx, &msg.content),
handle_short_links(ctx, msg, &msg.content),
let env = ctx.data.read().await.get::<OsuEnv>().unwrap().clone();
let (old_links, new_links) = (
parse_old_links(&env, &msg.content),
parse_new_links(&env, &msg.content),
);
stream::select(old_links, stream::select(new_links, short_links))
let to_join: Box<dyn Stream<Item = _> + Unpin + Send> = {
let use_short_link = if let Some(guild_id) = msg.guild_id {
announcer::announcer_of(ctx, crate::discord::announcer::ANNOUNCER_KEY, guild_id)
.await?
== Some(msg.channel_id)
} else {
false
};
if use_short_link {
Box::new(stream::select(
old_links,
stream::select(new_links, parse_short_links(&env, &msg.content)),
))
} else {
Box::new(stream::select(old_links, new_links))
}
};
to_join
.then(|l| async move {
match l.embed {
EmbedType::Beatmap(b, info, mods) => {
@ -277,225 +275,6 @@ pub fn hook<'a>(
})
}
enum EmbedType {
Beatmap(Box<Beatmap>, BeatmapInfoWithPP, Mods),
Beatmapset(Vec<Beatmap>),
}
struct ToPrint<'a> {
embed: EmbedType,
link: &'a str,
mode: Option<Mode>,
}
fn handle_old_links<'a>(
ctx: &'a Context,
content: &'a str,
) -> impl stream::Stream<Item = ToPrint<'a>> + 'a {
OLD_LINK_REGEX
.captures_iter(content)
.map(move |capture| async move {
let data = ctx.data.read().await;
let env = data.get::<OsuEnv>().unwrap();
let req_type = capture.name("link_type").unwrap().as_str();
let mode = capture
.name("mode")
.map(|v| v.as_str().parse())
.transpose()?
.and_then(|v| {
Some(match v {
0 => Mode::Std,
1 => Mode::Taiko,
2 => Mode::Catch,
3 => Mode::Mania,
_ => return None,
})
});
let beatmaps = match req_type {
"b" => vec![match mode {
Some(mode) => {
env.beatmaps
.get_beatmap(capture["id"].parse()?, mode)
.await?
}
None => {
env.beatmaps
.get_beatmap_default(capture["id"].parse()?)
.await?
}
}],
"s" => env.beatmaps.get_beatmapset(capture["id"].parse()?).await?,
_ => unreachable!(),
};
if beatmaps.is_empty() {
return Ok(None);
}
let r: Result<_> = Ok(match req_type {
"b" => {
let b = Box::new(beatmaps.into_iter().next().unwrap());
// collect beatmap info
let mods = capture
.name("mods")
.and_then(|v| Mods::from_str(v.as_str()).pls_ok())
.unwrap_or(Mods::NOMOD);
let info = {
let mode = mode.unwrap_or(b.mode);
env.oppai
.get_beatmap(b.beatmap_id)
.await
.and_then(|b| b.get_possible_pp_with(mode, mods))?
};
Some(ToPrint {
embed: EmbedType::Beatmap(b, info, mods),
link: capture.get(0).unwrap().as_str(),
mode,
})
}
"s" => Some(ToPrint {
embed: EmbedType::Beatmapset(beatmaps),
link: capture.get(0).unwrap().as_str(),
mode,
}),
_ => None,
});
r
})
.collect::<stream::FuturesUnordered<_>>()
.filter_map(|v| {
future::ready(v.unwrap_or_else(|e| {
eprintln!("{}", e);
None
}))
})
}
fn handle_new_links<'a>(
ctx: &'a Context,
content: &'a str,
) -> impl stream::Stream<Item = ToPrint<'a>> + 'a {
NEW_LINK_REGEX
.captures_iter(content)
.map(|capture| async move {
let env = ctx.data.read().await.get::<OsuEnv>().unwrap().clone();
let mode = capture
.name("mode")
.and_then(|v| Mode::parse_from_new_site(v.as_str()));
let link = capture.get(0).unwrap().as_str();
let beatmaps = match capture.name("beatmap_id") {
Some(ref v) => vec![match mode {
Some(mode) => env.beatmaps.get_beatmap(v.as_str().parse()?, mode).await?,
None => {
env.beatmaps
.get_beatmap_default(v.as_str().parse()?)
.await?
}
}],
None => {
env.beatmaps
.get_beatmapset(capture.name("set_id").unwrap().as_str().parse()?)
.await?
}
};
if beatmaps.is_empty() {
return Ok(None);
}
let r: Result<_> = Ok(match capture.name("beatmap_id") {
Some(_) => {
let beatmap = Box::new(beatmaps.into_iter().next().unwrap());
// collect beatmap info
let mods = capture
.name("mods")
.and_then(|v| Mods::from_str(v.as_str()).pls_ok())
.unwrap_or(Mods::NOMOD);
let info = {
let mode = mode.unwrap_or(beatmap.mode);
env.oppai
.get_beatmap(beatmap.beatmap_id)
.await
.and_then(|b| b.get_possible_pp_with(mode, mods))?
};
Some(ToPrint {
embed: EmbedType::Beatmap(beatmap, info, mods),
link,
mode,
})
}
None => Some(ToPrint {
embed: EmbedType::Beatmapset(beatmaps),
link,
mode,
}),
});
r
})
.collect::<stream::FuturesUnordered<_>>()
.filter_map(|v| {
future::ready(match v {
Ok(v) => v,
Err(e) => {
eprintln!("{}", e);
None
}
})
})
}
fn handle_short_links<'a>(
ctx: &'a Context,
msg: &'a Message,
content: &'a str,
) -> impl stream::Stream<Item = ToPrint<'a>> + 'a {
SHORT_LINK_REGEX
.captures_iter(content)
.map(|capture| async move {
if let Some(guild_id) = msg.guild_id {
if announcer::announcer_of(ctx, crate::discord::announcer::ANNOUNCER_KEY, guild_id)
.await?
!= Some(msg.channel_id)
{
// Disable if we are not in the server's announcer channel
return Err(Error::msg("not in server announcer channel"));
}
}
let env = ctx.data.read().await.get::<OsuEnv>().unwrap().clone();
let mode = capture
.name("mode")
.and_then(|v| Mode::parse_from_new_site(v.as_str()));
let id: u64 = capture.name("id").unwrap().as_str().parse()?;
let beatmap = match mode {
Some(mode) => env.beatmaps.get_beatmap(id, mode).await,
None => env.beatmaps.get_beatmap_default(id).await,
}?;
let mods = capture
.name("mods")
.and_then(|v| Mods::from_str(v.as_str()).pls_ok())
.unwrap_or(Mods::NOMOD);
let info = {
let mode = mode.unwrap_or(beatmap.mode);
env.oppai
.get_beatmap(beatmap.beatmap_id)
.await
.and_then(|b| b.get_possible_pp_with(mode, mods))?
};
let r: Result<_> = Ok(ToPrint {
embed: EmbedType::Beatmap(Box::new(beatmap), info, mods),
link: capture.name("main").unwrap().as_str(),
mode,
});
r
})
.collect::<stream::FuturesUnordered<_>>()
.filter_map(|v| {
future::ready(match v {
Ok(v) => Some(v),
Err(e) => {
eprintln!("{}", e);
None
}
})
})
}
async fn handle_beatmap<'a, 'b>(
ctx: &Context,
beatmap: &Beatmap,

View file

@ -0,0 +1,160 @@
use std::str::FromStr;
use crate::models::*;
use lazy_static::lazy_static;
use regex::Regex;
use stream::Stream;
use youmubot_prelude::*;
use super::{oppai_cache::BeatmapInfoWithPP, OsuEnv};
pub enum EmbedType {
Beatmap(Box<Beatmap>, BeatmapInfoWithPP, Mods),
Beatmapset(Vec<Beatmap>),
}
pub struct ToPrint<'a> {
pub embed: EmbedType,
pub link: &'a str,
pub mode: Option<Mode>,
}
lazy_static! {
// Beatmap(set) hooks
static ref OLD_LINK_REGEX: Regex = Regex::new(
r"(?:https?://)?osu\.ppy\.sh/(?P<link_type>s|b)/(?P<id>\d+)(?:[\&\?]m=(?P<mode>[0123]))?(?:\+(?P<mods>[A-Z]+))?"
).unwrap();
static ref NEW_LINK_REGEX: Regex = Regex::new(
r"(?:https?://)?osu\.ppy\.sh/beatmapsets/(?P<set_id>\d+)/?(?:\#(?P<mode>osu|taiko|fruits|mania)(?:/(?P<beatmap_id>\d+)|/?))?(?:\+(?P<mods>[A-Z]+))?"
).unwrap();
static ref SHORT_LINK_REGEX: Regex = Regex::new(
r"(?:^|\s|\W)(?P<main>/b/(?P<id>\d+)(?:/(?P<mode>osu|taiko|fruits|mania))?(?:\+(?P<mods>[A-Z]+))?)"
).unwrap();
// Score hook
pub(crate) static ref SCORE_LINK_REGEX: Regex = Regex::new(
r"(?:https?://)?osu\.ppy\.sh/scores/(?P<score_id>\d+)"
).unwrap();
}
pub fn parse_old_links<'a>(
env: &'a OsuEnv,
content: &'a str,
) -> impl Stream<Item = ToPrint<'a>> + 'a {
OLD_LINK_REGEX
.captures_iter(content)
.map(move |capture| async move {
let req_type = capture.name("link_type").unwrap().as_str();
let mode = capture
.name("mode")
.map(|v| v.as_str().parse::<u8>())
.transpose()?
.map(|v| Mode::from(v));
let embed = match req_type {
"b" => {
// collect beatmap info
let mods = capture
.name("mods")
.and_then(|v| Mods::from_str(v.as_str()).pls_ok())
.unwrap_or(Mods::NOMOD);
EmbedType::from_beatmap_id(&env, capture["id"].parse()?, mode, mods).await
}
"s" => EmbedType::from_beatmapset_id(&env, capture["id"].parse()?).await,
_ => unreachable!(),
}?;
Ok(ToPrint {
embed,
link: capture.get(0).unwrap().as_str(),
mode,
})
})
.collect::<stream::FuturesUnordered<_>>()
.filter_map(|v: Result<ToPrint>| future::ready(v.pls_ok()))
}
pub fn parse_new_links<'a>(
env: &'a OsuEnv,
content: &'a str,
) -> impl Stream<Item = ToPrint<'a>> + 'a {
NEW_LINK_REGEX
.captures_iter(content)
.map(|capture| async move {
let mode = capture
.name("mode")
.and_then(|v| Mode::parse_from_new_site(v.as_str()));
let link = capture.get(0).unwrap().as_str();
let embed = match capture
.name("beatmap_id")
.map(|v| v.as_str().parse::<u64>().unwrap())
{
Some(beatmap_id) => {
let mods = capture
.name("mods")
.and_then(|v| Mods::from_str(v.as_str()).pls_ok())
.unwrap_or(Mods::NOMOD);
EmbedType::from_beatmap_id(&env, beatmap_id, mode, mods).await
}
None => {
EmbedType::from_beatmapset_id(
&env,
capture.name("set_id").unwrap().as_str().parse()?,
)
.await
}
}?;
Ok(ToPrint { embed, link, mode })
})
.collect::<stream::FuturesUnordered<_>>()
.filter_map(|v: Result<ToPrint>| future::ready(v.pls_ok()))
}
pub fn parse_short_links<'a>(
env: &'a OsuEnv,
content: &'a str,
) -> impl Stream<Item = ToPrint<'a>> + 'a {
SHORT_LINK_REGEX
.captures_iter(content)
.map(|capture| async move {
let mode = capture
.name("mode")
.and_then(|v| Mode::parse_from_new_site(v.as_str()));
let link = capture.name("main").unwrap().as_str();
let id: u64 = capture.name("id").unwrap().as_str().parse()?;
let mods = capture
.name("mods")
.and_then(|v| Mods::from_str(v.as_str()).pls_ok())
.unwrap_or(Mods::NOMOD);
let embed = EmbedType::from_beatmap_id(&env, id, mode, mods).await?;
Ok(ToPrint { embed, link, mode })
})
.collect::<stream::FuturesUnordered<_>>()
.filter_map(|v: Result<ToPrint>| future::ready(v.pls_ok()))
}
impl EmbedType {
async fn from_beatmap_id(
env: &OsuEnv,
beatmap_id: u64,
mode: Option<Mode>,
mods: Mods,
) -> Result<Self> {
let bm = match mode {
Some(mode) => env.beatmaps.get_beatmap(beatmap_id, mode).await?,
None => env.beatmaps.get_beatmap_default(beatmap_id).await?,
};
let info = {
let mode = mode.unwrap_or(bm.mode);
env.oppai
.get_beatmap(bm.beatmap_id)
.await
.and_then(|b| b.get_possible_pp_with(mode, mods))?
};
Ok(Self::Beatmap(Box::new(bm), info, mods))
}
async fn from_beatmapset_id(env: &OsuEnv, beatmapset_id: u64) -> Result<Self> {
Ok(Self::Beatmapset(
env.beatmaps.get_beatmapset(beatmapset_id).await?,
))
}
}

View file

@ -15,7 +15,6 @@ use serenity::{
use db::{OsuLastBeatmap, OsuSavedUsers, OsuUser, OsuUserBests};
use embeds::{beatmap_embed, score_embed, user_embed};
use hook::SHORT_LINK_REGEX;
pub use hook::{dot_osu_hook, hook, score_hook};
use server_rank::{SERVER_RANK_COMMAND, SHOW_LEADERBOARD_COMMAND};
use youmubot_prelude::announcer::AnnouncerHandler;
@ -37,6 +36,7 @@ mod db;
pub(crate) mod display;
pub(crate) mod embeds;
mod hook;
mod link_parser;
pub(crate) mod oppai_cache;
mod server_rank;
@ -548,43 +548,34 @@ pub(crate) async fn load_beatmap(
env: &OsuEnv,
msg: &Message,
) -> Option<(BeatmapWithMode, Option<Mods>)> {
use link_parser::{parse_short_links, EmbedType};
if let Some(replied) = &msg.referenced_message {
// Try to look for a mention of the replied message.
let beatmap_id = SHORT_LINK_REGEX.captures(&replied.content).or_else(|| {
replied.embeds.iter().find_map(|e| {
e.description
.as_ref()
.and_then(|v| SHORT_LINK_REGEX.captures(v))
.or_else(|| {
e.fields
.iter()
.find_map(|f| SHORT_LINK_REGEX.captures(&f.value))
})
})
});
if let Some(caps) = beatmap_id {
let id: u64 = caps.name("id").unwrap().as_str().parse().unwrap();
let mode = caps
.name("mode")
.and_then(|m| Mode::parse_from_new_site(m.as_str()));
let mods = caps
.name("mods")
.and_then(|m| m.as_str().parse::<Mods>().ok());
let osu_client = &env.client;
let bms = osu_client
.beatmaps(BeatmapRequestKind::Beatmap(id), |f| f.maybe_mode(mode))
.await
.ok()
.and_then(|v| v.into_iter().next());
if let Some(beatmap) = bms {
let bm_mode = beatmap.mode;
let bm = BeatmapWithMode(beatmap, mode.unwrap_or(bm_mode));
// Store the beatmap in history
cache::save_beatmap(&env, msg.channel_id, &bm)
.await
.pls_ok();
return Some((bm, mods));
async fn try_content(
env: &OsuEnv,
content: &str,
) -> Option<(BeatmapWithMode, Option<Mods>)> {
let tp = parse_short_links(env, content).next().await?;
match tp.embed {
EmbedType::Beatmap(b, _, mods) => {
let mode = tp.mode.unwrap_or(b.mode);
Some((BeatmapWithMode(*b, mode), Some(mods)))
}
_ => None,
}
}
if let Some(v) = try_content(env, &replied.content).await {
return Some(v);
}
for embed in &replied.embeds {
if let Some(desc) = &embed.description {
if let Some(v) = try_content(env, desc).await {
return Some(v);
}
}
for field in &embed.fields {
if let Some(v) = try_content(env, &field.value).await {
return Some(v);
}
}
}
}