resolvematrix/src/server.rs
Jade Ellis 22e5fc1075
All checks were successful
Checks / Prek / Pre-commit & Formatting (push) Successful in 2m1s
Checks / Prek / Clippy and Cargo Tests (push) Successful in 4m23s
style: Fix clippy lints in tests
2026-04-12 15:47:38 +01:00

1170 lines
41 KiB
Rust

use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use hickory_resolver::TokioResolver;
use reqwest::{Client, StatusCode};
use serde::Deserialize;
use thiserror::Error;
/// Error type for Matrix server resolution.
#[derive(Debug, Error)]
pub enum ResolveServerError {
#[error("Failed to parse address: {0}")]
AddrParse(#[from] std::net::AddrParseError),
#[error("HTTP client error: {0}")]
Http(#[from] reqwest::Error),
#[error("DNS resolution error: {0}")]
Dns(#[from] hickory_resolver::ResolveError),
#[error("Invalid port number: {0}")]
InvalidPort(#[from] std::num::ParseIntError),
#[error("Malformed .well-known response")]
MalformedWellKnown,
#[error("Unexpected error: {0}")]
Other(String),
}
/// Represents the resolved destination for a Matrix server.
#[derive(Debug, Clone)]
pub enum ResolvedDestination {
/// A literal IP address and port (e.g., 1.2.3.4:8448)
Literal(SocketAddr),
/// A named host and port (e.g., "matrix.org", "8448")
Named(String, String),
}
/// Result of a Matrix server resolution.
///
/// Contains the resolved destination (IP/Port or Hostname/Port) and the
/// hostname to use for SNI/Host headers.
#[derive(Debug, Clone)]
pub struct Resolution {
/// The actual destination to connect to.
pub destination: ResolvedDestination,
/// The hostname to use for TLS SNI and HTTP Host header.
pub host: String,
}
impl Resolution {
/// Get the base URL for making requests to this resolution.
/// Uses the host field for proper SNI.
#[must_use]
pub fn base_url(&self) -> String {
match &self.destination {
ResolvedDestination::Literal(addr) => format!("https://{addr}"),
ResolvedDestination::Named(_dest_host, dest_port) => {
let port: u16 = dest_port.parse().unwrap_or(8448);
if self.host.contains(':') {
format!("https://{}", self.host)
} else {
format!("https://{}:{}", self.host, port)
}
}
}
}
/// Get the hostname (without port) from the host field for DNS mapping.
fn sni_hostname(&self) -> String {
if let Some(colon_pos) = self.host.find(':') {
self.host[..colon_pos].to_string()
} else {
self.host.clone()
}
}
/// Get the destination address for DNS resolution mapping.
async fn destination_addr(&self, resolver: &TokioResolver) -> Option<SocketAddr> {
match &self.destination {
ResolvedDestination::Literal(addr) => Some(*addr),
ResolvedDestination::Named(dest_host, dest_port) => {
let port: u16 = dest_port.parse().ok()?;
// Try to parse as IP first
if let Ok(ip) = dest_host.parse::<IpAddr>() {
return Some(SocketAddr::new(ip, port));
}
// Resolve via DNS
match resolver.lookup_ip(dest_host.as_str()).await {
Ok(lookup) => {
let ip = lookup.iter().next()?;
Some(SocketAddr::new(ip, port))
}
Err(_) => None,
}
}
}
}
}
impl ResolvedDestination {
/// Get the destination hostname
pub fn hostname(&self) -> String {
match &self {
ResolvedDestination::Literal(addr) => addr.ip().to_string(),
ResolvedDestination::Named(dest_host, _dest_port) => dest_host.clone(),
}
}
/// Get the destination port
pub fn port(&self) -> u16 {
match &self {
ResolvedDestination::Literal(addr) => addr.port(),
ResolvedDestination::Named(_dest_host, dest_port) => {
dest_port.parse::<u16>().unwrap_or(8448)
}
}
}
/// Return the host:port formatted string of the resolved destination server (not SNI host)
pub fn host_port(&self) -> String {
match &self {
ResolvedDestination::Literal(addr) => addr.to_string(),
ResolvedDestination::Named(host, port) => format!("{host}:{port}"),
}
}
}
/// Simple cache entry with expiry time.
#[derive(Clone, Debug)]
pub struct CacheEntry {
resolution: Resolution,
expires_at: Instant,
is_override: bool, // If true, this is a Matrix resolution that should be refetched when expired
}
/// Result of a cache lookup.
#[derive(Debug)]
enum CacheLookup {
/// Valid cached entry found
Valid(Resolution),
/// Expired Matrix override - should refetch via Matrix resolution
ExpiredOverride(String), // Returns the hostname that needs refetching
/// No entry found or expired non-override
Miss,
}
/// Simple cache for Matrix server resolutions with TTL-based expiry.
#[derive(Clone)]
pub(crate) struct Cache {
inner: Arc<RwLock<HashMap<String, CacheEntry>>>,
hostname_map: Arc<RwLock<HashMap<String, String>>>, // hostname -> server_name
ttl: Duration,
}
impl Cache {
fn new(ttl: Duration) -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
hostname_map: Arc::new(RwLock::new(HashMap::new())),
ttl,
}
}
fn get(&self, server_name: &str) -> Option<Resolution> {
// First try read lock to check if entry exists and is valid
if let Ok(cache) = self.inner.read()
&& let Some(entry) = cache.get(server_name)
&& Instant::now() < entry.expires_at
{
return Some(entry.resolution.clone());
}
// If expired or not found, acquire write lock to remove expired entry
if let Ok(mut cache) = self.inner.write()
&& let Some(entry) = cache.get(server_name)
&& Instant::now() >= entry.expires_at
{
cache.remove(server_name);
}
None
}
fn lookup(&self, hostname: &str) -> CacheLookup {
// Try direct lookup first with read lock
let lookup_result = if let Ok(cache) = self.inner.read() {
if let Some(entry) = cache.get(hostname) {
if Instant::now() < entry.expires_at {
return CacheLookup::Valid(entry.resolution.clone());
}
// Entry exists but is expired
Some(entry.is_override)
} else {
None
}
} else {
None
};
// If we found an expired entry, remove it with write lock
if let Some(is_override) = lookup_result {
if let Ok(mut cache) = self.inner.write() {
cache.remove(hostname);
}
if is_override {
return CacheLookup::ExpiredOverride(hostname.to_string());
} else {
return CacheLookup::Miss;
}
}
// Try hostname mapping
if let Ok(hostname_map) = self.hostname_map.read()
&& let Some(server_name) = hostname_map.get(hostname)
{
if let Some(resolution) = self.get(server_name) {
return CacheLookup::Valid(resolution);
}
// If the mapping exists but the server_name entry is expired/missing,
// treat it as an expired override
return CacheLookup::ExpiredOverride(server_name.clone());
}
CacheLookup::Miss
}
fn set(&self, server_name: String, resolution: &Resolution) {
if let Ok(mut cache) = self.inner.write() {
cache.insert(
server_name.clone(),
CacheEntry {
resolution: resolution.clone(),
expires_at: Instant::now() + self.ttl,
is_override: true, // All Matrix resolutions are overrides
},
);
// Add hostname mapping for DNS lookups
if let Ok(mut hostname_map) = self.hostname_map.write() {
let sni_hostname = resolution.sni_hostname();
if sni_hostname != server_name {
hostname_map.insert(sni_hostname, server_name);
}
}
}
}
/// Remove a single entry from the cache, returning the previously existing entry if there was one
fn remove_entry(&self, server_name: &str) -> Option<CacheEntry> {
match self.inner.write() {
Ok(mut cache) => cache.remove(server_name),
Err(_) => None,
}
}
/// Clear all cache entries. Returns nothing.
fn clear(&self) {
if let Ok(mut cache) = self.inner.write() {
cache.clear();
}
}
}
#[derive(Clone)]
/// A custom DNS resolver for `reqwest` that handles Matrix server name resolution.
///
/// This resolver integrates with the `MatrixResolver` cache and logic to ensure that
/// HTTP requests made by `reqwest` are routed to the correct IP address and port
/// as discovered by the Matrix server discovery process.
///
/// It exists to ensure that the correct SNI is used. The resolver base URL is the
/// domain expected for SNI, and the `MatrixDnsResolver` resolves it to the correct destination.
pub struct MatrixDnsResolver {
resolver: Arc<TokioResolver>,
cache: Cache,
matrix_resolver: Arc<MatrixResolver>,
}
impl MatrixDnsResolver {
pub(crate) fn new(
resolver: Arc<TokioResolver>,
cache: Cache,
matrix_resolver: Arc<MatrixResolver>,
) -> Self {
Self {
resolver,
cache,
matrix_resolver,
}
}
}
impl reqwest::dns::Resolve for MatrixDnsResolver {
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
let name_str = name.as_str().to_string();
let resolver = self.resolver.clone();
let cache = self.cache.clone();
let matrix_resolver = self.matrix_resolver.clone();
Box::pin(async move {
// Check cache and determine what to do
match cache.lookup(&name_str) {
CacheLookup::Valid(resolution) => {
// Valid cached entry - use it
if let Some(addr) = resolution.destination_addr(&resolver).await {
tracing::trace!("DNS cache hit for {name_str} -> {addr}");
return Ok(Box::new(std::iter::once(addr))
as Box<dyn Iterator<Item = SocketAddr> + Send>);
}
}
CacheLookup::ExpiredOverride(server_name) => {
// Expired Matrix override - refetch via Matrix resolution
tracing::trace!("DNS cache expired override for {name_str}, refetching");
match matrix_resolver.resolve_server(&server_name).await {
Ok(resolution) => {
if let Some(addr) = resolution.destination_addr(&resolver).await {
return Ok(Box::new(std::iter::once(addr))
as Box<dyn Iterator<Item = SocketAddr> + Send>);
} else {
// Something funky, they should re-resolve the server
}
}
Err(e) => {
tracing::warn!("Failed to refetch Matrix server {server_name}: {e:?}",);
}
}
}
CacheLookup::Miss => {
// No override - use standard DNS
}
}
// Fallback: standard DNS lookup
tracing::trace!("DNS fallback for {name_str}, using standard DNS");
match resolver.lookup_ip(&name_str).await {
Ok(lookup) => {
let addrs: Vec<SocketAddr> = lookup
.iter()
.map(|ip| SocketAddr::new(ip, 8448)) // Default Matrix port
.collect();
Ok(Box::new(addrs.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
}
Err(e) => Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>),
}
})
}
}
/// The main resolver struct for Matrix server resolution.
pub struct MatrixResolver {
client: Client,
resolver: Arc<TokioResolver>,
cache: Cache,
}
impl MatrixResolver {
/// Create a new `MatrixResolver` with default TTL of 5 minutes.
///
/// # Errors
///
/// Returns an error if the DNS resolver or HTTP client cannot be initialized.
pub fn new() -> Result<Self, ResolveServerError> {
Self::new_with_ttl(Duration::from_secs(300))
}
/// Create a new `MatrixResolver` with a custom cache TTL.
///
/// # Errors
///
/// Returns an error if the DNS resolver or HTTP client cannot be initialized.
pub fn new_with_ttl(cache_ttl: Duration) -> Result<Self, ResolveServerError> {
let resolver = Arc::new(hickory_resolver::Resolver::builder_tokio()?.build());
let client = Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()?;
let cache = Cache::new(cache_ttl);
Ok(MatrixResolver {
client,
resolver,
cache,
})
}
/// Create a client with custom builder that can be reused for all Matrix servers.
///
/// The client uses a custom DNS resolver that dynamically looks up Matrix servers
/// from the cache, allowing one client to handle all federation requests.
///
/// # Errors
///
/// Returns an error if the client cannot be built.
pub fn create_client_with_builder(
self: &Arc<Self>,
builder: reqwest::ClientBuilder,
) -> Result<Client, ResolveServerError> {
let dns_resolver =
MatrixDnsResolver::new(self.resolver.clone(), self.cache.clone(), self.clone());
Ok(builder.dns_resolver(Arc::new(dns_resolver)).build()?)
}
/// Create a standard reqwest client that can be reused for all Matrix servers.
///
/// # Errors
///
/// Returns an error if the client cannot be built.
pub fn create_client(self: &Arc<Self>) -> Result<Client, ResolveServerError> {
let builder = Client::builder().timeout(std::time::Duration::from_secs(10));
self.create_client_with_builder(builder)
}
/// Resolve a Matrix server name and return the Resolution.
///
/// The returned Resolution can be used to construct URLs via `resolution.base_url()`.
/// When making a request, you must use a client built via the resolver to handle
/// SRV records correctly.
///
/// # Errors
///
/// Returns an error if resolution fails (e.g. DNS failure, invalid response).
///
/// # Example
///
/// ```rust,no_run
/// # use resolvematrix::server::MatrixResolver;
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// let resolver = MatrixResolver::new()?;
/// let resolution = resolver.resolve_server("matrix.org").await?;
///
/// assert_eq!(resolution.host, "matrix-federation.matrix.org");
/// # Ok(())
/// # }
/// ```
pub async fn resolve_server(
&self,
server_name: &str,
) -> Result<Resolution, ResolveServerError> {
// Check cache first
if let Some(resolution) = self.cache.get(server_name) {
tracing::trace!("Cache hit for {server_name}");
return Ok(resolution);
}
// Perform resolution
let resolution = self.resolve_actual_dest(server_name).await?;
// Cache the result
self.cache.set(server_name.to_string(), &resolution);
Ok(resolution)
}
/// Resolve the actual destination according to Matrix spec.
/// <https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names>
#[tracing::instrument(
name = "actual",
level = "debug",
skip(self),
fields(dest = %dest)
)]
async fn resolve_actual_dest(&self, dest: &str) -> Result<Resolution, ResolveServerError> {
// 1. If the hostname is an IP literal
if let Some((ip, port)) = get_ip_with_port(dest) {
tracing::info!(
ip = %ip,
port = port,
step = "ip_literal",
"Resolved IP literal with port"
);
let socket = SocketAddr::new(ip, port.unwrap_or(8448));
return Ok(Resolution {
destination: ResolvedDestination::Literal(socket),
host: dest.to_owned(),
});
}
// 2. Hostname with explicit port
if let Some(pos) = dest.find(':') {
let (host_part, port_part) = dest.split_at(pos);
let port_str = port_part.trim_start_matches(':');
tracing::info!(
host = %host_part,
port = %port_str,
step = "explicit_port",
"Resolved hostname with explicit port"
);
return Ok(Resolution {
destination: ResolvedDestination::Named(host_part.to_owned(), port_str.to_owned()),
host: dest.to_owned(),
});
}
// 3. Well-known delegation
if let Some(res) = self.resolve_well_known(dest).await {
tracing::info!(?res, step = "well_known", "Resolved .well-known delegation");
match res {
WellKnownServerResult::Ip(ip, port) => {
tracing::info!(
ip = %ip,
port = port.unwrap_or(8448),
step = "well_known_ip_literal",
"Resolved .well-known IP literal"
);
let socket = SocketAddr::new(ip, port.unwrap_or(8448));
return Ok(Resolution {
destination: ResolvedDestination::Literal(socket),
host: dest.to_owned(),
});
}
WellKnownServerResult::Domain(domain, None) => {
// 3.3/3.4: Hostname, no port in .well-known
if let Some((srv_host, srv_port)) = self.query_srv_record(&domain).await? {
tracing::info!(
srv_host = %srv_host,
srv_port = srv_port,
step = "well_known_host_srv",
"Resolved SRV from .well-known hostname without port"
);
return Ok(Resolution {
destination: ResolvedDestination::Named(srv_host, srv_port.to_string()),
host: domain,
});
} else {
// 3.5: No SRV, fallback to A/AAAA/CNAME + 8448
tracing::trace!(
delegated = %domain,
step = "well_known_fallback",
"Fallback to .well-known host with default port"
);
return Ok(Resolution {
destination: ResolvedDestination::Named(
domain.clone(),
"8448".to_owned(),
),
host: domain,
});
}
}
WellKnownServerResult::Domain(domain, Some(port)) => {
tracing::info!(
domain = %domain,
port = port,
step = "well_known_domain",
"Resolved .well-known domain with port"
);
return Ok(Resolution {
destination: ResolvedDestination::Named(domain.clone(), port.to_string()),
host: domain,
});
}
}
}
// 4. SRV lookup on original hostname
if let Some((srv_host, srv_port)) = self.query_srv_record(dest).await? {
tracing::trace!(
srv_host = %srv_host,
srv_port = srv_port,
step = "srv_lookup",
"Resolved SRV record on original hostname"
);
return Ok(Resolution {
destination: ResolvedDestination::Named(srv_host, srv_port.to_string()),
host: dest.to_owned(),
});
}
// 5. Fallback: A/AAAA/CNAME + 8448
tracing::trace!(
host = %dest,
step = "fallback",
"Fallback to original hostname with default port"
);
Ok(Resolution {
destination: ResolvedDestination::Named(dest.to_owned(), "8448".to_owned()),
host: dest.to_owned(),
})
}
/// Resolve .well-known delegation for a hostname.
#[tracing::instrument(
level = "trace",
skip(self),
fields(hostname = %hostname)
)]
async fn resolve_well_known(&self, hostname: &str) -> Option<WellKnownServerResult> {
#[derive(Deserialize)]
struct WellKnown {
#[serde(rename = "m.server")]
m_server: String,
}
let url = format!("https://{hostname}/.well-known/matrix/server");
tracing::trace!(url = %url, "Fetching .well-known matrix server");
let Ok(resp) = self.client.get(&url).send().await else {
return None;
};
if resp.status() != StatusCode::OK {
return None;
}
let wk: WellKnown = match resp.json().await {
Ok(wk) => wk,
Err(e) => {
tracing::warn!(
error = %e,
url = %url,
"Failed to parse .well-known matrix server JSON"
);
return None;
}
};
if let Some((ip, port)) = get_ip_with_port(&wk.m_server) {
tracing::trace!(
ip = %ip,
port = ?port,
"Parsed .well-known matrix server IP and port"
);
return Some(WellKnownServerResult::Ip(ip, port));
}
let (host, port) = parse_server_name(&wk.m_server);
tracing::trace!(
well_known_host = %host,
well_known_port = ?port,
"Parsed .well-known matrix server domain"
);
Some(WellKnownServerResult::Domain(host, port))
}
/// Query SRV records for a hostname, returning (target, port) if found.
#[tracing::instrument(
level = "trace",
skip(self),
fields(hostname = %hostname)
)]
async fn query_srv_record(
&self,
hostname: &str,
) -> Result<Option<(String, u16)>, ResolveServerError> {
let srv_names = [
format!("_matrix-fed._tcp.{hostname}"),
format!("_matrix._tcp.{hostname}"),
];
for srv in &srv_names {
tracing::trace!(srv = %srv, "Querying SRV record");
let lookup = self.resolver.srv_lookup(srv).await;
if let Ok(result) = lookup
&& let Some(record) = result.iter().next()
{
let target = record.target().to_utf8();
let port = record.port();
return Ok(Some((target.trim_end_matches('.').to_owned(), port)));
}
}
tracing::trace!(hostname = %hostname, "No SRV records found for hostname");
Ok(None)
}
/// Remove a single entry from the cache, returning the removed entry if it existed
#[tracing::instrument(
level = "trace",
skip(self),
fields(hostname = %hostname)
)]
pub fn remove_cache_entry(&self, hostname: &str) -> Option<CacheEntry> {
self.cache.remove_entry(hostname)
}
/// Clear entire cache
#[tracing::instrument(level = "trace", skip(self))]
pub fn clear_cache(&self) {
self.cache.clear()
}
}
#[derive(Debug)]
enum WellKnownServerResult {
Ip(IpAddr, Option<u16>),
Domain(String, Option<u16>),
}
/// Parses a Matrix server name into (hostname, Option<port>)
#[tracing::instrument(
name = "parse_server_name",
level = "trace",
fields(server_name = %server_name)
)]
fn parse_server_name(server_name: &str) -> (String, Option<u16>) {
if let Some((host, port)) = server_name.rsplit_once(':')
&& let Ok(port) = u16::from_str(port)
{
return (host.to_string(), Some(port));
}
(server_name.to_string(), None)
}
/// If the string is an IP literal (with optional port), returns (`IpAddr`, port).
#[tracing::instrument(
name = "get_ip_with_port",
level = "trace",
fields(input = %s)
)]
fn get_ip_with_port(s: &str) -> Option<(IpAddr, Option<u16>)> {
// Try SocketAddr first (IP:port)
if let Ok(sock) = SocketAddr::from_str(s) {
tracing::trace!(
ip = %sock.ip(),
port = sock.port(),
"Parsed SocketAddr from input"
);
return Some((sock.ip(), Some(sock.port())));
}
// Try IP only
if let Ok(ip) = IpAddr::from_str(s) {
tracing::trace!(
ip = %ip,
port = 8448,
"Parsed IpAddr from input, using default port"
);
return Some((ip, None));
}
tracing::debug!(input = %s, "Input is not an IP literal");
None
}
#[cfg(test)]
mod tests {
use assertables::{assert_none, assert_some};
use rstest::rstest;
use tracing::debug;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use super::*;
#[test]
fn test_get_ip_with_port() {
assert_eq!(
get_ip_with_port("127.0.0.1:8080"),
Some((IpAddr::from([127, 0, 0, 1]), Some(8080)))
);
assert_eq!(
get_ip_with_port("[::1]:8080"),
Some((IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]), Some(8080)))
);
assert_eq!(
get_ip_with_port("127.0.0.1"),
Some((IpAddr::from([127, 0, 0, 1]), None))
);
assert_eq!(
get_ip_with_port("::1"),
Some((IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]), None))
);
assert_eq!(get_ip_with_port("example.com"), None);
}
#[test]
fn test_get_ip_with_port_invalid() {
assert_eq!(get_ip_with_port("invalid"), None);
assert_eq!(get_ip_with_port("127.0.0.1:invalid"), None);
assert_eq!(get_ip_with_port("::1:invalid"), None);
assert_eq!(get_ip_with_port("127.0.0.1:8080:invalid"), None);
assert_eq!(get_ip_with_port("::1:8080:invalid"), None);
}
#[tokio::test]
async fn test_resolve() {
init_tracing();
let resolver = MatrixResolver::new().unwrap();
let _ = dbg!(resolver.resolve_server("matrix.org").await.unwrap());
let _ = dbg!(resolver.resolve_server("ellis.link").await.unwrap());
}
/// Helper function to initialize tracing for tests
fn init_tracing() {
let _ = tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer()
.with_test_writer()
.with_target(false),
)
.try_init();
}
#[allow(dead_code)]
#[derive(Deserialize, Debug)]
struct ServerVersionEndpoint {
pub server: ServerVersionServer,
}
#[allow(dead_code)]
#[derive(Deserialize, Debug)]
struct ServerVersionServer {
pub name: String,
pub version: String,
}
/// Parameterized test for server resolution.
#[rstest]
#[case::maunium_net("maunium.net")]
#[case::timedout_uk_port("timedout.uk:69")]
#[case::nexy7574_co_uk("nexy7574.co.uk")]
#[case::matrix_org("matrix.org")]
#[case::matrixrooms_info("matrixrooms.info")]
#[case::resolvematrix_2_port("2.s.resolvematrix.dev:7652")]
#[case::resolvematrix_3b("3b.s.resolvematrix.dev")]
#[case::resolvematrix_3c("3c.s.resolvematrix.dev")]
#[case::resolvematrix_3d("3d.s.resolvematrix.dev")]
#[case::resolvematrix_4("4.s.resolvematrix.dev")]
#[case::resolvematrix_5("5.s.resolvematrix.dev")]
#[case::resolvematrix_3c_msc4040("3c.msc4040.s.resolvematrix.dev")]
#[case::resolvematrix_4_msc4040("4.msc4040.s.resolvematrix.dev")]
#[tokio::test]
async fn test_server_resolver(#[case] server_name: &str) {
init_tracing();
let resolver = Arc::new(MatrixResolver::new().unwrap());
tracing::info!("Testing {server_name}");
// Resolve server
let resolution = resolver.resolve_server(server_name).await.unwrap();
// Create client with custom DNS resolver
let builder = Client::builder()
.tls_danger_accept_invalid_certs(true)
.timeout(std::time::Duration::from_secs(10));
let client = resolver.create_client_with_builder(builder).unwrap();
// Build URL using the resolution's base_url
let url = format!("{}/_matrix/federation/v1/version", resolution.base_url());
debug!(?resolution, ?url, "Resolved server");
let request = client.get(&url).build().unwrap();
let response = client.execute(request).await;
match response {
Ok(resp) => {
let status = resp.status();
let json: Option<ServerVersionEndpoint> = resp.json().await.ok();
tracing::debug!(%status, "Response");
if status == StatusCode::OK {
tracing::info!(
"✓ Successfully fetched federation version from {server_name}: {json:?}"
);
} else {
tracing::warn!("Server {server_name} returned non-200 status: {status}.");
panic!();
}
}
Err(e) => {
tracing::warn!("Failed to fetch federation version from {server_name}: {e:?}");
panic!();
}
}
}
/// Test `parse_server_name` function with various inputs
#[rstest]
#[case::no_port("matrix.org", "matrix.org", None)]
#[case::with_port("matrix.org:8448", "matrix.org", Some(8448))]
#[case::high_port("server.com:9999", "server.com", Some(9999))]
#[case::low_port("localhost:80", "localhost", Some(80))]
#[case::ipv4_with_port("192.168.1.1:8008", "192.168.1.1", Some(8008))]
fn test_parse_server_name(
#[case] input: &str,
#[case] expected_host: &str,
#[case] expected_port: Option<u16>,
) {
let (host, port) = parse_server_name(input);
assert_eq!(host, expected_host);
assert_eq!(port, expected_port);
}
/// Test IP literal detection with parameterized cases
#[rstest]
#[case::ipv4_with_port("127.0.0.1:8080", Some((IpAddr::from([127, 0, 0, 1]), Some(8080))))]
#[case::ipv4_no_port("127.0.0.1", Some((IpAddr::from([127, 0, 0, 1]), None)))]
#[case::ipv6_with_port("[::1]:8080", Some((IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]), Some(8080))))]
#[case::ipv6_no_port("::1", Some((IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]), None)))]
#[case::hostname("example.com", None)]
#[case::hostname_with_port("example.com:8448", None)]
#[case::invalid("not-an-ip", None)]
fn test_get_ip_with_port_parameterized(
#[case] input: &str,
#[case] expected: Option<(IpAddr, Option<u16>)>,
) {
assert_eq!(get_ip_with_port(input), expected);
}
/// Test resolution of well-known servers
#[rstest]
#[case::maunium("maunium.net")]
#[case::nexy("nexy7574.co.uk")]
#[tokio::test]
async fn test_well_known_resolution(#[case] server_name: &str) {
init_tracing();
let resolver = MatrixResolver::new().unwrap();
let resolution = resolver.resolve_server(server_name).await;
assert!(
resolution.is_ok(),
"Failed to resolve {server_name}: {:?}",
resolution.err()
);
let resolved = resolution.unwrap();
tracing::info!(
"Resolved {server_name} to destination: {:?}, host: {}",
resolved.destination,
resolved.host
);
// Verify the resolution contains valid data
match &resolved.destination {
ResolvedDestination::Literal(addr) => {
assert!(addr.port() > 0, "Port should be greater than 0");
}
ResolvedDestination::Named(host, port) => {
assert!(!host.is_empty(), "Host should not be empty");
assert!(!port.is_empty(), "Port should not be empty");
let port_num: u16 = port.parse().expect("Port should be a valid number");
assert!(port_num > 0, "Port should be greater than 0");
}
}
}
/// Test servers with explicit ports
#[rstest]
#[case::standard_port("matrix.org:8448")]
#[case::custom_port("timedout.uk:69")]
#[case::high_port("test.server:9999")]
#[tokio::test]
async fn test_explicit_port_resolution(#[case] server_name: &str) {
init_tracing();
let resolver = MatrixResolver::new().unwrap();
let resolution = resolver.resolve_server(server_name).await;
assert!(
resolution.is_ok(),
"Failed to resolve {server_name}: {:?}",
resolution.err()
);
let resolved = resolution.unwrap();
// When a port is explicitly specified, it should be preserved
match &resolved.destination {
ResolvedDestination::Named(_, port) => {
let expected_port = server_name.split(':').nth(1).unwrap();
assert_eq!(
port, expected_port,
"Port should match the explicit port in server name"
);
}
ResolvedDestination::Literal(addr) => {
let expected_port: u16 = server_name.split(':').nth(1).unwrap().parse().unwrap();
assert_eq!(
addr.port(),
expected_port,
"Port should match the explicit port in server name"
);
}
}
}
/// Test IP literal resolution
#[rstest]
#[case::ipv4_default("192.168.1.1")]
#[case::ipv4_custom_port("192.168.1.1:8008")]
#[case::ipv6_default("::1")]
#[case::ipv6_custom_port("[::1]:8008")]
#[tokio::test]
async fn test_ip_literal_resolution(#[case] server_name: &str) {
init_tracing();
let resolver = MatrixResolver::new().unwrap();
let resolution = resolver.resolve_server(server_name).await;
assert!(
resolution.is_ok(),
"Failed to resolve {server_name}: {:?}",
resolution.err()
);
let resolved = resolution.unwrap();
// IP literals should always resolve to Literal variant
match &resolved.destination {
ResolvedDestination::Literal(addr) => {
assert!(addr.port() > 0, "Port should be greater than 0");
// If no port specified, should default to 8448
if !server_name.contains(':') {
assert_eq!(addr.port(), 8448, "Should default to port 8448");
}
}
ResolvedDestination::Named(..) => {
panic!("IP literal should resolve to Literal variant")
}
}
}
#[rstest]
#[case::resolvematrix_2_port("2.s.resolvematrix.dev:7652", "2.s.resolvematrix.dev", 7652)] // Explicit port
#[case::resolvematrix_3b("3b.s.resolvematrix.dev", "wk.3b.s.resolvematrix.dev", 7753)] // Delegated explicit port
#[case::resolvematrix_3c("3c.s.resolvematrix.dev", "srv.wk.3c.s.resolvematrix.dev", 7754)] // Delegated `matrix` SRV
#[case::resolvematrix_3d("3d.s.resolvematrix.dev", "wk.3d.s.resolvematrix.dev", 8448)] // Delegated default port
#[case::resolvematrix_4("4.s.resolvematrix.dev", "srv.4.s.resolvematrix.dev", 7855)] // `matrix` SRV
#[case::resolvematrix_5("5.s.resolvematrix.dev", "5.s.resolvematrix.dev", 8448)] // Default port
#[case::resolvematrix_3c_msc4040(
"3c.msc4040.s.resolvematrix.dev",
"srv.wk.3c.msc4040.s.resolvematrix.dev",
7053
)] // Delegated `matrix-fed` SRV
#[case::resolvematrix_4_msc4040(
"4.msc4040.s.resolvematrix.dev",
"srv.4.msc4040.s.resolvematrix.dev",
7054
)] // `matrix-fed` SRV
#[tokio::test]
async fn test_resolvematrix(
#[case] input: &str,
#[case] expected_host: &str,
#[case] expected_port: u16,
) {
let resolver = Arc::new(MatrixResolver::new().unwrap());
tracing::info!("Testing {input}");
// Resolve server
let resolution = resolver.resolve_server(input).await.unwrap();
assert_eq!(resolution.destination.hostname(), expected_host);
assert_eq!(resolution.destination.port(), expected_port);
}
/// Demonstrate reuse of the same client across different resolutions
#[tokio::test]
async fn test_client_reuse() {
init_tracing();
let resolver = Arc::new(MatrixResolver::new().unwrap());
// Create ONE client that will be reused for all servers
let builder = Client::builder()
.tls_danger_accept_invalid_certs(true)
.timeout(std::time::Duration::from_secs(10));
let client = resolver.create_client_with_builder(builder).unwrap();
let servers = vec!["matrix.org", "nexy7574.co.uk", "matrixrooms.info"];
for server_name in servers {
tracing::info!("Testing {server_name} with reused client");
// Resolve the server
let resolution = resolver.resolve_server(server_name).await.unwrap();
// Make a request
let url = format!("{}/_matrix/federation/v1/version", resolution.base_url());
debug!(?resolution, ?url, "Resolved server");
let response = client.get(&url).send().await;
match response {
Ok(resp) => {
let status = resp.status();
tracing::info!("✓ {server_name} returned status {status}");
assert_eq!(status, StatusCode::OK);
}
Err(e) => {
tracing::warn!("Failed to fetch from {server_name}: {e:?}");
panic!("Request failed");
}
}
}
}
#[rstest]
#[tokio::test]
async fn test_cache_remove_entry() {
init_tracing();
// Setup code
let cache = Cache::new(Duration::from_secs(300));
let server1_name = "matrix.org";
let server1_resolution = Resolution {
destination: ResolvedDestination::Named("matrix.org".to_string(), "8448".to_string()),
host: String::from(server1_name),
};
let server2_name = "example.com";
let server2_resolution = Resolution {
destination: ResolvedDestination::Named("example.com".to_string(), "8448".to_string()),
host: String::from(server2_name),
};
cache.set(String::from(server1_name), &server1_resolution);
cache.set(String::from(server2_name), &server2_resolution);
// Actual test
let server1_removed = cache.remove_entry(server1_name);
assert_some!(&server1_removed);
// Ensure data of removed object matches what was put in originally
let server1_removed_unwrapped = server1_removed.unwrap();
assert_eq!(
server1_removed_unwrapped.resolution.host,
server1_resolution.host
);
assert_eq!(
server1_removed_unwrapped.resolution.base_url(),
server1_resolution.base_url()
);
// Check that trying to access the removed cache entry gives us None
let server1_check_actually_removed = cache.remove_entry(server1_name);
assert_none!(server1_check_actually_removed);
// Query server2 to ensure it still exists
let server2_queried = cache.get(server2_name);
assert_some!(server2_queried);
}
#[rstest]
#[tokio::test]
async fn test_cache_clear() {
init_tracing();
// Setup code
let cache = Cache::new(Duration::from_secs(300));
let server1_name = "matrix.org";
let server1_resolution = Resolution {
destination: ResolvedDestination::Named("matrix.org".to_string(), "8448".to_string()),
host: String::from(server1_name),
};
let server2_name = "example.com";
let server2_resolution = Resolution {
destination: ResolvedDestination::Named("example.com".to_string(), "8448".to_string()),
host: String::from(server2_name),
};
cache.set(String::from(server1_name), &server1_resolution);
cache.set(String::from(server2_name), &server2_resolution);
// Actual test
cache.clear();
// Query servers to ensure they are actually gone
let server1_queried = cache.get(server1_name);
let server2_queried = cache.get(server2_name);
assert_none!(server1_queried);
assert_none!(server2_queried);
}
}