lychee_lib/ratelimit/host/
host.rs1use crate::{
2 ratelimit::{CacheableResponse, headers},
3 retry::RetryExt,
4};
5use dashmap::DashMap;
6use governor::{
7 Quota, RateLimiter,
8 clock::DefaultClock,
9 state::{InMemoryState, NotKeyed},
10};
11use http::StatusCode;
12use humantime_serde::re::humantime::format_duration;
13use log::warn;
14use reqwest::{Client as ReqwestClient, Request, Response as ReqwestResponse};
15use std::{num::NonZeroU32, sync::Mutex};
16use std::{
17 sync::Arc,
18 time::{Duration, Instant},
19};
20use tokio::sync::Semaphore;
21
22use super::key::HostKey;
23use super::stats::HostStats;
24use crate::Uri;
25use crate::types::Result;
26use crate::{
27 ErrorKind,
28 ratelimit::{HostConfig, RateLimitConfig},
29};
30
31const MAXIMUM_BACKOFF: Duration = Duration::from_secs(60);
33
34type HostCache = DashMap<Uri, CacheableResponse>;
36
37#[derive(Debug)]
47pub struct Host {
48 pub key: HostKey,
50
51 rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
53
54 semaphore: Semaphore,
56
57 client: ReqwestClient,
59
60 stats: Mutex<HostStats>,
62
63 backoff_duration: Mutex<Duration>,
65
66 cache: HostCache,
69
70 active_requests: DashMap<Uri, Arc<tokio::sync::Mutex<()>>>,
72}
73
74impl Host {
75 #[must_use]
77 pub fn new(
78 key: HostKey,
79 host_config: &HostConfig,
80 global_config: &RateLimitConfig,
81 client: ReqwestClient,
82 ) -> Self {
83 const MAX_BURST: NonZeroU32 = NonZeroU32::new(1).unwrap();
84 let interval = host_config.effective_request_interval(global_config);
85 let rate_limiter =
86 Quota::with_period(interval).map(|q| RateLimiter::direct(q.allow_burst(MAX_BURST)));
87
88 let max_concurrent = host_config.effective_concurrency(global_config);
90 let semaphore = Semaphore::new(max_concurrent);
91
92 Host {
93 key,
94 rate_limiter,
95 semaphore,
96 client,
97 stats: Mutex::new(HostStats::default()),
98 backoff_duration: Mutex::new(Duration::from_millis(0)),
99 cache: DashMap::new(),
100 active_requests: DashMap::new(),
101 }
102 }
103
104 fn get_cached_status(&self, uri: &Uri, needs_body: bool) -> Option<CacheableResponse> {
107 let cached = self.cache.get(uri)?.clone();
108 if needs_body {
109 if cached.text.is_some() {
110 Some(cached)
111 } else {
112 None
113 }
114 } else {
115 Some(cached)
116 }
117 }
118
119 fn record_cache_hit(&self) {
120 self.stats.lock().unwrap().record_cache_hit();
121 }
122
123 fn record_cache_miss(&self) {
124 self.stats.lock().unwrap().record_cache_miss();
125 }
126
127 fn cache_result(&self, uri: &Uri, response: CacheableResponse) {
129 if !response.status.should_retry() {
131 self.cache.insert(uri.clone(), response);
132 }
133 }
134
135 pub(crate) async fn execute_request(
145 &self,
146 request: Request,
147 needs_body: bool,
148 ) -> Result<CacheableResponse> {
149 let mut url = request.url().clone();
150 url.set_fragment(None);
151 let uri = Uri::from(url);
152 let _uri_guard = self.lock_uri_mutex(uri.clone()).await;
153
154 if let Some(cached) = self.get_cached_status(&uri, needs_body) {
155 self.record_cache_hit();
156 return Ok(cached);
157 }
158
159 self.record_cache_miss();
160 let _permit = self.acquire_semaphore().await;
161
162 self.await_backoff().await;
163
164 if let Some(rate_limiter) = &self.rate_limiter {
165 rate_limiter.until_ready().await;
166 }
167
168 self.perform_request(request, uri, needs_body).await
169 }
170
171 pub(crate) const fn get_client(&self) -> &ReqwestClient {
172 &self.client
173 }
174
175 async fn perform_request(
176 &self,
177 request: Request,
178 uri: Uri,
179 needs_body: bool,
180 ) -> Result<CacheableResponse> {
181 let start_time = Instant::now();
182 let response = match self.client.execute(request).await {
183 Ok(response) => response,
184 Err(e) => {
185 return Err(ErrorKind::NetworkRequest(e));
187 }
188 };
189
190 self.update_stats(response.status(), start_time.elapsed());
191 self.update_backoff(response.status());
192 self.handle_rate_limit_headers(&response);
193
194 let response = CacheableResponse::from_response(response, needs_body).await?;
195 self.cache_result(&uri, response.clone());
196 Ok(response)
197 }
198
199 async fn await_backoff(&self) {
201 let backoff_duration = {
202 let backoff = self.backoff_duration.lock().unwrap();
203 *backoff
204 };
205 if !backoff_duration.is_zero() {
206 log::debug!(
207 "Host {} applying backoff delay of {}ms due to previous rate limiting or errors",
208 self.key,
209 backoff_duration.as_millis()
210 );
211 tokio::time::sleep(backoff_duration).await;
212 }
213 }
214
215 async fn lock_uri_mutex(&self, uri: Uri) -> tokio::sync::OwnedMutexGuard<()> {
218 let uri_mutex = self
219 .active_requests
220 .entry(uri)
221 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
222 .clone();
223
224 uri_mutex.lock_owned().await
225 }
226
227 async fn acquire_semaphore(&self) -> tokio::sync::SemaphorePermit<'_> {
229 self.semaphore
230 .acquire()
231 .await
232 .expect("Semaphore was closed unexpectedly")
234 }
235
236 fn update_backoff(&self, status: StatusCode) {
237 let mut backoff = self.backoff_duration.lock().unwrap();
238 match status.as_u16() {
239 200..=299 => {
240 *backoff = Duration::from_millis(0);
242 }
243 429 => {
244 let new_backoff = std::cmp::min(
246 if backoff.is_zero() {
247 Duration::from_millis(500)
248 } else {
249 *backoff * 2
250 },
251 Duration::from_secs(30),
252 );
253 log::debug!(
254 "Host {} hit rate limit (429), increasing backoff from {}ms to {}ms",
255 self.key,
256 backoff.as_millis(),
257 new_backoff.as_millis()
258 );
259 *backoff = new_backoff;
260 }
261 500..=599 => {
262 *backoff = std::cmp::min(
264 *backoff + Duration::from_millis(200),
265 Duration::from_secs(10),
266 );
267 }
268 _ => {} }
270 }
271
272 fn update_stats(&self, status: StatusCode, request_time: Duration) {
273 self.stats
274 .lock()
275 .unwrap()
276 .record_response(status.as_u16(), request_time);
277 }
278
279 fn handle_rate_limit_headers(&self, response: &ReqwestResponse) {
281 let headers = response.headers();
283 self.handle_retry_after_header(headers);
284 self.handle_common_rate_limit_header_fields(headers);
285 }
286
287 fn handle_common_rate_limit_header_fields(&self, headers: &http::HeaderMap) {
289 if let (Some(remaining), Some(limit)) =
290 headers::parse_common_rate_limit_header_fields(headers)
291 && limit > 0
292 {
293 #[allow(clippy::cast_precision_loss)]
294 let usage_ratio = limit.saturating_sub(remaining) as f64 / limit as f64;
295
296 if usage_ratio > 0.8 {
298 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
299 let duration = Duration::from_millis((200.0 * (usage_ratio - 0.8) / 0.2) as u64);
300 self.increase_backoff(duration);
301 }
302 }
303 }
304
305 fn handle_retry_after_header(&self, headers: &http::HeaderMap) {
307 if let Some(retry_after_value) = headers.get("retry-after") {
308 let duration = match headers::parse_retry_after(retry_after_value) {
309 Ok(e) => e,
310 Err(e) => {
311 warn!("Unable to parse Retry-After header as per RFC 7231: {e}");
312 return;
313 }
314 };
315
316 self.increase_backoff(duration);
317 }
318 }
319
320 fn increase_backoff(&self, mut increased_backoff: Duration) {
321 if increased_backoff > MAXIMUM_BACKOFF {
322 warn!(
323 "Host {} sent an unexpectedly big rate limit backoff duration of {}. Capping the duration to {} instead.",
324 self.key,
325 format_duration(increased_backoff),
326 format_duration(MAXIMUM_BACKOFF)
327 );
328 increased_backoff = MAXIMUM_BACKOFF;
329 }
330
331 let mut backoff = self.backoff_duration.lock().unwrap();
332 *backoff = std::cmp::max(*backoff, increased_backoff);
333 }
334
335 pub fn stats(&self) -> HostStats {
341 self.stats.lock().unwrap().clone()
342 }
343
344 pub(crate) fn record_persistent_cache_hit(&self) {
347 self.record_cache_hit();
348 }
349
350 pub fn cache_size(&self) -> usize {
352 self.cache.len()
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::ratelimit::{HostConfig, RateLimitConfig};
360 use reqwest::Client;
361
362 #[tokio::test]
363 async fn test_host_creation() {
364 let key = HostKey::from("example.com");
365 let host_config = HostConfig::default();
366 let global_config = RateLimitConfig::default();
367
368 let host = Host::new(key.clone(), &host_config, &global_config, Client::default());
369
370 assert_eq!(host.key, key);
371 assert_eq!(host.semaphore.available_permits(), 10); assert!((host.stats().success_rate() - 1.0).abs() < f64::EPSILON);
373 assert_eq!(host.cache_size(), 0);
374 }
375}