1#![allow(
11 clippy::module_name_repetitions,
12 clippy::struct_excessive_bools,
13 clippy::default_trait_access,
14 clippy::used_underscore_binding
15)]
16use std::{collections::HashSet, sync::Arc, time::Duration};
17
18use http::{
19 StatusCode,
20 header::{HeaderMap, HeaderValue},
21};
22use log::debug;
23use octocrab::Octocrab;
24use regex::RegexSet;
25use reqwest::{header, redirect, tls};
26use reqwest_cookie_store::CookieStoreMutex;
27use secrecy::{ExposeSecret, SecretString};
28use typed_builder::TypedBuilder;
29
30use crate::{
31 BaseInfo, BasicAuthCredentials, ErrorKind, Request, Response, Result, Status, Uri,
32 chain::RequestChain,
33 checker::{file::FileChecker, mail::MailChecker, website::WebsiteChecker},
34 filter::Filter,
35 ratelimit::{ClientMap, HostConfigs, HostKey, HostPool, RateLimitConfig},
36 remap::Remaps,
37 types::{DEFAULT_ACCEPTED_STATUS_CODES, redirect_history::RedirectHistory},
38};
39
40pub const DEFAULT_MAX_REDIRECTS: usize = 5;
42pub const DEFAULT_MAX_RETRIES: u64 = 3;
44pub const DEFAULT_RETRY_WAIT_TIME_SECS: u64 = 1;
46pub const DEFAULT_TIMEOUT_SECS: u64 = 20;
48pub const DEFAULT_USER_AGENT: &str = concat!("lychee/", env!("CARGO_PKG_VERSION"));
50
51const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
54const TCP_KEEPALIVE: Duration = Duration::from_secs(60);
59
60#[derive(TypedBuilder, Debug, Clone)]
64#[builder(field_defaults(default, setter(into)))]
65pub struct ClientBuilder {
66 github_token: Option<SecretString>,
75
76 remaps: Option<Remaps>,
91
92 fallback_extensions: Vec<String>,
96
97 #[builder(default = None)]
111 index_files: Option<Vec<String>>,
112
113 includes: Option<RegexSet>,
119
120 excludes: Option<RegexSet>,
123
124 exclude_all_private: bool,
131
132 exclude_private_ips: bool,
159
160 exclude_link_local_ips: bool,
180
181 exclude_loopback_ips: bool,
197
198 include_mail: bool,
200
201 #[builder(default = DEFAULT_MAX_REDIRECTS)]
205 max_redirects: usize,
206
207 #[builder(default = DEFAULT_MAX_RETRIES)]
211 max_retries: u64,
212
213 min_tls_version: Option<tls::Version>,
215
216 #[builder(default_code = "String::from(DEFAULT_USER_AGENT)")]
226 user_agent: String,
227
228 allow_insecure: bool,
238
239 schemes: HashSet<String>,
244
245 custom_headers: HeaderMap,
253
254 #[builder(default = reqwest::Method::GET)]
256 method: reqwest::Method,
257
258 #[builder(default = DEFAULT_ACCEPTED_STATUS_CODES.clone())]
262 accepted: HashSet<StatusCode>,
263
264 timeout: Option<Duration>,
266
267 base: BaseInfo,
272
273 #[builder(default_code = "Duration::from_secs(DEFAULT_RETRY_WAIT_TIME_SECS as u64)")]
286 retry_wait_time: Duration,
287
288 require_https: bool,
294
295 cookie_jar: Option<Arc<CookieStoreMutex>>,
299
300 include_fragments: bool,
302
303 include_wikilinks: bool,
306
307 plugin_request_chain: RequestChain,
312
313 rate_limit_config: RateLimitConfig,
315
316 hosts: HostConfigs,
318}
319
320impl Default for ClientBuilder {
321 #[inline]
322 fn default() -> Self {
323 Self::builder().build()
324 }
325}
326
327impl ClientBuilder {
328 pub fn client(self) -> Result<Client> {
343 let redirect_history = RedirectHistory::new();
344 let reqwest_client = self
345 .build_client(&redirect_history)?
346 .build()
347 .map_err(ErrorKind::BuildRequestClient)?;
348
349 let client_map = self.build_host_clients(&redirect_history)?;
350
351 let host_pool = HostPool::new(
352 self.rate_limit_config,
353 self.hosts,
354 reqwest_client,
355 client_map,
356 );
357
358 let github_client = match self.github_token.as_ref().map(ExposeSecret::expose_secret) {
359 Some(token) if !token.is_empty() => Some(
360 Octocrab::builder()
361 .personal_token(token.to_string())
362 .build()
363 .map_err(|e: octocrab::Error| ErrorKind::BuildGithubClient(Box::new(e)))?,
366 ),
367 _ => None,
368 };
369
370 let filter = Filter {
371 includes: self.includes.map(Into::into),
372 excludes: self.excludes.map(Into::into),
373 schemes: self.schemes,
374 exclude_private_ips: self.exclude_all_private || self.exclude_private_ips,
377 exclude_link_local_ips: self.exclude_all_private || self.exclude_link_local_ips,
378 exclude_loopback_ips: self.exclude_all_private || self.exclude_loopback_ips,
379 include_mail: self.include_mail,
380 };
381
382 let website_checker = WebsiteChecker::new(
383 self.method,
384 self.retry_wait_time,
385 redirect_history.clone(),
386 self.max_retries,
387 self.accepted,
388 github_client,
389 self.require_https,
390 self.plugin_request_chain,
391 self.include_fragments,
392 Arc::new(host_pool),
393 );
394
395 Ok(Client {
396 remaps: self.remaps,
397 filter,
398 email_checker: MailChecker::new(self.timeout),
399 website_checker,
400 file_checker: FileChecker::new(
401 &self.base,
402 self.fallback_extensions,
403 self.index_files,
404 self.include_fragments,
405 self.include_wikilinks,
406 )?,
407 })
408 }
409
410 fn build_host_clients(&self, redirect_history: &RedirectHistory) -> Result<ClientMap> {
412 self.hosts
413 .iter()
414 .map(|(host, config)| {
415 let mut headers = self.default_headers()?;
416 headers.extend(config.headers.clone());
417 let client = self
418 .build_client(redirect_history)?
419 .default_headers(headers)
420 .build()
421 .map_err(ErrorKind::BuildRequestClient)?;
422 Ok((HostKey::from(host.as_str()), client))
423 })
424 .collect()
425 }
426
427 fn build_client(&self, redirect_history: &RedirectHistory) -> Result<reqwest::ClientBuilder> {
429 let mut builder = reqwest::ClientBuilder::new()
430 .gzip(true)
431 .default_headers(self.default_headers()?)
432 .danger_accept_invalid_certs(self.allow_insecure)
433 .connect_timeout(CONNECT_TIMEOUT)
434 .tcp_keepalive(TCP_KEEPALIVE)
435 .redirect(redirect_policy(
436 redirect_history.clone(),
437 self.max_redirects,
438 ));
439
440 if let Some(cookie_jar) = self.cookie_jar.clone() {
441 builder = builder.cookie_provider(cookie_jar);
442 }
443
444 if let Some(min_tls) = self.min_tls_version {
445 builder = builder.min_tls_version(min_tls);
446 }
447
448 if let Some(timeout) = self.timeout {
449 builder = builder.timeout(timeout);
450 }
451
452 Ok(builder)
453 }
454
455 fn default_headers(&self) -> Result<HeaderMap> {
456 let user_agent = self.user_agent.clone();
457 let mut headers = self.custom_headers.clone();
458
459 if let Some(prev_user_agent) =
460 headers.insert(header::USER_AGENT, HeaderValue::try_from(&user_agent)?)
461 {
462 debug!(
463 "Found user-agent in headers: {}. Overriding it with {user_agent}.",
464 prev_user_agent.to_str().unwrap_or("�"),
465 );
466 }
467
468 headers.insert(
469 header::TRANSFER_ENCODING,
470 HeaderValue::from_static("chunked"),
471 );
472
473 Ok(headers)
474 }
475}
476
477fn redirect_policy(redirect_history: RedirectHistory, max_redirects: usize) -> redirect::Policy {
480 redirect::Policy::custom(move |attempt| {
481 if attempt.previous().len() > max_redirects {
482 attempt.stop()
483 } else {
484 redirect_history.record_redirects(&attempt);
485 debug!("Following redirect to {}", attempt.url());
486 attempt.follow()
487 }
488 })
489}
490
491#[derive(Debug, Clone)]
496pub struct Client {
497 remaps: Option<Remaps>,
499
500 filter: Filter,
502
503 website_checker: WebsiteChecker,
505
506 file_checker: FileChecker,
508
509 email_checker: MailChecker,
511}
512
513impl Client {
514 #[must_use]
516 pub fn host_pool(&self) -> Arc<HostPool> {
517 self.website_checker.host_pool()
518 }
519
520 #[allow(clippy::missing_panics_doc)]
532 pub async fn check<T, E>(&self, request: T) -> Result<Response>
533 where
534 Request: TryFrom<T, Error = E>,
535 ErrorKind: From<E>,
536 {
537 let Request {
538 ref mut uri,
539 credentials,
540 source,
541 span,
542 ..
543 } = request.try_into()?;
544
545 self.remap(uri)?;
546
547 if self.is_excluded(uri) {
548 return Ok(Response::new(
549 uri.clone(),
550 Status::Excluded,
551 source.into(),
552 span,
553 None,
554 ));
555 }
556
557 let start = std::time::Instant::now(); let status = match uri.scheme() {
560 _ if uri.is_tel() => Status::Excluded, _ if uri.is_file() => self.check_file(uri).await,
562 _ if uri.is_mail() => self.check_mail(uri).await,
563 _ => self.check_website(uri, credentials).await?,
564 };
565
566 Ok(Response::new(
567 uri.clone(),
568 status,
569 source.into(),
570 span,
571 Some(start.elapsed()),
572 ))
573 }
574
575 pub async fn check_file(&self, uri: &Uri) -> Status {
577 self.file_checker.check(uri).await
578 }
579
580 pub fn remap(&self, uri: &mut Uri) -> Result<()> {
586 if let Some(ref remaps) = self.remaps {
587 uri.url = remaps.remap(&uri.url)?;
588 }
589 Ok(())
590 }
591
592 #[must_use]
594 pub fn is_excluded(&self, uri: &Uri) -> bool {
595 self.filter.is_excluded(uri)
596 }
597
598 pub async fn check_website(
608 &self,
609 uri: &Uri,
610 credentials: Option<BasicAuthCredentials>,
611 ) -> Result<Status> {
612 self.website_checker.check_website(uri, credentials).await
613 }
614
615 pub async fn check_mail(&self, uri: &Uri) -> Status {
617 self.email_checker.check_mail(uri).await
618 }
619}
620
621pub async fn check<T, E>(request: T) -> Result<Response>
634where
635 Request: TryFrom<T, Error = E>,
636 ErrorKind: From<E>,
637{
638 let client = ClientBuilder::builder().build().client()?;
639 client.check(request).await
640}
641
642#[cfg(test)]
643mod tests {
644 use std::{
645 fs::File,
646 time::{Duration, Instant},
647 };
648
649 use async_trait::async_trait;
650 use http::{StatusCode, header::HeaderMap};
651 use reqwest::header;
652 use tempfile::tempdir;
653 use test_utils::get_mock_client_response;
654 use test_utils::mock_server;
655 use test_utils::redirecting_mock_server;
656 use wiremock::{
657 Mock,
658 matchers::{method, path},
659 };
660
661 use super::ClientBuilder;
662 use crate::{
663 ErrorKind, Redirect, Redirects, Request, Status, Uri,
664 chain::{ChainResult, Handler, RequestChain},
665 };
666
667 #[tokio::test]
668 async fn test_nonexistent() {
669 let mock_server = mock_server!(StatusCode::NOT_FOUND);
670 let res = get_mock_client_response!(mock_server.uri()).await;
671
672 assert!(res.status().is_error());
673 }
674
675 #[tokio::test]
676 async fn test_nonexistent_with_path() {
677 let res = get_mock_client_response!("http://127.0.0.1/invalid").await;
678 assert!(res.status().is_error());
679 }
680
681 #[tokio::test]
682 async fn test_github() {
683 let res = get_mock_client_response!("https://github.com/lycheeverse/lychee").await;
684 assert!(res.status().is_success());
685 }
686
687 #[tokio::test]
688 async fn test_github_nonexistent_repo() {
689 let res = get_mock_client_response!("https://github.com/lycheeverse/not-lychee").await;
690 assert!(res.status().is_error());
691 }
692
693 #[tokio::test]
694 async fn test_github_nonexistent_file() {
695 let res = get_mock_client_response!(
696 "https://github.com/lycheeverse/lychee/blob/master/NON_EXISTENT_FILE.md",
697 )
698 .await;
699 assert!(res.status().is_error());
700 }
701
702 #[tokio::test]
703 async fn test_youtube() {
704 let res = get_mock_client_response!("https://www.youtube.com/watch?v=NlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").await;
706 assert!(res.status().is_success());
707
708 let res = get_mock_client_response!("https://www.youtube.com/watch?v=invalidNlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").await;
709 assert!(res.status().is_error());
710 }
711
712 #[tokio::test]
713 async fn test_basic_auth() {
714 let mut r: Request = "https://authenticationtest.com/HTTPAuth/"
715 .try_into()
716 .unwrap();
717
718 let res = get_mock_client_response!(r.clone()).await;
719 assert_eq!(res.status().code(), Some(401.try_into().unwrap()));
720
721 r.credentials = Some(crate::BasicAuthCredentials {
722 username: "user".into(),
723 password: "pass".into(),
724 });
725
726 let res = get_mock_client_response!(r).await;
727 assert!(matches!(
728 res.status(),
729 Status::Redirected(StatusCode::OK, _)
730 ));
731 }
732
733 #[tokio::test]
734 async fn test_non_github() {
735 let mock_server = mock_server!(StatusCode::OK);
736 let res = get_mock_client_response!(mock_server.uri()).await;
737
738 assert!(res.status().is_success());
739 }
740
741 #[tokio::test]
742 async fn test_invalid_ssl() {
743 let res = get_mock_client_response!("https://expired.badssl.com/").await;
744
745 assert!(res.status().is_error());
746
747 let res = ClientBuilder::builder()
749 .allow_insecure(true)
750 .build()
751 .client()
752 .unwrap()
753 .check("https://expired.badssl.com/")
754 .await
755 .unwrap();
756 assert!(res.status().is_success());
757 }
758
759 #[tokio::test]
760 async fn test_file() {
761 let dir = tempdir().unwrap();
762 let file = dir.path().join("temp");
763 File::create(file).unwrap();
764 let uri = format!("file://{}", dir.path().join("temp").to_str().unwrap());
765
766 let res = get_mock_client_response!(uri).await;
767 assert!(res.status().is_success());
768 }
769
770 #[tokio::test]
771 async fn test_custom_headers() {
772 let mut custom = HeaderMap::new();
774 custom.insert(header::ACCEPT, "text/html".parse().unwrap());
775 let res = ClientBuilder::builder()
776 .custom_headers(custom)
777 .build()
778 .client()
779 .unwrap()
780 .check("https://crates.io/crates/lychee")
781 .await
782 .unwrap();
783 assert!(res.status().is_success());
784 }
785
786 #[tokio::test]
787 async fn test_exclude_mail_by_default() {
788 let client = ClientBuilder::builder()
789 .exclude_all_private(true)
790 .build()
791 .client()
792 .unwrap();
793 assert!(client.is_excluded(&Uri {
794 url: "mailto://mail@example.com".try_into().unwrap()
795 }));
796 }
797
798 #[tokio::test]
799 async fn test_include_mail() {
800 let client = ClientBuilder::builder()
801 .include_mail(false)
802 .exclude_all_private(true)
803 .build()
804 .client()
805 .unwrap();
806 assert!(client.is_excluded(&Uri {
807 url: "mailto://mail@example.com".try_into().unwrap()
808 }));
809
810 let client = ClientBuilder::builder()
811 .include_mail(true)
812 .exclude_all_private(true)
813 .build()
814 .client()
815 .unwrap();
816 assert!(!client.is_excluded(&Uri {
817 url: "mailto://mail@example.com".try_into().unwrap()
818 }));
819 }
820
821 #[tokio::test]
822 async fn test_include_tel() {
823 let client = ClientBuilder::builder().build().client().unwrap();
824 assert!(client.is_excluded(&Uri {
825 url: "tel:1234567890".try_into().unwrap()
826 }));
827 }
828
829 #[tokio::test]
830 async fn test_require_https() {
831 let client = ClientBuilder::builder().build().client().unwrap();
832 let res = client.check("http://example.com").await.unwrap();
833 assert!(res.status().is_success());
834
835 let client = ClientBuilder::builder()
837 .require_https(true)
838 .build()
839 .client()
840 .unwrap();
841 let res = client.check("http://example.com").await.unwrap();
842 assert!(res.status().is_error());
843 }
844
845 #[tokio::test]
846 async fn test_timeout() {
847 let mock_delay = Duration::from_millis(20);
851 let checker_timeout = Duration::from_millis(10);
852 assert!(mock_delay > checker_timeout);
853
854 let mock_server = mock_server!(StatusCode::OK, set_delay(mock_delay));
855
856 let client = ClientBuilder::builder()
857 .timeout(checker_timeout)
858 .max_retries(0u64)
859 .build()
860 .client()
861 .unwrap();
862
863 let res = client.check(mock_server.uri()).await.unwrap();
864 assert!(res.status().is_timeout());
865 }
866
867 #[tokio::test]
868 async fn test_exponential_backoff() {
869 let mock_delay = Duration::from_millis(20);
870 let checker_timeout = Duration::from_millis(10);
871 assert!(mock_delay > checker_timeout);
872
873 let mock_server = mock_server!(StatusCode::OK, set_delay(mock_delay));
874
875 let warm_up_client = ClientBuilder::builder()
880 .max_retries(0_u64)
881 .build()
882 .client()
883 .unwrap();
884 let _res = warm_up_client.check(mock_server.uri()).await.unwrap();
885
886 let client = ClientBuilder::builder()
887 .timeout(checker_timeout)
888 .max_retries(3_u64)
889 .retry_wait_time(Duration::from_millis(50))
890 .build()
891 .client()
892 .unwrap();
893
894 let start = Instant::now();
904 let res = client.check(mock_server.uri()).await.unwrap();
905 let end = start.elapsed();
906
907 assert!(res.status().is_error());
908
909 assert!((350..=550).contains(&end.as_millis()));
912 }
913
914 #[tokio::test]
915 async fn test_avoid_reqwest_panic() {
916 let client = ClientBuilder::builder().build().client().unwrap();
917 let res = client.check("http://\"").await.unwrap();
919
920 assert!(matches!(
921 res.status(),
922 Status::Unsupported(ErrorKind::BuildRequestClient(_))
923 ));
924 assert!(res.status().is_unsupported());
925 }
926
927 #[tokio::test]
928 async fn test_max_redirects() {
929 let mock_server = wiremock::MockServer::start().await;
930
931 let redirect_uri = format!("{}/redirect", &mock_server.uri());
932 let redirect = wiremock::ResponseTemplate::new(StatusCode::PERMANENT_REDIRECT)
933 .insert_header("Location", redirect_uri.as_str());
934
935 let redirect_count = 15usize;
936 let initial_invocation = 1;
937
938 Mock::given(method("GET"))
940 .and(path("/redirect"))
941 .respond_with(move |_: &_| redirect.clone())
942 .expect(initial_invocation + redirect_count as u64)
943 .mount(&mock_server)
944 .await;
945
946 let res = ClientBuilder::builder()
947 .max_redirects(redirect_count)
948 .build()
949 .client()
950 .unwrap()
951 .check(redirect_uri.clone())
952 .await
953 .unwrap();
954
955 assert_eq!(
956 res.status(),
957 &Status::Error(ErrorKind::RejectedStatusCode(
958 StatusCode::PERMANENT_REDIRECT
959 ))
960 );
961 }
962
963 #[tokio::test]
964 async fn test_redirects() {
965 redirecting_mock_server!(async |redirect_url: Url, ok_url| {
966 let res = ClientBuilder::builder()
967 .max_redirects(1_usize)
968 .build()
969 .client()
970 .unwrap()
971 .check(Uri::from((redirect_url).clone()))
972 .await
973 .unwrap();
974
975 let mut redirects = Redirects::new(redirect_url);
976 redirects.push(Redirect {
977 url: ok_url,
978 code: StatusCode::PERMANENT_REDIRECT,
979 });
980 assert_eq!(res.status(), &Status::Redirected(StatusCode::OK, redirects));
981 })
982 .await;
983 }
984
985 #[tokio::test]
986 async fn test_unsupported_scheme() {
987 let examples = vec![
988 "ftp://example.com",
989 "gopher://example.com",
990 "slack://example.com",
991 ];
992
993 for example in examples {
994 let client = ClientBuilder::builder().build().client().unwrap();
995 let res = client.check(example).await.unwrap();
996 assert!(res.status().is_unsupported());
997 }
998 }
999
1000 #[tokio::test]
1001 async fn test_chain() {
1002 use reqwest::Request;
1003
1004 #[derive(Debug)]
1005 struct ExampleHandler();
1006
1007 #[async_trait]
1008 impl Handler<Request, Status> for ExampleHandler {
1009 async fn handle(&mut self, _: Request) -> ChainResult<Request, Status> {
1010 ChainResult::Done(Status::Excluded)
1011 }
1012 }
1013
1014 let chain = RequestChain::new(vec![Box::new(ExampleHandler {})]);
1015
1016 let client = ClientBuilder::builder()
1017 .plugin_request_chain(chain)
1018 .build()
1019 .client()
1020 .unwrap();
1021
1022 let result = client.check("http://example.com");
1023 let res = result.await.unwrap();
1024 assert_eq!(res.status(), &Status::Excluded);
1025 }
1026}