spin_factor_outbound_networking/
blocked_networks.rs1use std::{
2 net::{IpAddr, SocketAddr},
3 sync::Arc,
4};
5
6use ip_network::IpNetwork;
7use ip_network_table::IpNetworkTable;
8
9#[derive(Clone, Default)]
11pub struct BlockedNetworks {
12 networks: Arc<IpNetworkTable<()>>,
14 block_private: bool,
18}
19
20impl BlockedNetworks {
21 pub(crate) fn new(
22 block_networks: impl AsRef<[IpNetwork]>,
23 block_private_networks: bool,
24 ) -> Self {
25 let mut networks = IpNetworkTable::new();
26 for network in IpNetwork::collapse_addresses(block_networks.as_ref()) {
27 if block_private_networks && !network.is_global() {
29 continue;
30 }
31 networks.insert(network, ());
32 }
33 Self {
34 networks: networks.into(),
35 block_private: block_private_networks,
36 }
37 }
38
39 pub fn is_empty(&self) -> bool {
41 !self.block_private && self.networks.is_empty()
42 }
43
44 pub fn is_blocked(&self, addr: &impl IpAddrLike) -> bool {
46 let ip_addr = addr.as_ip_addr();
47 if self.block_private && !IpNetwork::from(ip_addr).is_global() {
48 return true;
49 }
50 if self.networks.longest_match(ip_addr).is_some() {
51 return true;
52 }
53 if let IpAddr::V6(ipv6) = ip_addr {
55 if let Some(ipv4_compat) = ipv6.to_ipv4() {
56 return self.is_blocked(&IpAddr::V4(ipv4_compat));
57 }
58 }
59 false
60 }
61
62 pub fn remove_blocked<T: IpAddrLike>(&self, addrs: &mut Vec<T>) -> Vec<T> {
64 if self.is_empty() {
65 return vec![];
66 }
67 let (blocked, allowed) = std::mem::take(addrs)
68 .into_iter()
69 .partition(|addr| self.is_blocked(addr));
70 *addrs = allowed;
71 blocked
72 }
73}
74
75pub trait IpAddrLike {
78 fn as_ip_addr(&self) -> IpAddr;
79}
80
81impl IpAddrLike for IpAddr {
82 fn as_ip_addr(&self) -> IpAddr {
83 *self
84 }
85}
86
87impl IpAddrLike for SocketAddr {
88 fn as_ip_addr(&self) -> IpAddr {
89 self.ip()
90 }
91}
92
93#[cfg(test)]
94pub(crate) mod tests {
95 use super::*;
96
97 #[test]
98 fn test_is_empty() {
99 assert!(BlockedNetworks::default().is_empty());
100 assert!(!BlockedNetworks::new([cidr("1.1.1.1/32")], false).is_empty());
101 assert!(!BlockedNetworks::new([], true).is_empty());
102 assert!(!BlockedNetworks::new([cidr("1.1.1.1/32")], true).is_empty());
103 }
104
105 #[test]
106 fn test_is_blocked_networks() {
107 let blocked = BlockedNetworks::new([cidr("123.123.0.0/16"), cidr("2001::/96")], false);
108 assert!(blocked.is_blocked(&ip("123.123.123.123")));
109 assert!(blocked.is_blocked(&ip("2001::1000")));
110 assert!(blocked.is_blocked(&ip("::ffff:123.123.123.123")));
111 assert!(!blocked.is_blocked(&ip("123.100.100.100")));
112 assert!(!blocked.is_blocked(&ip("2002::1000")));
113 }
114
115 #[test]
116 fn test_is_blocked_private() {
117 let redundant_private_cidr = cidr("10.0.0.0/8");
118 let blocked = BlockedNetworks::new([redundant_private_cidr], true);
119 for private in [
120 "0.0.0.0",
121 "10.10.10.10",
122 "100.64.1.1",
123 "127.0.0.1",
124 "169.254.0.1",
125 "192.0.0.1",
126 "::1",
127 "::ffff:10.10.10.10",
128 "fc00::f00d",
129 ] {
130 assert!(blocked.is_blocked(&ip(private)), "{private}");
131 }
132 assert!(!blocked.is_blocked(&ip("123.123.123.123")));
134 assert!(!blocked.is_blocked(&ip("2600::beef")));
135 }
136
137 #[test]
138 fn test_remove_blocked_socket_addrs() {
139 let blocked_networks =
140 BlockedNetworks::new([cidr("123.123.0.0/16"), cidr("2600:f00d::/32")], true);
141
142 let allowed: Vec<SocketAddr> = ["123.200.0.1:443", "[2600:beef::1000]:80"]
143 .iter()
144 .map(|addr| addr.parse().unwrap())
145 .collect();
146 let blocked: Vec<SocketAddr> = [
147 "127.0.0.1:3000",
148 "123.123.123.123:443",
149 "[::1]:8080",
150 "[2600:f00d::4]:80",
151 ]
152 .iter()
153 .map(|addr| addr.parse().unwrap())
154 .collect();
155
156 let mut addrs = [allowed.clone(), blocked.clone()].concat();
157 let actual_blocked = blocked_networks.remove_blocked(&mut addrs);
158
159 assert_eq!(addrs, allowed);
160 assert_eq!(actual_blocked, blocked);
161 }
162
163 pub(crate) fn cidr(net: &str) -> IpNetwork {
164 IpNetwork::from_str_truncate(net)
165 .unwrap_or_else(|err| panic!("invalid cidr {net:?}: {err:?}"))
166 }
167
168 pub(crate) fn ip(addr: &str) -> IpAddr {
169 addr.parse()
170 .unwrap_or_else(|err| panic!("invalid ip addr {addr:?}: {err:?}"))
171 }
172}