diff --git a/youmubot-osu/src/discord/hook.rs b/youmubot-osu/src/discord/hook.rs index 3c99f0e..240a8f1 100644 --- a/youmubot-osu/src/discord/hook.rs +++ b/youmubot-osu/src/discord/hook.rs @@ -7,12 +7,7 @@ use crate::{ }; use lazy_static::lazy_static; use regex::Regex; -use serenity::{ - builder::CreateMessage, - framework::standard::{CommandError as Error, CommandResult}, - model::channel::Message, - utils::MessageBuilder, -}; +use serenity::{builder::CreateMessage, model::channel::Message, utils::MessageBuilder}; use std::str::FromStr; use youmubot_prelude::*; @@ -30,43 +25,49 @@ lazy_static! { ).unwrap(); } -pub fn hook(ctx: &mut Context, msg: &Message) -> () { +pub async fn hook(ctx: &Context, msg: &Message) -> Result<()> { if msg.author.bot { - return; + return Ok(()); } - let mut v = move || -> CommandResult { - let old_links = handle_old_links(ctx, &msg.content)?; - let new_links = handle_new_links(ctx, &msg.content)?; - let short_links = handle_short_links(ctx, &msg, &msg.content)?; - let mut last_beatmap = None; - for l in old_links - .into_iter() - .chain(new_links.into_iter()) - .chain(short_links.into_iter()) - { - if let Err(v) = msg.channel_id.send_message(&ctx, |m| match l.embed { - EmbedType::Beatmap(b, info, mods) => { - let t = handle_beatmap(&b, info, l.link, l.mode, mods, m); - let mode = l.mode.unwrap_or(b.mode); - last_beatmap = Some(super::BeatmapWithMode(b, mode)); - t + let (old_links, new_links, short_links) = ( + handle_old_links(ctx, &msg.content), + handle_new_links(ctx, &msg.content), + handle_short_links(ctx, &msg, &msg.content), + ); + let last_beatmap = stream::select(old_links, stream::select(new_links, short_links)) + .then(|l| async move { + let mut bm: Option = None; + msg.channel_id + .send_message(&ctx, |m| match l.embed { + EmbedType::Beatmap(b, info, mods) => { + let t = handle_beatmap(&b, info, l.link, l.mode, mods, m); + let mode = l.mode.unwrap_or(b.mode); + bm = Some(super::BeatmapWithMode(b, mode)); + t + } + EmbedType::Beatmapset(b) => handle_beatmapset(b, l.link, l.mode, m), + }) + .await?; + let r: Result<_> = Ok(bm); + r + }) + .filter_map(|v| async move { + match v { + Ok(v) => v, + Err(e) => { + eprintln!("{}", e); + None } - EmbedType::Beatmapset(b) => handle_beatmapset(b, l.link, l.mode, m), - }) { - println!("Error in osu! hook: {:?}", v) } - } - // Save the beatmap for query later. - if let Some(t) = last_beatmap { - if let Err(v) = super::cache::save_beatmap(&*ctx.data.read(), msg.channel_id, &t) { - dbg!(v); - } - } - Ok(()) - }; - if let Err(v) = v() { - println!("Error in osu! hook: {:?}", v) + }) + .fold(None, |_, v| async move { Some(v) }) + .await; + + // Save the beatmap for query later. + if let Some(t) = last_beatmap { + super::cache::save_beatmap(&*ctx.data.read().await, msg.channel_id, &t)?; } + Ok(()) } enum EmbedType { @@ -80,37 +81,47 @@ struct ToPrint<'a> { mode: Option, } -fn handle_old_links<'a>(ctx: &mut Context, content: &'a str) -> Result>, Error> { - let osu = ctx.data.get_cloned::(); - let mut to_prints: Vec> = Vec::new(); - let cache = ctx.data.get_cloned::(); - for capture in OLD_LINK_REGEX.captures_iter(content) { - let req_type = capture.name("link_type").unwrap().as_str(); - let req = match req_type { - "b" => BeatmapRequestKind::Beatmap(capture["id"].parse()?), - "s" => BeatmapRequestKind::Beatmapset(capture["id"].parse()?), - _ => continue, - }; - let mode = capture - .name("mode") - .map(|v| v.as_str().parse()) - .transpose()? - .and_then(|v| { - Some(match v { - 0 => Mode::Std, - 1 => Mode::Taiko, - 2 => Mode::Catch, - 3 => Mode::Mania, - _ => return None, +fn handle_old_links<'a>( + ctx: &'a Context, + content: &'a str, +) -> impl stream::Stream> + 'a { + OLD_LINK_REGEX + .captures_iter(content) + .map(move |capture| async move { + let data = ctx.data.read().await; + let osu = data.get::().unwrap(); + let cache = data.get::().unwrap(); + let req_type = capture.name("link_type").unwrap().as_str(); + let req = match req_type { + "b" => BeatmapRequestKind::Beatmap(capture["id"].parse()?), + "s" => BeatmapRequestKind::Beatmapset(capture["id"].parse()?), + _ => unreachable!(), + }; + let mode = capture + .name("mode") + .map(|v| v.as_str().parse()) + .transpose()? + .and_then(|v| { + Some(match v { + 0 => Mode::Std, + 1 => Mode::Taiko, + 2 => Mode::Catch, + 3 => Mode::Mania, + _ => return None, + }) + }); + let beatmaps = osu + .beatmaps(req, |v| match mode { + Some(m) => v.mode(m, true), + None => v, }) - }); - let beatmaps = osu.beatmaps(req, |v| match mode { - Some(m) => v.mode(m, true), - None => v, - })?; - match req_type { - "b" => { - for b in beatmaps.into_iter() { + .await?; + if beatmaps.is_empty() { + return Ok(None); + } + let r: Result<_> = Ok(match req_type { + "b" => { + let b = beatmaps.into_iter().next().unwrap(); // collect beatmap info let mods = capture .name("mods") @@ -123,46 +134,65 @@ fn handle_old_links<'a>(ctx: &mut Context, content: &'a str) -> Result to_prints.push(ToPrint { - embed: EmbedType::Beatmapset(beatmaps), - link: capture.get(0).unwrap().as_str(), - mode, - }), - _ => (), - } - } - Ok(to_prints) + "s" => Some(ToPrint { + embed: EmbedType::Beatmapset(beatmaps), + link: capture.get(0).unwrap().as_str(), + mode, + }), + _ => None, + }); + r + }) + .collect::>() + .filter_map(|v| { + future::ready(match v { + Ok(v) => v, + Err(e) => { + eprintln!("{}", e); + None + } + }) + }) } -fn handle_new_links<'a>(ctx: &mut Context, content: &'a str) -> Result>, Error> { - let osu = ctx.data.get_cloned::(); - let mut to_prints: Vec> = Vec::new(); - let cache = ctx.data.get_cloned::(); - for capture in NEW_LINK_REGEX.captures_iter(content) { - let mode = capture - .name("mode") - .and_then(|v| Mode::parse_from_new_site(v.as_str())); - let link = capture.get(0).unwrap().as_str(); - let req = match capture.name("beatmap_id") { - Some(ref v) => BeatmapRequestKind::Beatmap(v.as_str().parse()?), - None => { - BeatmapRequestKind::Beatmapset(capture.name("set_id").unwrap().as_str().parse()?) +fn handle_new_links<'a>( + ctx: &'a Context, + content: &'a str, +) -> impl stream::Stream> + 'a { + NEW_LINK_REGEX + .captures_iter(content) + .map(|capture| async move { + let data = ctx.data.read().await; + let osu = data.get::().unwrap(); + let cache = data.get::().unwrap(); + let mode = capture + .name("mode") + .and_then(|v| Mode::parse_from_new_site(v.as_str())); + let link = capture.get(0).unwrap().as_str(); + let req = match capture.name("beatmap_id") { + Some(ref v) => BeatmapRequestKind::Beatmap(v.as_str().parse()?), + None => BeatmapRequestKind::Beatmapset( + capture.name("set_id").unwrap().as_str().parse()?, + ), + }; + let beatmaps = osu + .beatmaps(req, |v| match mode { + Some(m) => v.mode(m, true), + None => v, + }) + .await?; + if beatmaps.is_empty() { + return Ok(None); } - }; - let beatmaps = osu.beatmaps(req, |v| match mode { - Some(m) => v.mode(m, true), - None => v, - })?; - match capture.name("beatmap_id") { - Some(_) => { - for beatmap in beatmaps.into_iter() { + let r: Result<_> = Ok(match capture.name("beatmap_id") { + Some(_) => { + let beatmap = beatmaps.into_iter().next().unwrap(); // collect beatmap info let mods = capture .name("mods") @@ -177,48 +207,59 @@ fn handle_new_links<'a>(ctx: &mut Context, content: &'a str) -> Result to_prints.push(ToPrint { - embed: EmbedType::Beatmapset(beatmaps), - link, - mode, - }), - } - } - Ok(to_prints) + None => Some(ToPrint { + embed: EmbedType::Beatmapset(beatmaps), + link, + mode, + }), + }); + r + }) + .collect::>() + .filter_map(|v| { + future::ready(match v { + Ok(v) => v, + Err(e) => { + eprintln!("{}", e); + None + } + }) + }) } fn handle_short_links<'a>( - ctx: &mut Context, - msg: &Message, + ctx: &'a Context, + msg: &'a Message, content: &'a str, -) -> Result>, Error> { - if let Some(guild_id) = msg.guild_id { - if announcer::announcer_of(ctx, crate::discord::announcer::ANNOUNCER_KEY, guild_id)? - != Some(msg.channel_id) - { - // Disable if we are not in the server's announcer channel - return Ok(vec![]); - } - } - let osu = ctx.data.get_cloned::(); - let cache = ctx.data.get_cloned::(); - Ok(SHORT_LINK_REGEX +) -> impl stream::Stream> + 'a { + SHORT_LINK_REGEX .captures_iter(content) - .map(|capture| -> Result<_, Error> { + .map(|capture| async move { + if let Some(guild_id) = msg.guild_id { + if announcer::announcer_of(ctx, crate::discord::announcer::ANNOUNCER_KEY, guild_id) + .await? + != Some(msg.channel_id) + { + // Disable if we are not in the server's announcer channel + return Err(Error::msg("not in server announcer channel")); + } + } + let data = ctx.data.read().await; + let osu = data.get::().unwrap(); + let cache = data.get::().unwrap(); let mode = capture .name("mode") .and_then(|v| Mode::parse_from_new_site(v.as_str())); let id: u64 = capture.name("id").unwrap().as_str().parse()?; let beatmap = match mode { - Some(mode) => osu.get_beatmap(id, mode), - None => osu.get_beatmap_default(id), + Some(mode) => osu.get_beatmap(id, mode).await, + None => osu.get_beatmap_default(id).await, }?; let mods = capture .name("mods") @@ -233,14 +274,23 @@ fn handle_short_links<'a>( .and_then(|b| b.get_info_with(Some(mode), mods)) .ok() }); - Ok(ToPrint { + let r: Result<_> = Ok(ToPrint { embed: EmbedType::Beatmap(beatmap, info, mods), link: capture.get(0).unwrap().as_str(), mode, + }); + r + }) + .collect::>() + .filter_map(|v| { + future::ready(match v { + Ok(v) => Some(v), + Err(e) => { + eprintln!("{}", e); + None + } }) }) - .filter_map(|v| v.ok()) - .collect()) } fn handle_beatmap<'a, 'b>(