From 2756c463d5b09274be833fd0f718331b29cad47d Mon Sep 17 00:00:00 2001 From: Natsu Kagami Date: Mon, 12 May 2025 23:59:13 +0200 Subject: [PATCH] Introduce a Scores stream so we can lazily load top score requests --- youmubot-osu/src/discord/announcer.rs | 11 +- youmubot-osu/src/discord/commands.rs | 33 +-- youmubot-osu/src/discord/display.rs | 116 +++++----- youmubot-osu/src/discord/interaction.rs | 7 +- youmubot-osu/src/discord/mod.rs | 130 +++++------ youmubot-osu/src/discord/server_rank.rs | 4 +- youmubot-osu/src/lib.rs | 29 +-- .../src/{request.rs => request/mod.rs} | 125 +++++------ youmubot-osu/src/request/scores.rs | 206 ++++++++++++++++++ 9 files changed, 427 insertions(+), 234 deletions(-) rename youmubot-osu/src/{request.rs => request/mod.rs} (76%) create mode 100644 youmubot-osu/src/request/scores.rs diff --git a/youmubot-osu/src/discord/announcer.rs b/youmubot-osu/src/discord/announcer.rs index 5bb2c4c..1a0277c 100644 --- a/youmubot-osu/src/discord/announcer.rs +++ b/youmubot-osu/src/discord/announcer.rs @@ -22,6 +22,7 @@ use youmubot_prelude::*; use crate::discord::calculate_weighted_map_age; use crate::discord::db::OsuUserMode; +use crate::scores::Scores; use crate::{ discord::cache::save_beatmap, discord::oppai_cache::BeatmapContent, @@ -212,8 +213,8 @@ impl Announcer { }; let top_scores = env .client - .user_best(user_id.clone(), |f| f.mode(mode)) - .try_collect::>(); + .user_best(user_id.clone(), move |f| f.mode(mode)) + .and_then(|v| v.get_all()); let (user, top_scores) = try_join!(user, top_scores)?; let mut user = user.unwrap(); // if top scores exist, user would too @@ -264,14 +265,14 @@ impl<'a> CollectedScore<'a> { user: &'a User, event: UserEventRank, ) -> Result> { - let scores = osu + let mut scores = osu .scores(event.beatmap_id, |f| { f.user(UserID::ID(user.id)).mode(event.mode) }) .await?; let score = match scores - .into_iter() .find(|s| (s.date - event.date).abs() < chrono::TimeDelta::seconds(5)) + .await? { Some(v) => v, None => { @@ -283,7 +284,7 @@ impl<'a> CollectedScore<'a> { }; Ok(Self { user, - score, + score: score.clone(), mode: event.mode, kind: ScoreType::world(event.rank), }) diff --git a/youmubot-osu/src/discord/commands.rs b/youmubot-osu/src/discord/commands.rs index b38a98a..82a829c 100644 --- a/youmubot-osu/src/discord/commands.rs +++ b/youmubot-osu/src/discord/commands.rs @@ -4,7 +4,6 @@ use super::*; use cache::save_beatmap; use display::display_beatmapset; use embeds::ScoreEmbedBuilder; -use futures::TryStream; use link_parser::EmbedType; use poise::{ChoiceParameter, CreateReply}; use serenity::all::{CreateAttachment, User}; @@ -63,7 +62,9 @@ async fn top( ctx.defer().await?; let osu_client = &env.client; let mode = args.mode; - let plays = osu_client.user_best(UserID::ID(args.user.id), |f| f.mode(mode)); + let plays = osu_client + .user_best(UserID::ID(args.user.id), |f| f.mode(mode)) + .await?; handle_listing(ctx, plays, args, |nth, b| b.top_record(nth), "top").await } @@ -134,9 +135,11 @@ async fn recent( let osu_client = &env.client; let mode = args.mode; - let plays = osu_client.user_recent(UserID::ID(args.user.id), |f| { - f.mode(mode).include_fails(include_fails).limit(50) - }); + let plays = osu_client + .user_recent(UserID::ID(args.user.id), |f| { + f.mode(mode).include_fails(include_fails) + }) + .await?; handle_listing(ctx, plays, args, |_, b| b, "recent").await } @@ -166,7 +169,9 @@ async fn pinned( let osu_client = &env.client; let mode = args.mode; - let plays = osu_client.user_pins(UserID::ID(args.user.id), |f| f.mode(mode)); + let plays = osu_client + .user_pins(UserID::ID(args.user.id), |f| f.mode(mode)) + .await?; handle_listing(ctx, plays, args, |_, b| b, "pinned").await } @@ -250,7 +255,7 @@ pub async fn forcesave( async fn handle_listing( ctx: CmdContext<'_, U>, - plays: impl TryStream, + mut plays: impl Scores, listing_args: ListingArgs, transform: impl for<'a> Fn(u8, ScoreEmbedBuilder<'a>) -> ScoreEmbedBuilder<'a>, listing_kind: &'static str, @@ -265,12 +270,8 @@ async fn handle_listing( match nth { Nth::Nth(nth) => { - let play = std::pin::pin!(plays.into_stream()) - .skip(nth as usize) - .next() - .await; - let play = if let Some(play) = play { - play? + let play = if let Some(play) = plays.get(nth as usize).await? { + play } else { return Err(Error::msg("no such play"))?; }; @@ -307,7 +308,7 @@ async fn handle_listing( let reply = ctx.clone().reply(&header).await?; style .display_scores( - plays.try_collect::>(), + plays, ctx.clone().serenity_context(), ctx.guild_id(), (reply, ctx).with_header(header), @@ -489,7 +490,7 @@ async fn check( style .display_scores( - future::ok(scores), + scores, ctx.clone().serenity_context(), ctx.guild_id(), (msg, ctx).with_header(header), @@ -612,7 +613,7 @@ async fn leaderboard( let reply = ctx.reply(header).await?; style .display_scores( - future::ok(scores.into_iter().map(|s| s.score).collect()), + scores.into_iter().map(|s| s.score).collect::>(), ctx.serenity_context(), Some(guild.id), (reply, ctx), diff --git a/youmubot-osu/src/discord/display.rs b/youmubot-osu/src/discord/display.rs index c87bdd4..a9eee4e 100644 --- a/youmubot-osu/src/discord/display.rs +++ b/youmubot-osu/src/discord/display.rs @@ -2,14 +2,12 @@ pub use beatmapset::display_beatmapset; pub use scores::ScoreListStyle; mod scores { - use std::future::Future; - use poise::ChoiceParameter; use serenity::all::GuildId; use youmubot_prelude::*; - use crate::models::Score; + use crate::scores::Scores; #[derive(Debug, Clone, Copy, PartialEq, Eq, ChoiceParameter)] /// The style for the scores list to be displayed. @@ -43,7 +41,7 @@ mod scores { impl ScoreListStyle { pub async fn display_scores( self, - scores: impl Future>>, + scores: impl Scores, ctx: &Context, guild_id: Option, m: impl CanEdit, @@ -57,8 +55,6 @@ mod scores { } mod grid { - use std::future::Future; - use pagination::paginate_with_first_message; use serenity::all::{CreateActionRow, GuildId}; @@ -66,17 +62,16 @@ mod scores { use crate::discord::interaction::score_components; use crate::discord::{cache::save_beatmap, BeatmapWithMode, OsuEnv}; - use crate::models::Score; + use crate::scores::Scores; pub async fn display_scores_grid( - scores: impl Future>>, + scores: impl Scores, ctx: &Context, guild_id: Option, mut on: impl CanEdit, ) -> Result<()> { let env = ctx.data.read().await.get::().unwrap().clone(); let channel_id = on.get_message().await?.channel_id; - let scores = scores.await?; if scores.is_empty() { on.apply_edit(CreateReply::default().content("No plays found")) .await?; @@ -98,15 +93,22 @@ mod scores { Ok(()) } - pub struct Paginate { + pub struct Paginate { env: OsuEnv, - scores: Vec, + scores: T, guild_id: Option, channel_id: serenity::all::ChannelId, } + impl Paginate { + fn pages_fake(&self) -> usize { + let size = self.scores.length_fetched(); + size.count() + if size.is_total() { 0 } else { 1 } + } + } + #[async_trait] - impl pagination::Paginate for Paginate { + impl pagination::Paginate for Paginate { async fn render( &mut self, page: u8, @@ -114,7 +116,10 @@ mod scores { ) -> Result> { let env = &self.env; let page = page as usize; - let score = &self.scores[page]; + let Some(score) = self.scores.get(page).await? else { + return Ok(None); + }; + let score = score.clone(); let beatmap = env .beatmaps @@ -137,8 +142,12 @@ mod scores { Ok(Some( CreateReply::default() .embed({ - crate::discord::embeds::score_embed(score, &bm, &content, &user) - .footer(format!("Page {}/{}", page + 1, self.scores.len())) + crate::discord::embeds::score_embed(&score, &bm, &content, &user) + .footer(format!( + "Page {} / {}", + page + 1, + self.scores.length_fetched() + )) .build() }) .components( @@ -151,14 +160,13 @@ mod scores { } fn len(&self) -> Option { - Some(self.scores.len()) + Some(self.pages_fake()) } } } pub mod table { use std::borrow::Cow; - use std::future::Future; use pagination::paginate_with_first_message; use serenity::all::{CreateActionRow, CreateAttachment}; @@ -169,29 +177,28 @@ mod scores { use crate::discord::oppai_cache::Stats; use crate::discord::{time_before_now, Beatmap, BeatmapInfo, OsuEnv}; - use crate::models::Score; + use crate::scores::Scores; pub async fn display_scores_as_file( - scores: impl Future>>, + scores: impl Scores, ctx: &Context, mut on: impl CanEdit, ) -> Result<()> { let header = on.headers().unwrap_or("").to_owned(); let content = format!("{}\n\nPreparing file...", header); - let preparing = on.apply_edit(CreateReply::default().content(content)); - let (_, scores) = future::try_join(preparing, scores).await?; - if scores.is_empty() { - on.apply_edit(CreateReply::default().content("No plays found")) - .await?; - return Ok(()); - } + on.apply_edit(CreateReply::default().content(content)) + .await?; - let p = Paginate { + let mut p = Paginate { env: ctx.data.read().await.get::().unwrap().clone(), header: header.clone(), scores, }; - let content = p.to_table(0, p.scores.len()).await; + let Some(content) = p.to_table(0, usize::max_value()).await? else { + on.apply_edit(CreateReply::default().content("No plays found")) + .await?; + return Ok(()); + }; on.apply_edit( CreateReply::default() .content(header) @@ -202,11 +209,10 @@ mod scores { } pub async fn display_scores_table( - scores: impl Future>>, + scores: impl Scores, ctx: &Context, mut on: impl CanEdit, ) -> Result<()> { - let scores = scores.await?; if scores.is_empty() { on.apply_edit(CreateReply::default().content("No plays found")) .await?; @@ -227,19 +233,18 @@ mod scores { Ok(()) } - pub struct Paginate { + pub struct Paginate { env: OsuEnv, header: String, - scores: Vec, + scores: T, } - impl Paginate { - fn total_pages(&self) -> usize { - (self.scores.len() + ITEMS_PER_PAGE - 1) / ITEMS_PER_PAGE - } - - async fn to_table(&self, start: usize, end: usize) -> String { - let scores = &self.scores[start..end]; + impl Paginate { + async fn to_table(&mut self, start: usize, end: usize) -> Result> { + let scores = self.scores.get_range(start..end).await?; + if scores.is_empty() { + return Ok(None); + } let meta_cache = &self.env.beatmaps; let oppai = &self.env.oppai; @@ -348,14 +353,18 @@ mod scores { }) .collect::>(); - table_formatting(&SCORE_HEADERS, &SCORE_ALIGNS, score_arr) + Ok(Some(table_formatting( + &SCORE_HEADERS, + &SCORE_ALIGNS, + score_arr, + ))) } } const ITEMS_PER_PAGE: usize = 5; #[async_trait] - impl pagination::Paginate for Paginate { + impl pagination::Paginate for Paginate { async fn render( &mut self, page: u8, @@ -363,23 +372,20 @@ mod scores { ) -> Result> { let page = page as usize; let start = page * ITEMS_PER_PAGE; - let end = self.scores.len().min(start + ITEMS_PER_PAGE); - if start >= end { + let end = start + ITEMS_PER_PAGE; + + let Some(score_table) = self.to_table(start, end).await? else { return Ok(None); - } - let plays = &self.scores[start..end]; - - let has_oppai = plays.iter().any(|p| p.pp.is_none()); - - let score_table = self.to_table(start, end).await; + }; let mut content = serenity::utils::MessageBuilder::new(); content .push_line(&self.header) .push_line(score_table) - .push_line(format!("Page **{}/{}**", page + 1, self.total_pages())); - if has_oppai { - content.push_line("[?] means pp was predicted by oppai-rs."); - }; + .push_line(format!( + "Page **{} / {}**", + page + 1, + self.scores.length_fetched().as_pages(ITEMS_PER_PAGE) + )); let content = content.build(); Ok(Some( @@ -388,7 +394,9 @@ mod scores { } fn len(&self) -> Option { - Some(self.total_pages()) + let size = self.scores.length_fetched(); + let pages = size.count().div_ceil(ITEMS_PER_PAGE); + Some(pages + if size.is_total() { 0 } else { 1 }) } } } diff --git a/youmubot-osu/src/discord/interaction.rs b/youmubot-osu/src/discord/interaction.rs index a2bdd0d..b1900db 100644 --- a/youmubot-osu/src/discord/interaction.rs +++ b/youmubot-osu/src/discord/interaction.rs @@ -109,12 +109,7 @@ pub fn handle_check_button<'a>( let guild_id = comp.guild_id; ScoreListStyle::Grid - .display_scores( - future::ok(scores), - &ctx, - guild_id, - (comp, ctx).with_header(header), - ) + .display_scores(scores, &ctx, guild_id, (comp, ctx).with_header(header)) .await .pls_ok(); Ok(()) diff --git a/youmubot-osu/src/discord/mod.rs b/youmubot-osu/src/discord/mod.rs index 098c2f7..067bd9a 100644 --- a/youmubot-osu/src/discord/mod.rs +++ b/youmubot-osu/src/discord/mod.rs @@ -35,8 +35,9 @@ use crate::{ }, models::{Beatmap, Mode, Mods, Score, User}, mods::UnparsedMods, - request::{BeatmapRequestKind, UserID, SCORE_COUNT_LIMIT}, - OsuClient as OsuHttpClient, UserHeader, + request::{BeatmapRequestKind, UserID}, + scores::Scores, + OsuClient as OsuHttpClient, UserHeader, MAX_TOP_SCORES_INDEX, }; mod announcer; @@ -304,7 +305,8 @@ pub(crate) async fn find_save_requirements( ] { let scores = client .user_best(UserID::ID(u.id), |f| f.mode(*mode)) - .try_collect::>() + .await? + .get_all() .await?; if let Some(v) = scores.into_iter().choose(&mut rand::thread_rng()) { return Ok(Some((v, *mode))); @@ -351,10 +353,12 @@ pub(crate) async fn handle_save_respond( ) -> Result<()> { let osu_client = &env.client; async fn check(client: &OsuHttpClient, u: &User, mode: Mode, map_id: u64) -> Result { - client - .user_recent(UserID::ID(u.id), |f| f.mode(mode).limit(1)) - .try_any(|s| future::ready(s.beatmap_id == map_id)) - .await + Ok(client + .user_recent(UserID::ID(u.id), |f| f.mode(mode)) + .await? + .get(0) + .await? + .is_some_and(|s| s.beatmap_id == map_id)) } let msg_id = reply.get_message().await?.id; let recv = InteractionCollector::create(&ctx, msg_id).await?; @@ -498,13 +502,17 @@ pub(crate) struct UserExtras { impl UserExtras { // Collect UserExtras from the given user. pub async fn from_user(env: &OsuEnv, user: &User, mode: Mode) -> Result { - let scores = env - .client - .user_best(UserID::ID(user.id), |f| f.mode(mode)) - .try_collect::>() - .await - .pls_ok() - .unwrap_or_else(std::vec::Vec::new); + let scores = { + match env + .client + .user_best(UserID::ID(user.id), |f| f.mode(mode)) + .await + .pls_ok() + { + Some(v) => v.get_all().await.pls_ok().unwrap_or_else(Vec::new), + None => Vec::new(), + } + }; let (length, age) = join!( calculate_weighted_map_length(&scores, &env.beatmaps, mode), @@ -589,7 +597,7 @@ impl ListingArgs { sender: serenity::all::UserId, ) -> Result { let nth = index - .filter(|&v| 1 <= v && v <= SCORE_COUNT_LIMIT as u8) + .filter(|&v| 1 <= v && v <= MAX_TOP_SCORES_INDEX as u8) .map(|v| v - 1) .map(Nth::Nth) .unwrap_or_default(); @@ -632,7 +640,7 @@ async fn user_header_or_default_id( Some(UsernameArg::Raw(r)) => { let user = env .client - .user(&UserID::Username(r), |f| f) + .user(&UserID::Username(Arc::new(r)), |f| f) .await? .ok_or(Error::msg("User not found"))?; (user.preferred_mode, user.into()) @@ -678,30 +686,40 @@ pub async fn recent(ctx: &Context, msg: &Message, mut args: Args) -> CommandResu } = ListingArgs::parse(&env, msg, &mut args, ScoreListStyle::Table).await?; let osu_client = &env.client; - let plays = osu_client.user_recent(UserID::ID(user.id), |f| f.mode(mode)); + let mut plays = osu_client + .user_recent(UserID::ID(user.id), |f| f.mode(mode)) + .await?; match nth { Nth::All => { let header = format!("Here are the recent plays by {}!", user.mention()); let reply = msg.reply(ctx, &header).await?; style - .display_scores( - plays.try_collect::>(), - ctx, - reply.guild_id, - (reply, ctx).with_header(header), - ) + .display_scores(plays, ctx, reply.guild_id, (reply, ctx).with_header(header)) .await?; } Nth::Nth(nth) => { - let plays = std::pin::pin!(plays.into_stream()); - let (play, rest) = plays.skip(nth as usize).into_future().await; - let play = play.ok_or(Error::msg("No such play"))??; - let attempts = rest - .try_take_while(|p| { - future::ok(p.beatmap_id == play.beatmap_id && p.mods == play.mods) - }) - .count() - .await; + let play = plays + .get(nth as usize) + .await? + .ok_or(Error::msg("No such play"))? + .clone(); + let attempts = { + let mut count = 0usize; + while plays + .get(nth as usize + count + 1) + .await + .ok() + .flatten() + .is_some_and(|p| { + p.beatmap_id == play.beatmap_id + && p.mode == play.mode + && p.mods == play.mods + }) + { + count += 1; + } + count + }; let beatmap = env.beatmaps.get_beatmap(play.beatmap_id, mode).await?; let content = env.oppai.get_beatmap(beatmap.beatmap_id).await?; let beatmap_mode = BeatmapWithMode(beatmap, Some(mode)); @@ -751,26 +769,22 @@ pub async fn pins(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult let osu_client = &env.client; - let plays = osu_client.user_pins(UserID::ID(user.id), |f| f.mode(mode)); + let mut plays = osu_client + .user_pins(UserID::ID(user.id), |f| f.mode(mode)) + .await?; match nth { Nth::All => { let header = format!("Here are the pinned plays by `{}`!", user.username); let reply = msg.reply(ctx, &header).await?; style - .display_scores( - plays.try_collect::>(), - ctx, - reply.guild_id, - (reply, ctx).with_header(header), - ) + .display_scores(plays, ctx, reply.guild_id, (reply, ctx).with_header(header)) .await?; } Nth::Nth(nth) => { - let play = std::pin::pin!(plays.into_stream()) - .skip(nth as usize) - .next() - .await - .ok_or(Error::msg("No such play"))??; + let play = plays + .get(nth as usize) + .await? + .ok_or(Error::msg("No such play"))?; let beatmap = env.beatmaps.get_beatmap(play.beatmap_id, mode).await?; let content = env.oppai.get_beatmap(beatmap.beatmap_id).await?; let beatmap_mode = BeatmapWithMode(beatmap, Some(mode)); @@ -1016,12 +1030,7 @@ pub async fn check(ctx: &Context, msg: &Message, mut args: Args) -> CommandResul ); let reply = msg.reply(&ctx, &header).await?; style - .display_scores( - future::ok(scores), - ctx, - msg.guild_id, - (reply, ctx).with_header(header), - ) + .display_scores(scores, ctx, msg.guild_id, (reply, ctx).with_header(header)) .await?; Ok(()) @@ -1045,6 +1054,7 @@ pub(crate) async fn do_check( let mods = mods.clone().and_then(|t| t.to_mods(m).ok()); osu_client .scores(b.beatmap_id, |f| f.user(UserID::ID(user.id)).mode(m)) + .and_then(|v| v.get_all()) .map_ok(move |mut v| { v.retain(|s| mods.as_ref().is_none_or(|m| s.mods.contains(&m))); v @@ -1088,15 +1098,16 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult } = ListingArgs::parse(&env, msg, &mut args, ScoreListStyle::default()).await?; let osu_client = &env.client; - let plays = osu_client.user_best(UserID::ID(user.id), |f| f.mode(mode)); + let mut plays = osu_client + .user_best(UserID::ID(user.id), |f| f.mode(mode)) + .await?; match nth { Nth::Nth(nth) => { - let play = std::pin::pin!(plays.into_stream()) - .skip(nth as usize) - .next() - .await - .ok_or(Error::msg("No such play"))??; + let play = plays + .get(nth as usize) + .await? + .ok_or(Error::msg("No such play"))?; let beatmap = env.beatmaps.get_beatmap(play.beatmap_id, mode).await?; let content = env.oppai.get_beatmap(beatmap.beatmap_id).await?; @@ -1126,12 +1137,7 @@ pub async fn top(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult let header = format!("Here are the top plays by {}!", user.mention()); let reply = msg.reply(&ctx, &header).await?; style - .display_scores( - plays.try_collect::>(), - ctx, - msg.guild_id, - (reply, ctx).with_header(header), - ) + .display_scores(plays, ctx, msg.guild_id, (reply, ctx).with_header(header)) .await?; } } diff --git a/youmubot-osu/src/discord/server_rank.rs b/youmubot-osu/src/discord/server_rank.rs index 459c9e7..476e1e9 100644 --- a/youmubot-osu/src/discord/server_rank.rs +++ b/youmubot-osu/src/discord/server_rank.rs @@ -30,6 +30,7 @@ use crate::{ }, models::Mode, request::UserID, + scores::Scores, Beatmap, Score, }; @@ -438,7 +439,7 @@ pub async fn show_leaderboard(ctx: &Context, msg: &Message, mut args: Args) -> C let reply = msg.reply(&ctx, header).await?; style .display_scores( - future::ok(scores.into_iter().map(|s| s.score).collect()), + scores.into_iter().map(|s| s.score).collect::>(), ctx, Some(guild), (reply, ctx), @@ -503,6 +504,7 @@ async fn get_leaderboard( .scores(b.beatmap_id, move |f| { f.user(UserID::ID(osu_id)).mode(mode_override) }) + .and_then(|v| v.get_all()) .map(move |r| Some((b, op, mem.clone(), r.ok()?))) }) }) diff --git a/youmubot-osu/src/lib.rs b/youmubot-osu/src/lib.rs index f85985e..f5432d3 100644 --- a/youmubot-osu/src/lib.rs +++ b/youmubot-osu/src/lib.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use std::convert::TryInto; use std::sync::Arc; -use futures::TryStream; use futures_util::lock::Mutex; use models::*; use request::builders::*; @@ -13,6 +12,8 @@ pub mod discord; pub mod models; pub mod request; +pub const MAX_TOP_SCORES_INDEX: usize = 200; + /// Client is the client that will perform calls to the osu! api server. #[derive(Clone)] pub struct OsuClient { @@ -87,45 +88,45 @@ impl OsuClient { &self, beatmap_id: u64, f: impl FnOnce(&mut ScoreRequestBuilder) -> &mut ScoreRequestBuilder, - ) -> Result, Error> { + ) -> Result { let mut r = ScoreRequestBuilder::new(beatmap_id); f(&mut r); r.build(self).await } - pub fn user_best( + pub async fn user_best( &self, user: UserID, f: impl FnOnce(&mut UserScoreRequestBuilder) -> &mut UserScoreRequestBuilder, - ) -> impl TryStream { - self.user_scores(UserScoreType::Best, user, f) + ) -> Result { + self.user_scores(UserScoreType::Best, user, f).await } - pub fn user_recent( + pub async fn user_recent( &self, user: UserID, f: impl FnOnce(&mut UserScoreRequestBuilder) -> &mut UserScoreRequestBuilder, - ) -> impl TryStream { - self.user_scores(UserScoreType::Recent, user, f) + ) -> Result { + self.user_scores(UserScoreType::Recent, user, f).await } - pub fn user_pins( + pub async fn user_pins( &self, user: UserID, f: impl FnOnce(&mut UserScoreRequestBuilder) -> &mut UserScoreRequestBuilder, - ) -> impl TryStream { - self.user_scores(UserScoreType::Pin, user, f) + ) -> Result { + self.user_scores(UserScoreType::Pin, user, f).await } - fn user_scores( + async fn user_scores( &self, u: UserScoreType, user: UserID, f: impl FnOnce(&mut UserScoreRequestBuilder) -> &mut UserScoreRequestBuilder, - ) -> impl TryStream { + ) -> Result { let mut r = UserScoreRequestBuilder::new(u, user); f(&mut r); - r.build(self.clone()) + r.build(self.clone()).await } pub async fn score(&self, score_id: u64) -> Result, Error> { diff --git a/youmubot-osu/src/request.rs b/youmubot-osu/src/request/mod.rs similarity index 76% rename from youmubot-osu/src/request.rs rename to youmubot-osu/src/request/mod.rs index e7ea872..a13eee0 100644 --- a/youmubot-osu/src/request.rs +++ b/youmubot-osu/src/request/mod.rs @@ -1,16 +1,18 @@ use core::fmt; +use std::sync::Arc; use crate::models::{Mode, Mods}; use crate::OsuClient; use rosu_v2::error::OsuError; use youmubot_prelude::*; -/// Maximum number of scores returned by the osu! api. -pub const SCORE_COUNT_LIMIT: usize = 200; +pub(crate) mod scores; + +pub use scores::Scores; #[derive(Clone, Debug)] pub enum UserID { - Username(String), + Username(Arc), ID(u64), } @@ -26,7 +28,7 @@ impl fmt::Display for UserID { impl From for rosu_v2::prelude::UserId { fn from(value: UserID) -> Self { match value { - UserID::Username(s) => rosu_v2::request::UserId::Name(s.into()), + UserID::Username(s) => rosu_v2::request::UserId::Name(s[..].into()), UserID::ID(id) => rosu_v2::request::UserId::Id(id as u32), } } @@ -37,7 +39,7 @@ impl UserID { let s = s.into(); match s.parse::() { Ok(id) => UserID::ID(id), - Err(_) => UserID::Username(s), + Err(_) => UserID::Username(Arc::new(s)), } } } @@ -57,11 +59,11 @@ fn handle_not_found(v: Result) -> Result, OsuError> { } pub mod builders { - use futures_util::TryStream; use rosu_v2::model::mods::GameModsIntermode; - use crate::models; + use crate::models::{self, Score}; + use super::scores::{FetchScores, ScoresFetcher}; use super::OsuClient; use super::*; /// A builder for a Beatmap request. @@ -170,7 +172,6 @@ pub mod builders { user: Option, mode: Option, mods: Option, - limit: Option, } impl ScoreRequestBuilder { @@ -180,7 +181,6 @@ pub mod builders { user: None, mode: None, mods: None, - limit: None, } } @@ -199,23 +199,21 @@ pub mod builders { self } - pub fn limit(&mut self, limit: u8) -> &mut Self { - self.limit = Some(limit) - .filter(|&v| v <= SCORE_COUNT_LIMIT as u8) - .or(self.limit); - self - } - - pub(crate) async fn build(self, osu: &OsuClient) -> Result> { - let scores = handle_not_found(match self.user { + async fn fetch_scores( + &self, + osu: &crate::OsuClient, + _offset: usize, + ) -> Result> { + let scores = handle_not_found(match &self.user { Some(user) => { - let mut r = osu.rosu.beatmap_user_scores(self.beatmap_id as u32, user); + let mut r = osu + .rosu + .beatmap_user_scores(self.beatmap_id as u32, user.clone()); if let Some(mode) = self.mode { r = r.mode(mode.into()); } - match self.mods { + match &self.mods { Some(mods) => r.await.map(|mut ss| { - // let mods = GameModsIntermode::from(mods.inner); ss.retain(|s| { Mods::from_gamemods(s.mods.clone(), s.set_on_lazer).contains(&mods) }); @@ -226,21 +224,25 @@ pub mod builders { } None => { let mut r = osu.rosu.beatmap_scores(self.beatmap_id as u32).global(); - if let Some(mode) = self.mode { - r = r.mode(mode.into()); + if let Some(mode) = &self.mode { + r = r.mode(mode.clone().into()); } - if let Some(mods) = self.mods { - r = r.mods(GameModsIntermode::from(mods.inner)); - } - if let Some(limit) = self.limit { - r = r.limit(limit as u32); + if let Some(mods) = &self.mods { + r = r.mods(GameModsIntermode::from(mods.inner.clone())); } + // r = r.limit(limit); // can't do this just yet because of offset not working r.await } })? .ok_or_else(|| error!("beatmap or user not found"))?; Ok(scores.into_iter().map(|v| v.into()).collect()) } + + pub(crate) async fn build(self, osu: &OsuClient) -> Result { + // user queries always return all scores, so no need to consider offset. + // otherwise, it's not working anyway... + Ok(self.fetch_scores(osu, 0).await?) + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -255,7 +257,6 @@ pub mod builders { score_type: UserScoreType, user: UserID, mode: Option, - limit: Option, include_fails: bool, } @@ -265,7 +266,6 @@ pub mod builders { score_type, user, mode: None, - limit: None, include_fails: true, } } @@ -275,43 +275,19 @@ pub mod builders { self } - pub fn limit(&mut self, limit: usize) -> &mut Self { - self.limit = if limit > SCORE_COUNT_LIMIT { - self.limit - } else { - Some(limit) - }; - self - } - pub fn include_fails(&mut self, include_fails: bool) -> &mut Self { self.include_fails = include_fails; self } - async fn with_offset( - self, - offset: Option, - client: OsuClient, - ) -> Result, Option)>> { - const MAXIMUM_LIMIT: usize = 100; - let offset = if let Some(offset) = offset { - offset - } else { - return Ok(None); - }; - let count = match self.limit { - Some(limit) => (limit - offset).min(MAXIMUM_LIMIT), - None => MAXIMUM_LIMIT, - }; - if count == 0 { - return Ok(None); - } + const SCORES_PER_PAGE: usize = 100; + + async fn with_offset(&self, client: &OsuClient, offset: usize) -> Result> { let scores = handle_not_found({ let mut r = client .rosu .user_scores(self.user.clone()) - .limit(count) + .limit(Self::SCORES_PER_PAGE) .offset(offset); r = match self.score_type { UserScoreType::Recent => r.recent().include_fails(self.include_fails), @@ -324,28 +300,25 @@ pub mod builders { r.await })? .ok_or_else(|| error!("user not found"))?; - let count = scores.len(); - Ok(Some(( - scores.into_iter().map(|v| v.into()).collect(), - if count == MAXIMUM_LIMIT { - Some(offset + MAXIMUM_LIMIT) - } else { - None - }, - ))) + Ok(scores.into_iter().map(|v| v.into()).collect()) } - pub(crate) fn build( - self, - client: OsuClient, - ) -> impl TryStream { - futures::stream::try_unfold(Some(0), move |off| { - self.clone().with_offset(off, client.clone()) - }) - .map_ok(|v| futures::stream::iter(v).map(|v| Ok(v) as Result<_>)) - .try_flatten() + pub(crate) async fn build(self, client: OsuClient) -> Result { + ScoresFetcher::new(client, self).await } } + + impl FetchScores for UserScoreRequestBuilder { + async fn fetch_scores( + &self, + client: &crate::OsuClient, + offset: usize, + ) -> Result> { + self.with_offset(client, offset).await + } + + const SCORES_PER_PAGE: usize = Self::SCORES_PER_PAGE; + } } pub struct UserBestRequest { diff --git a/youmubot-osu/src/request/scores.rs b/youmubot-osu/src/request/scores.rs new file mode 100644 index 0000000..fe7ce4b --- /dev/null +++ b/youmubot-osu/src/request/scores.rs @@ -0,0 +1,206 @@ +use std::{fmt::Display, future::Future, ops::Range}; + +use youmubot_prelude::*; + +use crate::{models::Score, OsuClient}; + +pub const MAX_SCORE_PER_PAGE: usize = 1000; + +/// Fetch scores given an offset. +/// Implemented for score requests. +pub trait FetchScores: Send { + /// Scores per page. + const SCORES_PER_PAGE: usize = MAX_SCORE_PER_PAGE; + /// Fetch scores given an offset. + fn fetch_scores( + &self, + client: &crate::OsuClient, + offset: usize, + ) -> impl Future>> + Send; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Size { + /// There might be more + AtLeast(usize), + /// All + Total(usize), +} + +impl Display for Size { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.count())?; + if !self.is_total() { + write!(f, "+")?; + } + Ok(()) + } +} + +impl Size { + pub fn count(&self) -> usize { + match self { + Size::AtLeast(cnt) => *cnt, + Size::Total(cnt) => *cnt, + } + } + + pub fn is_total(&self) -> bool { + match self { + Size::AtLeast(_) => false, + Size::Total(_) => true, + } + } + + pub fn as_pages(self, per_page: usize) -> Size { + match self { + Size::AtLeast(a) => Size::AtLeast(a.div_ceil(per_page)), + Size::Total(a) => Size::Total(a.div_ceil(per_page)), + } + } +} + +/// A scores stream. +pub trait Scores: Send { + /// Total length of the pages. + fn length_fetched(&self) -> Size; + + /// Whether the scores set is empty. + fn is_empty(&self) -> bool; + + /// Get the index-th score. + fn get(&mut self, index: usize) -> impl Future>> + Send; + + /// Get all scores. + fn get_all(self) -> impl Future>> + Send; + + /// Get the scores between the given range. + fn get_range(&mut self, range: Range) -> impl Future> + Send; + + /// Find a score that matches the predicate `f`. + fn find bool + Send>( + &mut self, + f: F, + ) -> impl Future>> + Send; +} + +impl Scores for Vec { + fn length_fetched(&self) -> Size { + Size::Total(self.len()) + } + + fn is_empty(&self) -> bool { + self.is_empty() + } + + fn get(&mut self, index: usize) -> impl Future>> + Send { + future::ok(self[..].get(index)) + } + + fn get_all(self) -> impl Future>> + Send { + future::ok(self) + } + + fn get_range(&mut self, range: Range) -> impl Future> + Send { + future::ok(&self[range]) + } + + async fn find bool + Send>(&mut self, mut f: F) -> Result> { + Ok(self.iter().find(|v| f(*v))) + } +} + +/// A scores stream with a fetcher. +pub(super) struct ScoresFetcher { + fetcher: T, + client: OsuClient, + scores: Vec, + more_exists: bool, +} + +impl ScoresFetcher { + /// Create a new Scores stream. + pub async fn new(client: OsuClient, fetcher: T) -> Result { + let mut s = Self { + fetcher, + client, + scores: Vec::new(), + more_exists: true, + }; + // fetch the first page immediately. + s.fetch_next_page().await?; + Ok(s) + } +} + +impl Scores for ScoresFetcher { + /// Total length of the pages. + fn length_fetched(&self) -> Size { + let count = self.len(); + if self.more_exists { + Size::AtLeast(count) + } else { + Size::Total(count) + } + } + + fn is_empty(&self) -> bool { + self.scores.is_empty() + } + + /// Get the index-th score. + async fn get(&mut self, index: usize) -> Result> { + Ok(self.get_range(index..(index + 1)).await?.get(0)) + } + + /// Get all scores. + async fn get_all(mut self) -> Result> { + let _ = self.get_range(0..usize::max_value()).await?; + Ok(self.scores) + } + + /// Get the scores between the given range. + async fn get_range(&mut self, range: Range) -> Result<&[Score]> { + while self.len() < range.end { + if !self.fetch_next_page().await? { + break; + } + } + Ok(&self.scores[range.start.min(self.len())..range.end.min(self.len())]) + } + + async fn find bool + Send>(&mut self, mut f: F) -> Result> { + let mut from = 0usize; + let index = loop { + if from == self.len() && !self.fetch_next_page().await? { + break None; + } + if f(&self.scores[from]) { + break Some(from); + } + from += 1; + }; + Ok(index.map(|v| &self.scores[v])) + } +} + +impl ScoresFetcher { + async fn fetch_next_page(&mut self) -> Result { + if !self.more_exists { + return Ok(false); + } + let offset = self.len(); + let scores = self.fetcher.fetch_scores(&self.client, offset).await?; + if scores.len() < T::SCORES_PER_PAGE { + self.more_exists = false; + } + if scores.is_empty() { + return Ok(false); + } + self.scores.extend(scores); + Ok(true) + } + fn len(&self) -> usize { + self.scores.len() + } +}