]> code.delx.au - proxy/blob - test_proxy.py
Initial commit v0.1
[proxy] / test_proxy.py
1 #!/usr/bin/env python3
2
3 import socket
4 import struct
5 import subprocess
6 import time
7 import threading
8 import unittest
9
10
11 def get_free_port():
12 s = socket.socket()
13 s.bind(("", 0))
14 return s.getsockname()[1]
15
16 SOCKS_PORT = get_free_port()
17 ECHO_PORT = get_free_port()
18 ECHO_PORT_B = struct.pack(">H", ECHO_PORT)
19
20
21 class SocketHelper(object):
22 def init_socket(self):
23 self.sock = socket.socket(socket.AF_INET)
24 self.sock.connect(("localhost", SOCKS_PORT))
25
26 def init_ipv6_socket(self):
27 self.sock = socket.socket(socket.AF_INET6)
28 self.sock.connect(("localhost", SOCKS_PORT))
29
30 def destroy_socket(self):
31 self.sock.close()
32
33 def send(self, msg):
34 l = self.sock.send(msg)
35 self.assertEqual(len(msg), l)
36
37 def send_proxy_length(self, x):
38 self.send(("%s" % x).zfill(10).encode("ascii") + b"\n")
39
40 def recv(self, expected_length):
41 result = self.sock.recv(16384)
42 self.assertEqual(expected_length, len(result), str(result))
43 return result
44
45 def assertEnd(self):
46 try:
47 result = self.sock.recv(1)
48 self.assertEqual(0, len(result), str(result))
49 except ConnectionResetError:
50 return
51
52 def assertAuthSuccess(self):
53 result = self.recv(2)
54 self.assertEqual(b"\x05\x00", result)
55
56 def assertAuthFail(self):
57 result = self.recv(2)
58 self.assertEqual(b"\x05\xff", result)
59
60 def assertRequestSuccess(self):
61 self.assertRequestResponse(0)
62
63 def assertRequestFail(self, reply):
64 self.assertRequestResponse(reply)
65 self.assertEnd()
66
67 def assertRequestResponse(self, reply):
68 reply = struct.pack(">B", reply)
69 expected = b"\x05" + reply + b"\x00\x01\x00\x00\x00\x00\x00\x00"
70 result = self.recv(10)
71 self.assertEqual(expected, result)
72
73 class TestAuthNegotiation(SocketHelper, unittest.TestCase):
74 def run(self, result=None):
75 with SocksServer():
76 unittest.TestCase.run(self, result)
77
78 def setUp(self):
79 self.init_socket()
80
81 def tearDown(self):
82 self.destroy_socket()
83
84 def test_one_method_success(self):
85 self.send(b"\x05\x01\x00")
86 self.assertAuthSuccess()
87
88 def test_two_methods_success_first(self):
89 self.send(b"\x05\x02\x00\x80")
90 self.assertAuthSuccess()
91
92 def test_two_methods_success_second(self):
93 self.send(b"\x05\x02\x80\x00")
94 self.assertAuthSuccess()
95
96 def test_no_methods_fail(self):
97 self.send(b"\x05\x00")
98 self.assertEnd()
99
100 def test_no_matching_methods_fail(self):
101 self.send(b"\x05\x01\x80")
102 self.assertAuthFail()
103
104 def test_invalid_version_fail(self):
105 self.send(b"\x04\x01\x00")
106 self.assertEnd()
107
108
109 class TestRequestNegotiation(SocketHelper, unittest.TestCase):
110 def run(self, result=None):
111 with SocksServer():
112 unittest.TestCase.run(self, result)
113
114 def setUp(self):
115 self.init_socket()
116 self.send(b"\x05\x01\x00")
117 self.assertAuthSuccess()
118
119 def tearDown(self):
120 self.destroy_socket()
121
122 def test_invalid_version(self):
123 self.send(b"\x04\x01\x00\x01\x7f\x00\x00\x01\x00\01")
124 self.assertRequestFail(1)
125
126 def test_invalid_command(self):
127 self.send(b"\x05\x02\x00\x01\x7f\x00\x00\x01\x00\x01")
128 self.assertRequestFail(7)
129
130 def test_invalid_reserved_section(self):
131 self.send(b"\x05\x01\x01\x01\x7f\x00\x00\x01\x00\x01")
132 self.assertRequestFail(1)
133
134 def test_invalid_address_type(self):
135 self.send(b"\x05\x01\x00\x09\x7f\x00\x00\x01\x00\x01")
136 self.assertRequestFail(8)
137
138 def test_ipv4_success(self):
139 self.send(b"\x05\x01\x00\x01\x7f\x00\x00\x01" + ECHO_PORT_B)
140 self.assertRequestSuccess()
141
142 def test_ipv4_bad_port(self):
143 self.send(b"\x05\x01\x00\x01\x7f\x00\x00\x01\xff\xff")
144 self.assertRequestFail(4)
145
146 def test_ipv4_bad_host(self):
147 self.send(b"\x05\x01\x00\x01\x7f\x00\x00\x00\xff\xff")
148 self.assertRequestFail(4)
149
150 def test_dns_success(self):
151 self.send(b"\x05\x01\x00\x03\x09localhost" + ECHO_PORT_B)
152 self.assertRequestSuccess()
153
154 def test_dns_bad_port(self):
155 self.send(b"\x05\x01\x00\x03\x09localhost\xff\xff")
156 self.assertRequestFail(4)
157
158 def test_dns_bad_host(self):
159 self.send(b"\x05\x01\x00\x03\x09f.invalid" + ECHO_PORT_B)
160 self.assertRequestFail(4)
161
162 def test_dns_invalid_host(self):
163 self.send(b"\x05\x01\x00\x03\x00" + ECHO_PORT_B)
164 self.assertEnd()
165
166 def test_ipv6_success(self):
167 self.send(b"\x05\x01\x00\x04" + (b"\x00"*15) + b"\x01" + ECHO_PORT_B)
168 self.assertRequestSuccess()
169
170 def test_ipv6_bad_port(self):
171 self.send(b"\x05\x01\x00\x04" + (b"\x00"*15) + b"\x01" + b"\xff\xff")
172 self.assertRequestFail(4)
173
174 def test_ipv6_bad_host(self):
175 self.send(b"\x05\x01\x00\x04" + b"\xfe\x80" + (b"\x00"*13) + b"\x01" + ECHO_PORT_B)
176 self.assertRequestFail(4)
177
178
179 class ProxyPacketHelper(object):
180 def test_one_packet(self):
181 self.send_proxy_length(3)
182 self.send(b"foo")
183 result = self.recv(3)
184 self.assertEqual(b"foo", result)
185 self.assertEnd()
186
187 def test_no_received_data(self):
188 self.send_proxy_length(0)
189 self.send(b"foo")
190 self.assertEnd()
191
192 def test_two_packets(self):
193 self.send_proxy_length(6)
194
195 self.send(b"foo")
196 result = self.recv(3)
197 self.assertEqual(b"foo", result)
198
199 self.send(b"bar")
200 result = self.recv(3)
201 self.assertEqual(b"bar", result)
202
203 self.assertEnd()
204
205 def test_large_packet(self):
206 msg = b"1234" * 1024
207 self.send_proxy_length(len(msg))
208 self.send(msg)
209 count = len(msg)
210 while count > 0:
211 part = self.sock.recv(4)
212 self.assertEqual(b"1234", part)
213 count -= 4
214 self.assertEnd()
215
216
217 class TestIPv4Proxy(SocketHelper, ProxyPacketHelper, unittest.TestCase):
218 def run(self, result=None):
219 with SocksServer():
220 unittest.TestCase.run(self, result)
221
222 def setUp(self):
223 self.init_socket()
224
225 self.send(b"\x05\x01\x00")
226 self.assertAuthSuccess()
227
228 self.send(b"\x05\x01\x00\x01\x7f\x00\x00\x01" + ECHO_PORT_B)
229 self.assertRequestSuccess()
230
231 def tearDown(self):
232 self.destroy_socket()
233
234 class TestDNSProxy(SocketHelper, ProxyPacketHelper, unittest.TestCase):
235 def run(self, result=None):
236 with SocksServer():
237 unittest.TestCase.run(self, result)
238
239 def setUp(self):
240 self.init_socket()
241
242 self.send(b"\x05\x01\x00")
243 self.assertAuthSuccess()
244
245 self.send(b"\x05\x01\x00\x03\x09localhost" + ECHO_PORT_B)
246 self.assertRequestSuccess()
247
248 def tearDown(self):
249 self.destroy_socket()
250
251 class TestIPv6Proxy(SocketHelper, ProxyPacketHelper, unittest.TestCase):
252 def run(self, result=None):
253 with SocksServer():
254 unittest.TestCase.run(self, result)
255
256 def setUp(self):
257 self.init_socket()
258
259 self.send(b"\x05\x01\x00")
260 self.assertAuthSuccess()
261
262 self.send(b"\x05\x01\x00\x04" + (b"\x00"*15) + b"\x01" + ECHO_PORT_B)
263 self.assertRequestSuccess()
264
265 def tearDown(self):
266 self.destroy_socket()
267
268 class TestPermissions(SocketHelper, unittest.TestCase):
269 def assert_connection_allowed(self):
270 try:
271 self.init_socket()
272 self.send(b"\x05\x01\x00")
273 self.assertAuthSuccess()
274 finally:
275 self.destroy_socket()
276
277 def assert_ipv6_connection_allowed(self):
278 try:
279 self.init_ipv6_socket()
280 self.send(b"\x05\x01\x00")
281 self.assertAuthSuccess()
282 finally:
283 self.destroy_socket()
284
285 def assert_connection_blocked(self):
286 try:
287 self.init_socket()
288 self.send(b"\x05\x01\x00")
289 self.assertAuthSuccess()
290 self.fail("Expected ConnectionResetError")
291 except ConnectionResetError:
292 pass
293 finally:
294 self.destroy_socket()
295
296 def test_allow_all_connections(self):
297 with SocksServer({"ALLOW_ALL": "1"}):
298 self.assert_connection_allowed()
299
300 def test_block_all_connections(self):
301 with SocksServer({}):
302 self.assert_connection_blocked()
303
304 def test_allow_ipv4_host(self):
305 with SocksServer({"ALLOW_HOST1": "127.0.0.1"}):
306 self.assert_connection_allowed()
307
308 def test_allow_ipv6_host(self):
309 with SocksServer({"ALLOW_HOST1": "::1"}):
310 self.assert_ipv6_connection_allowed()
311
312 def test_allow_multiple_hosts(self):
313 with SocksServer({"ALLOW_HOST1": "foo.invalid", "ALLOW_HOST2": "localhost"}):
314 self.assert_connection_allowed()
315
316
317 class EchoServer(object):
318 def __enter__(self):
319 self.run = True
320
321 self.sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
322 self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
323 self.sock.bind(("", ECHO_PORT))
324 self.sock.listen(5)
325
326 self.thread = threading.Thread(target=self.run_echo_server)
327 self.thread.start()
328
329 return self
330
331 def __exit__(self, exc_type, exc_val, exc_tb):
332 self.run = False
333 self.sock.shutdown(socket.SHUT_RDWR)
334 self.thread.join()
335
336 def run_echo_server(self):
337 try:
338 while self.run:
339 client, addr = self.sock.accept()
340 try:
341 self.handle_echo_client(client)
342 finally:
343 client.close()
344 except:
345 if not self.run:
346 return
347
348 def handle_echo_client(self, client):
349 line = client.recv(10+1)
350 if not line:
351 return
352 num_bytes = int(line)
353
354 while num_bytes > 0:
355 # force the test app to handle many packets by using small ones
356 data = client.recv(16)
357 if not data:
358 break
359 num_bytes -= len(data)
360 while data:
361 l = client.send(data)
362 data = data[l:]
363
364 class SocksServer(object):
365 def __init__(self, extra_env={"ALLOW_ALL": "1"}):
366 self.env = {}
367 self.env["LISTEN_PORT"] = str(SOCKS_PORT)
368 self.env.update(extra_env)
369 self.devnull = open("/dev/null", "w")
370
371 def __enter__(self):
372 self.process = subprocess.Popen(
373 args=["./socks5server"],
374 stdout=self.devnull,
375 stderr=self.devnull,
376 env=self.env,
377 )
378
379 self.wait_for_port()
380
381 return self
382
383 def __exit__(self, exc_type, exc_val, exc_tb):
384 self.devnull.close()
385 self.process.terminate()
386 self.process.wait()
387
388 def wait_for_port(self):
389 start_time = time.time()
390 with socket.socket(socket.AF_INET) as s:
391 while start_time + 10 > time.time():
392 try:
393 s.connect(("localhost", SOCKS_PORT))
394 return
395 except ConnectionRefusedError:
396 time.sleep(0.01)
397
398
399 if __name__ == "__main__":
400 with EchoServer():
401 unittest.main()