Skip to main content

lychee_lib/ratelimit/host/
host.rs

1use crate::{
2    ratelimit::{CacheableResponse, headers},
3    retry::RetryExt,
4};
5use dashmap::DashMap;
6use governor::{
7    Quota, RateLimiter,
8    clock::DefaultClock,
9    state::{InMemoryState, NotKeyed},
10};
11use http::StatusCode;
12use humantime_serde::re::humantime::format_duration;
13use log::warn;
14use reqwest::{Client as ReqwestClient, Request, Response as ReqwestResponse};
15use std::{num::NonZeroU32, sync::Mutex};
16use std::{
17    sync::Arc,
18    time::{Duration, Instant},
19};
20use tokio::sync::Semaphore;
21
22use super::key::HostKey;
23use super::stats::HostStats;
24use crate::Uri;
25use crate::types::Result;
26use crate::{
27    ErrorKind,
28    ratelimit::{HostConfig, RateLimitConfig},
29};
30
31/// Cap maximum backoff duration to reasonable limits
32const MAXIMUM_BACKOFF: Duration = Duration::from_secs(60);
33
34/// Per-host cache for storing request results
35type HostCache = DashMap<Uri, CacheableResponse>;
36
37/// Represents a single host with its own rate limiting, concurrency control,
38/// HTTP client configuration, and request cache.
39///
40/// Each host maintains:
41/// - A token bucket rate limiter using governor
42/// - A semaphore for concurrency control
43/// - A dedicated HTTP client with host-specific headers and cookies
44/// - Statistics tracking for adaptive behavior
45/// - A per-host cache to prevent duplicate requests
46#[derive(Debug)]
47pub struct Host {
48    /// The hostname this instance manages
49    pub key: HostKey,
50
51    /// Rate limiter using token bucket algorithm
52    rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
53
54    /// Controls maximum concurrent requests to this host
55    semaphore: Semaphore,
56
57    /// HTTP client configured for this specific host
58    client: ReqwestClient,
59
60    /// Request statistics and adaptive behavior tracking
61    stats: Mutex<HostStats>,
62
63    /// Current backoff duration for adaptive rate limiting
64    backoff_duration: Mutex<Duration>,
65
66    /// Per-host cache to prevent duplicate requests during a single link check invocation.
67    /// Note that this cache has no direct relation to the inter-process persistable [`crate::CacheStatus`].
68    cache: HostCache,
69
70    /// Keep track of currently active requests, to prevent duplicate concurrent requests
71    active_requests: DashMap<Uri, Arc<tokio::sync::Mutex<()>>>,
72}
73
74impl Host {
75    /// Create a new Host instance for the given hostname
76    #[must_use]
77    pub fn new(
78        key: HostKey,
79        host_config: &HostConfig,
80        global_config: &RateLimitConfig,
81        client: ReqwestClient,
82    ) -> Self {
83        const MAX_BURST: NonZeroU32 = NonZeroU32::new(1).unwrap();
84        let interval = host_config.effective_request_interval(global_config);
85        let rate_limiter =
86            Quota::with_period(interval).map(|q| RateLimiter::direct(q.allow_burst(MAX_BURST)));
87
88        // Create semaphore for concurrency control
89        let max_concurrent = host_config.effective_concurrency(global_config);
90        let semaphore = Semaphore::new(max_concurrent);
91
92        Host {
93            key,
94            rate_limiter,
95            semaphore,
96            client,
97            stats: Mutex::new(HostStats::default()),
98            backoff_duration: Mutex::new(Duration::from_millis(0)),
99            cache: DashMap::new(),
100            active_requests: DashMap::new(),
101        }
102    }
103
104    /// Check if a URI is cached and returns the cached response if it is valid
105    /// and satisfies the `needs_body` requirement.
106    fn get_cached_status(&self, uri: &Uri, needs_body: bool) -> Option<CacheableResponse> {
107        let cached = self.cache.get(uri)?.clone();
108        if needs_body {
109            if cached.text.is_some() {
110                Some(cached)
111            } else {
112                None
113            }
114        } else {
115            Some(cached)
116        }
117    }
118
119    fn record_cache_hit(&self) {
120        self.stats.lock().unwrap().record_cache_hit();
121    }
122
123    fn record_cache_miss(&self) {
124        self.stats.lock().unwrap().record_cache_miss();
125    }
126
127    /// Cache a request result
128    fn cache_result(&self, uri: &Uri, response: CacheableResponse) {
129        // Do not cache responses that are potentially retried
130        if !response.status.should_retry() {
131            self.cache.insert(uri.clone(), response);
132        }
133    }
134
135    /// Execute a request with rate limiting, concurrency control, and caching
136    ///
137    /// # Errors
138    ///
139    /// Returns an error if the request fails or rate limiting is exceeded
140    ///
141    /// # Panics
142    ///
143    /// Panics if the statistics mutex is poisoned
144    pub(crate) async fn execute_request(
145        &self,
146        request: Request,
147        needs_body: bool,
148    ) -> Result<CacheableResponse> {
149        let mut url = request.url().clone();
150        url.set_fragment(None);
151        let uri = Uri::from(url);
152        let _uri_guard = self.lock_uri_mutex(uri.clone()).await;
153
154        if let Some(cached) = self.get_cached_status(&uri, needs_body) {
155            self.record_cache_hit();
156            return Ok(cached);
157        }
158
159        self.record_cache_miss();
160        let _permit = self.acquire_semaphore().await;
161
162        self.await_backoff().await;
163
164        if let Some(rate_limiter) = &self.rate_limiter {
165            rate_limiter.until_ready().await;
166        }
167
168        self.perform_request(request, uri, needs_body).await
169    }
170
171    pub(crate) const fn get_client(&self) -> &ReqwestClient {
172        &self.client
173    }
174
175    async fn perform_request(
176        &self,
177        request: Request,
178        uri: Uri,
179        needs_body: bool,
180    ) -> Result<CacheableResponse> {
181        let start_time = Instant::now();
182        let response = match self.client.execute(request).await {
183            Ok(response) => response,
184            Err(e) => {
185                // Wrap network/HTTP errors to preserve the original error
186                return Err(ErrorKind::NetworkRequest(e));
187            }
188        };
189
190        self.update_stats(response.status(), start_time.elapsed());
191        self.update_backoff(response.status());
192        self.handle_rate_limit_headers(&response);
193
194        let response = CacheableResponse::from_response(response, needs_body).await?;
195        self.cache_result(&uri, response.clone());
196        Ok(response)
197    }
198
199    /// Await adaptive backoff if needed
200    async fn await_backoff(&self) {
201        let backoff_duration = {
202            let backoff = self.backoff_duration.lock().unwrap();
203            *backoff
204        };
205        if !backoff_duration.is_zero() {
206            log::debug!(
207                "Host {} applying backoff delay of {}ms due to previous rate limiting or errors",
208                self.key,
209                backoff_duration.as_millis()
210            );
211            tokio::time::sleep(backoff_duration).await;
212        }
213    }
214
215    /// Get a [`tokio::sync::OwnedMutexGuard<()>`]
216    /// to prevent concurrent requests to identical [`Uri`]s.
217    async fn lock_uri_mutex(&self, uri: Uri) -> tokio::sync::OwnedMutexGuard<()> {
218        let uri_mutex = self
219            .active_requests
220            .entry(uri)
221            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
222            .clone();
223
224        uri_mutex.lock_owned().await
225    }
226
227    /// Enforce the maximum concurrency of this host
228    async fn acquire_semaphore(&self) -> tokio::sync::SemaphorePermit<'_> {
229        self.semaphore
230            .acquire()
231            .await
232            // SAFETY: this should not panic as we never close the semaphore
233            .expect("Semaphore was closed unexpectedly")
234    }
235
236    fn update_backoff(&self, status: StatusCode) {
237        let mut backoff = self.backoff_duration.lock().unwrap();
238        match status.as_u16() {
239            200..=299 => {
240                // Reset backoff on success
241                *backoff = Duration::from_millis(0);
242            }
243            429 => {
244                // Exponential backoff on rate limit, capped at 30 seconds
245                let new_backoff = std::cmp::min(
246                    if backoff.is_zero() {
247                        Duration::from_millis(500)
248                    } else {
249                        *backoff * 2
250                    },
251                    Duration::from_secs(30),
252                );
253                log::debug!(
254                    "Host {} hit rate limit (429), increasing backoff from {}ms to {}ms",
255                    self.key,
256                    backoff.as_millis(),
257                    new_backoff.as_millis()
258                );
259                *backoff = new_backoff;
260            }
261            500..=599 => {
262                // Moderate backoff increase on server errors, capped at 10 seconds
263                *backoff = std::cmp::min(
264                    *backoff + Duration::from_millis(200),
265                    Duration::from_secs(10),
266                );
267            }
268            _ => {} // No backoff change for other status codes
269        }
270    }
271
272    fn update_stats(&self, status: StatusCode, request_time: Duration) {
273        self.stats
274            .lock()
275            .unwrap()
276            .record_response(status.as_u16(), request_time);
277    }
278
279    /// Parse rate limit headers from response and adjust behavior
280    fn handle_rate_limit_headers(&self, response: &ReqwestResponse) {
281        // Implement basic parsing here rather than using the rate-limits crate to keep dependencies minimal
282        let headers = response.headers();
283        self.handle_retry_after_header(headers);
284        self.handle_common_rate_limit_header_fields(headers);
285    }
286
287    /// Handle the common "X-RateLimit" header fields.
288    fn handle_common_rate_limit_header_fields(&self, headers: &http::HeaderMap) {
289        if let (Some(remaining), Some(limit)) =
290            headers::parse_common_rate_limit_header_fields(headers)
291            && limit > 0
292        {
293            #[allow(clippy::cast_precision_loss)]
294            let usage_ratio = limit.saturating_sub(remaining) as f64 / limit as f64;
295
296            // If we've used more than 80% of our quota, apply preventive backoff
297            if usage_ratio > 0.8 {
298                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
299                let duration = Duration::from_millis((200.0 * (usage_ratio - 0.8) / 0.2) as u64);
300                self.increase_backoff(duration);
301            }
302        }
303    }
304
305    /// Handle the "Retry-After" header
306    fn handle_retry_after_header(&self, headers: &http::HeaderMap) {
307        if let Some(retry_after_value) = headers.get("retry-after") {
308            let duration = match headers::parse_retry_after(retry_after_value) {
309                Ok(e) => e,
310                Err(e) => {
311                    warn!("Unable to parse Retry-After header as per RFC 7231: {e}");
312                    return;
313                }
314            };
315
316            self.increase_backoff(duration);
317        }
318    }
319
320    fn increase_backoff(&self, mut increased_backoff: Duration) {
321        if increased_backoff > MAXIMUM_BACKOFF {
322            warn!(
323                "Host {} sent an unexpectedly big rate limit backoff duration of {}. Capping the duration to {} instead.",
324                self.key,
325                format_duration(increased_backoff),
326                format_duration(MAXIMUM_BACKOFF)
327            );
328            increased_backoff = MAXIMUM_BACKOFF;
329        }
330
331        let mut backoff = self.backoff_duration.lock().unwrap();
332        *backoff = std::cmp::max(*backoff, increased_backoff);
333    }
334
335    /// Get host statistics
336    ///
337    /// # Panics
338    ///
339    /// Panics if the statistics mutex is poisoned
340    pub fn stats(&self) -> HostStats {
341        self.stats.lock().unwrap().clone()
342    }
343
344    /// Record a cache hit from the persistent disk cache.
345    /// Cache misses are tracked internally, so we don't expose such a method.
346    pub(crate) fn record_persistent_cache_hit(&self) {
347        self.record_cache_hit();
348    }
349
350    /// Get the current cache size (number of cached entries)
351    pub fn cache_size(&self) -> usize {
352        self.cache.len()
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use crate::ratelimit::{HostConfig, RateLimitConfig};
360    use reqwest::Client;
361
362    #[tokio::test]
363    async fn test_host_creation() {
364        let key = HostKey::from("example.com");
365        let host_config = HostConfig::default();
366        let global_config = RateLimitConfig::default();
367
368        let host = Host::new(key.clone(), &host_config, &global_config, Client::default());
369
370        assert_eq!(host.key, key);
371        assert_eq!(host.semaphore.available_permits(), 10); // Default concurrency
372        assert!((host.stats().success_rate() - 1.0).abs() < f64::EPSILON);
373        assert_eq!(host.cache_size(), 0);
374    }
375}