diff --git a/youmubot-cf/src/hook.rs b/youmubot-cf/src/hook.rs index 33ac2af..26504c8 100644 --- a/youmubot-cf/src/hook.rs +++ b/youmubot-cf/src/hook.rs @@ -1,5 +1,7 @@ +use crate::CFClient; use chrono::{TimeZone, Utc}; -use codeforces::{Contest, Problem}; +use codeforces::{Client, Contest, Problem}; +use dashmap::DashMap as HashMap; use lazy_static::lazy_static; use rayon::{iter::Either, prelude::*}; use regex::{Captures, Regex}; @@ -7,7 +9,7 @@ use serenity::{ builder::CreateEmbed, framework::standard::CommandError, model::channel::Message, utils::MessageBuilder, }; -use std::{collections::HashMap, sync::Arc}; +use std::{sync::Arc, time::Instant}; use youmubot_prelude::*; lazy_static! { @@ -27,95 +29,109 @@ enum ContestOrProblem { } /// Caches the contest list. -#[derive(Clone, Debug, Default)] -pub struct ContestCache(Arc>)>>>); +pub struct ContestCache { + contests: HashMap>)>, + all_list: RwLock<(Vec, Instant)>, + http: Arc, +} impl TypeMapKey for ContestCache { type Value = ContestCache; } impl ContestCache { - fn get( - &self, - http: &::Value, - contest_id: u64, - ) -> Result<(Contest, Option>), CommandError> { - let rl = self.0.read(); - match rl.get(&contest_id) { - Some(r @ (_, Some(_))) => Ok(r.clone()), - Some((c, None)) => match Contest::standings(http, contest_id, |f| f.limit(1, 1)) { - Ok((c, p, _)) => Ok({ - drop(rl); - let mut v = self.0.write(); - let v = v.entry(contest_id).or_insert((c, None)); - v.1 = Some(p); - v.clone() - }), - Err(_) => Ok((c.clone(), None)), - }, - None => { - drop(rl); - // Step 1: try to fetch it individually - match Contest::standings(http, contest_id, |f| f.limit(1, 1)) { - Ok((c, p, _)) => Ok(self - .0 - .write() - .entry(contest_id) - .or_insert((c, Some(p))) - .clone()), - Err(codeforces::Error::Codeforces(s)) if s.ends_with("has not started") => { - // Fetch the entire list - { - let mut m = self.0.write(); - let contests = Contest::list(http, contest_id > 100_000)?; - contests.into_iter().for_each(|c| { - m.entry(c.id).or_insert((c, None)); - }); - } - self.0 - .read() - .get(&contest_id) - .cloned() - .ok_or("No contest found".into()) - } - Err(e) => Err(e.into()), - } - // Step 2: try to fetch the entire list. + /// Creates a new, empty cache. + pub async fn new(http: Arc) -> Result { + let contests_list = Contest::list(&*http, true).await?; + Ok(Self { + contests: HashMap::new(), + all_list: RwLock::new((contests_list, Instant::now())), + http, + }) + } + + /// Gets a contest from the cache, fetching from upstream if possible. + pub async fn get(&self, contest_id: u64) -> Result<(Contest, Option>)> { + if let Some(v) = self.contests.get(&contest_id) { + if v.1.is_some() { + return Ok(v.clone()); } } + self.get_and_store_contest(contest_id).await + } + + async fn get_and_store_contest( + &self, + contest_id: u64, + ) -> Result<(Contest, Option>)> { + let (c, p) = match Contest::standings(&*self.http, contest_id, |f| f.limit(1, 1)).await { + Ok((c, p, _)) => (c, Some(p)), + Err(codeforces::Error::Codeforces(s)) if s.ends_with("has not started") => { + let c = self.get_from_list(contest_id).await?; + (c, None) + } + Err(v) => return Err(Error::from(v)), + }; + self.contests.insert(contest_id, (c, p)); + Ok(self.contests.get(&contest_id).unwrap().clone()) + } + + async fn get_from_list(&self, contest_id: u64) -> Result { + let last_updated = self.all_list.read().await.1.clone(); + if Instant::now() - last_updated > std::time::Duration::from_secs(60 * 60) { + // We update at most once an hour. + *self.all_list.write().await = + (Contest::list(&*self.http, true).await?, Instant::now()); + } + self.all_list + .read() + .await + .0 + .iter() + .find(|v| v.id == contest_id) + .cloned() + .ok_or_else(|| Error::msg("Contest not found")) } } /// Prints info whenever a problem or contest (or more) is sent on a channel. -pub fn codeforces_info_hook(ctx: &mut Context, m: &Message) { - if m.author.bot { - return; +pub struct InfoHook; + +#[async_trait] +impl Hook for InfoHook { + async fn call(&mut self, ctx: &Context, m: &Message) -> Result<()> { + if m.author.bot { + return Ok(()); + } + let data = ctx.data.read().await; + let http = data.get::().unwrap(); + let contest_cache = data.get::().unwrap(); + let matches = parse(&m.content[..], contest_cache) + .collect::>() + .await; + if !matches.is_empty() { + m.channel_id + .send_message(&ctx, |c| { + c.content("Here are the info of the given Codeforces links!") + .embed(|e| print_info_message(&matches[..], e)) + }) + .await?; + } + Ok(()) } - let http = ctx.data.get_cloned::(); - let contest_cache = ctx.data.get_cloned::(); +} + +fn parse<'a>( + content: &'a str, + contest_cache: &'a ContestCache, +) -> impl stream::Stream + 'a { let matches = CONTEST_LINK - .captures_iter(&m.content) - .chain(PROBLEMSET_LINK.captures_iter(&m.content)) - // .collect::>() - // .into_par_iter() - .filter_map( - |v| match parse_capture(http.clone(), contest_cache.clone(), v) { - Ok(v) => Some(v), - Err(e) => { - dbg!(e); - None - } - }, - ) - .collect::>(); - if !matches.is_empty() { - m.channel_id - .send_message(&ctx, |c| { - c.content("Here are the info of the given Codeforces links!") - .embed(|e| print_info_message(&matches[..], e)) - }) - .ok(); - } + .captures_iter(content) + .chain(PROBLEMSET_LINK.captures_iter(content)) + .map(|v| parse_capture(contest_cache, v)) + .collect::>() + .filter_map(|v| future::ready(v.ok())); + matches } fn print_info_message<'a>( @@ -190,9 +206,8 @@ fn print_info_message<'a>( e.description(m.build()) } -fn parse_capture<'a>( - http: ::Value, - contest_cache: ContestCache, +async fn parse_capture<'a>( + contest_cache: &ContestCache, cap: Captures<'a>, ) -> Result<(ContestOrProblem, &'a str), CommandError> { let contest_id: u64 = cap @@ -200,7 +215,7 @@ fn parse_capture<'a>( .ok_or(CommandError::from("Contest not captured"))? .as_str() .parse()?; - let (contest, problems) = contest_cache.get(&http, contest_id)?; + let (contest, problems) = contest_cache.get(contest_id).await?; match cap.name("problem") { Some(p) => { for problem in problems.ok_or(CommandError::from("Contest hasn't started"))? { diff --git a/youmubot-cf/src/lib.rs b/youmubot-cf/src/lib.rs index 8fd487d..7b441d7 100644 --- a/youmubot-cf/src/lib.rs +++ b/youmubot-cf/src/lib.rs @@ -7,7 +7,7 @@ use serenity::{ model::channel::Message, utils::MessageBuilder, }; -use std::{collections::HashMap, time::Duration}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use youmubot_prelude::*; mod announcer; @@ -18,15 +18,25 @@ mod hook; /// Live-commentating a Codeforces round. mod live; +/// The TypeMapKey holding the Client. +struct CFClient; + +impl TypeMapKey for CFClient { + type Value = Arc; +} + use db::{CfSavedUsers, CfUser}; pub use hook::codeforces_info_hook; /// Sets up the CF databases. -pub fn setup(path: &std::path::Path, data: &mut ShareMap, announcers: &mut AnnouncerHandler) { +pub async fn setup(path: &std::path::Path, data: &mut TypeMap, announcers: &mut AnnouncerHandler) { CfSavedUsers::insert_into(data, path.join("cf_saved_users.yaml")) .expect("Must be able to set up DB"); - data.insert::(hook::ContestCache::default()); + let http = data.get::().unwrap(); + let client = Arc::new(codeforces::Client::new(http.clone())); + data.insert::(hook::ContestCache::new(client.clone()).await.unwrap()); + data.insert::(client); announcers.add("codeforces", announcer::updates); }