spin_outbound_networking_config/
blocked_networks.rs

1use std::{
2    net::{IpAddr, SocketAddr},
3    sync::Arc,
4};
5
6use ip_network::IpNetwork;
7use ip_network_table::IpNetworkTable;
8
9/// A cheaply-clonable set of blocked networks
10#[derive(Clone, Default)]
11pub struct BlockedNetworks {
12    /// A set of IP networks to be blocked
13    networks: Arc<IpNetworkTable<()>>,
14    /// If true, block all non-globally-routable networks, in addition to `networks`
15    ///
16    /// See: [`ip_network::Ipv4Network::is_global`] / [`ip_network::Ipv6Network::is_global`]
17    block_private: bool,
18}
19
20impl BlockedNetworks {
21    /// Creates a new `BlockedNetworks` instance with the given networks and private network blocking option.
22    pub fn new(block_networks: impl AsRef<[IpNetwork]>, block_private_networks: bool) -> Self {
23        let mut networks = IpNetworkTable::new();
24        for network in IpNetwork::collapse_addresses(block_networks.as_ref()) {
25            // Omit redundant blocked_networks if block_private_networks = true
26            if block_private_networks && !network.is_global() {
27                continue;
28            }
29            networks.insert(network, ());
30        }
31        Self {
32            networks: networks.into(),
33            block_private: block_private_networks,
34        }
35    }
36
37    /// Returns true iff no networks are blocked.
38    pub fn is_empty(&self) -> bool {
39        !self.block_private && self.networks.is_empty()
40    }
41
42    /// Returns true iff the given address is blocked.
43    pub fn is_blocked(&self, addr: &impl IpAddrLike) -> bool {
44        let ip_addr = addr.as_ip_addr();
45        if self.block_private && !IpNetwork::from(ip_addr).is_global() {
46            return true;
47        }
48        if self.networks.longest_match(ip_addr).is_some() {
49            return true;
50        }
51        // Convert IPv4-compatible IPv6 addresses to IPv4 and check again to prevent bypass
52        if let IpAddr::V6(ipv6) = ip_addr {
53            if let Some(ipv4_compat) = ipv6.to_ipv4() {
54                return self.is_blocked(&IpAddr::V4(ipv4_compat));
55            }
56        }
57        false
58    }
59
60    /// Removes and returns any addresses with blocked IPs from the given Vec.
61    pub fn remove_blocked<T: IpAddrLike>(&self, addrs: &mut Vec<T>) -> Vec<T> {
62        if self.is_empty() {
63            return vec![];
64        }
65        let (blocked, allowed) = std::mem::take(addrs)
66            .into_iter()
67            .partition(|addr| self.is_blocked(addr));
68        *addrs = allowed;
69        blocked
70    }
71}
72
73/// AsIpAddr can be implemented to make an "IP-address-like" type compatible
74/// with [`BlockedNetworks`].
75pub trait IpAddrLike {
76    fn as_ip_addr(&self) -> IpAddr;
77}
78
79impl IpAddrLike for IpAddr {
80    fn as_ip_addr(&self) -> IpAddr {
81        *self
82    }
83}
84
85impl IpAddrLike for SocketAddr {
86    fn as_ip_addr(&self) -> IpAddr {
87        self.ip()
88    }
89}
90
91/// Helpers for testing purposes
92pub mod test {
93    use super::*;
94
95    /// Converts a string to an `IpNetwork`, panicking on failure.
96    pub fn cidr(net: &str) -> IpNetwork {
97        IpNetwork::from_str_truncate(net)
98            .unwrap_or_else(|err| panic!("invalid cidr {net:?}: {err:?}"))
99    }
100
101    /// Converts a string to an `IpAddr`, panicking on failure.
102    pub fn ip(addr: &str) -> IpAddr {
103        addr.parse()
104            .unwrap_or_else(|err| panic!("invalid ip addr {addr:?}: {err:?}"))
105    }
106}
107
108#[cfg(test)]
109pub mod tests {
110    use super::test::*;
111    use super::*;
112
113    #[test]
114    fn test_is_empty() {
115        assert!(BlockedNetworks::default().is_empty());
116        assert!(!BlockedNetworks::new([cidr("1.1.1.1/32")], false).is_empty());
117        assert!(!BlockedNetworks::new([], true).is_empty());
118        assert!(!BlockedNetworks::new([cidr("1.1.1.1/32")], true).is_empty());
119    }
120
121    #[test]
122    fn test_is_blocked_networks() {
123        let blocked = BlockedNetworks::new([cidr("123.123.0.0/16"), cidr("2001::/96")], false);
124        assert!(blocked.is_blocked(&ip("123.123.123.123")));
125        assert!(blocked.is_blocked(&ip("2001::1000")));
126        assert!(blocked.is_blocked(&ip("::ffff:123.123.123.123")));
127        assert!(!blocked.is_blocked(&ip("123.100.100.100")));
128        assert!(!blocked.is_blocked(&ip("2002::1000")));
129    }
130
131    #[test]
132    fn test_is_blocked_private() {
133        let redundant_private_cidr = cidr("10.0.0.0/8");
134        let blocked = BlockedNetworks::new([redundant_private_cidr], true);
135        for private in [
136            "0.0.0.0",
137            "10.10.10.10",
138            "100.64.1.1",
139            "127.0.0.1",
140            "169.254.0.1",
141            "192.0.0.1",
142            "::1",
143            "::ffff:10.10.10.10",
144            "fc00::f00d",
145        ] {
146            assert!(blocked.is_blocked(&ip(private)), "{private}");
147        }
148        // Public addresses not blocked
149        assert!(!blocked.is_blocked(&ip("123.123.123.123")));
150        assert!(!blocked.is_blocked(&ip("2600::beef")));
151    }
152
153    #[test]
154    fn test_remove_blocked_socket_addrs() {
155        let blocked_networks =
156            BlockedNetworks::new([cidr("123.123.0.0/16"), cidr("2600:f00d::/32")], true);
157
158        let allowed: Vec<SocketAddr> = ["123.200.0.1:443", "[2600:beef::1000]:80"]
159            .iter()
160            .map(|addr| addr.parse().unwrap())
161            .collect();
162        let blocked: Vec<SocketAddr> = [
163            "127.0.0.1:3000",
164            "123.123.123.123:443",
165            "[::1]:8080",
166            "[2600:f00d::4]:80",
167        ]
168        .iter()
169        .map(|addr| addr.parse().unwrap())
170        .collect();
171
172        let mut addrs = [allowed.clone(), blocked.clone()].concat();
173        let actual_blocked = blocked_networks.remove_blocked(&mut addrs);
174
175        assert_eq!(addrs, allowed);
176        assert_eq!(actual_blocked, blocked);
177    }
178}