]> code.delx.au - proxy/blob - test_proxy.py
Write server output to log file when running tests
[proxy] / test_proxy.py
1 #!/usr/bin/env python3
2
3 import socket
4 import struct
5 import subprocess
6 import time
7 import threading
8 import unittest
9
10
11 def get_free_port():
12 s = socket.socket()
13 s.bind(("", 0))
14 return s.getsockname()[1]
15
16 SOCKS_PORT = get_free_port()
17 ECHO_PORT = get_free_port()
18 ECHO_PORT_B = struct.pack(">H", ECHO_PORT)
19
20
21 class SocketHelper(object):
22 def init_socket(self):
23 self.sock = socket.socket(socket.AF_INET)
24 self.sock.connect(("localhost", SOCKS_PORT))
25
26 def init_ipv6_socket(self):
27 self.sock = socket.socket(socket.AF_INET6)
28 self.sock.connect(("localhost", SOCKS_PORT))
29
30 def destroy_socket(self):
31 self.sock.close()
32
33 def send(self, msg):
34 l = self.sock.send(msg)
35 self.assertEqual(len(msg), l)
36
37 def send_proxy_length(self, x):
38 self.send(("%s" % x).zfill(10).encode("ascii") + b"\n")
39
40 def recv(self, expected_length):
41 result = self.sock.recv(16384)
42 self.assertEqual(expected_length, len(result), str(result))
43 return result
44
45 def assertEnd(self):
46 try:
47 result = self.sock.recv(1)
48 self.assertEqual(0, len(result), str(result))
49 except ConnectionResetError:
50 return
51
52 def assertAuthSuccess(self):
53 result = self.recv(2)
54 self.assertEqual(b"\x05\x00", result)
55
56 def assertAuthFail(self):
57 result = self.recv(2)
58 self.assertEqual(b"\x05\xff", result)
59
60 def assertRequestSuccess(self):
61 self.assertRequestResponse(0)
62
63 def assertRequestFail(self, reply):
64 self.assertRequestResponse(reply)
65 self.assertEnd()
66
67 def assertRequestResponse(self, reply):
68 reply = struct.pack(">B", reply)
69 expected = b"\x05" + reply + b"\x00\x01\x00\x00\x00\x00\x00\x00"
70 result = self.recv(10)
71 self.assertEqual(expected, result)
72
73 class TestAuthNegotiation(SocketHelper, unittest.TestCase):
74 def run(self, result=None):
75 with SocksServer():
76 unittest.TestCase.run(self, result)
77
78 def setUp(self):
79 self.init_socket()
80
81 def tearDown(self):
82 self.destroy_socket()
83
84 def test_one_method_success(self):
85 self.send(b"\x05\x01\x00")
86 self.assertAuthSuccess()
87
88 def test_two_methods_success_first(self):
89 self.send(b"\x05\x02\x00\x80")
90 self.assertAuthSuccess()
91
92 def test_two_methods_success_second(self):
93 self.send(b"\x05\x02\x80\x00")
94 self.assertAuthSuccess()
95
96 def test_no_methods_fail(self):
97 self.send(b"\x05\x00")
98 self.assertEnd()
99
100 def test_no_matching_methods_fail(self):
101 self.send(b"\x05\x01\x80")
102 self.assertAuthFail()
103
104 def test_invalid_version_fail(self):
105 self.send(b"\x04\x01\x00")
106 self.assertEnd()
107
108
109 class TestRequestNegotiation(SocketHelper, unittest.TestCase):
110 def run(self, result=None):
111 with SocksServer():
112 unittest.TestCase.run(self, result)
113
114 def setUp(self):
115 self.init_socket()
116 self.send(b"\x05\x01\x00")
117 self.assertAuthSuccess()
118
119 def tearDown(self):
120 self.destroy_socket()
121
122 def test_invalid_version(self):
123 self.send(b"\x04\x01\x00\x01\x7f\x00\x00\x01\x00\01")
124 self.assertRequestFail(1)
125
126 def test_invalid_command(self):
127 self.send(b"\x05\x02\x00\x01\x7f\x00\x00\x01\x00\x01")
128 self.assertRequestFail(7)
129
130 def test_invalid_reserved_section(self):
131 self.send(b"\x05\x01\x01\x01\x7f\x00\x00\x01\x00\x01")
132 self.assertRequestFail(1)
133
134 def test_invalid_address_type(self):
135 self.send(b"\x05\x01\x00\x09\x7f\x00\x00\x01\x00\x01")
136 self.assertRequestFail(8)
137
138 def test_ipv4_success(self):
139 self.send(b"\x05\x01\x00\x01\x7f\x00\x00\x01" + ECHO_PORT_B)
140 self.assertRequestSuccess()
141
142 def test_ipv4_bad_port(self):
143 self.send(b"\x05\x01\x00\x01\x7f\x00\x00\x01\xff\xff")
144 self.assertRequestFail(4)
145
146 def test_ipv4_bad_host(self):
147 self.send(b"\x05\x01\x00\x01\x7f\x00\x00\x00\xff\xff")
148 self.assertRequestFail(4)
149
150 def test_dns_success(self):
151 self.send(b"\x05\x01\x00\x03\x09localhost" + ECHO_PORT_B)
152 self.assertRequestSuccess()
153
154 def test_dns_remote_success(self):
155 self.send(b"\x05\x01\x00\x03\x0bexample.com\x00P")
156 self.assertRequestSuccess()
157
158 def test_dns_bad_port(self):
159 self.send(b"\x05\x01\x00\x03\x09localhost\xff\xff")
160 self.assertRequestFail(4)
161
162 def test_dns_bad_host(self):
163 self.send(b"\x05\x01\x00\x03\x09f.invalid" + ECHO_PORT_B)
164 self.assertRequestFail(4)
165
166 def test_dns_invalid_host(self):
167 self.send(b"\x05\x01\x00\x03\x00" + ECHO_PORT_B)
168 self.assertEnd()
169
170 def test_ipv6_success(self):
171 self.send(b"\x05\x01\x00\x04" + (b"\x00"*15) + b"\x01" + ECHO_PORT_B)
172 self.assertRequestSuccess()
173
174 def test_ipv6_bad_port(self):
175 self.send(b"\x05\x01\x00\x04" + (b"\x00"*15) + b"\x01" + b"\xff\xff")
176 self.assertRequestFail(4)
177
178 def test_ipv6_bad_host(self):
179 self.send(b"\x05\x01\x00\x04" + b"\xfe\x80" + (b"\x00"*13) + b"\x01" + ECHO_PORT_B)
180 self.assertRequestFail(4)
181
182
183 class ProxyPacketHelper(object):
184 def test_one_packet(self):
185 self.send_proxy_length(3)
186 self.send(b"foo")
187 result = self.recv(3)
188 self.assertEqual(b"foo", result)
189 self.assertEnd()
190
191 def test_no_received_data(self):
192 self.send_proxy_length(0)
193 self.send(b"foo")
194 self.assertEnd()
195
196 def test_two_packets(self):
197 self.send_proxy_length(6)
198
199 self.send(b"foo")
200 result = self.recv(3)
201 self.assertEqual(b"foo", result)
202
203 self.send(b"bar")
204 result = self.recv(3)
205 self.assertEqual(b"bar", result)
206
207 self.assertEnd()
208
209 def test_large_packet(self):
210 msg = b"1234" * 1024
211 self.send_proxy_length(len(msg))
212 self.send(msg)
213 count = len(msg)
214 while count > 0:
215 part = self.sock.recv(4)
216 self.assertEqual(b"1234", part)
217 count -= 4
218 self.assertEnd()
219
220
221 class TestIPv4Proxy(SocketHelper, ProxyPacketHelper, unittest.TestCase):
222 def run(self, result=None):
223 with SocksServer():
224 unittest.TestCase.run(self, result)
225
226 def setUp(self):
227 self.init_socket()
228
229 self.send(b"\x05\x01\x00")
230 self.assertAuthSuccess()
231
232 self.send(b"\x05\x01\x00\x01\x7f\x00\x00\x01" + ECHO_PORT_B)
233 self.assertRequestSuccess()
234
235 def tearDown(self):
236 self.destroy_socket()
237
238 class TestDNSProxy(SocketHelper, ProxyPacketHelper, unittest.TestCase):
239 def run(self, result=None):
240 with SocksServer():
241 unittest.TestCase.run(self, result)
242
243 def setUp(self):
244 self.init_socket()
245
246 self.send(b"\x05\x01\x00")
247 self.assertAuthSuccess()
248
249 self.send(b"\x05\x01\x00\x03\x09localhost" + ECHO_PORT_B)
250 self.assertRequestSuccess()
251
252 def tearDown(self):
253 self.destroy_socket()
254
255 class TestIPv6Proxy(SocketHelper, ProxyPacketHelper, unittest.TestCase):
256 def run(self, result=None):
257 with SocksServer():
258 unittest.TestCase.run(self, result)
259
260 def setUp(self):
261 self.init_socket()
262
263 self.send(b"\x05\x01\x00")
264 self.assertAuthSuccess()
265
266 self.send(b"\x05\x01\x00\x04" + (b"\x00"*15) + b"\x01" + ECHO_PORT_B)
267 self.assertRequestSuccess()
268
269 def tearDown(self):
270 self.destroy_socket()
271
272 class TestPermissions(SocketHelper, unittest.TestCase):
273 def assert_connection_allowed(self):
274 try:
275 self.init_socket()
276 self.send(b"\x05\x01\x00")
277 self.assertAuthSuccess()
278 finally:
279 self.destroy_socket()
280
281 def assert_ipv6_connection_allowed(self):
282 try:
283 self.init_ipv6_socket()
284 self.send(b"\x05\x01\x00")
285 self.assertAuthSuccess()
286 finally:
287 self.destroy_socket()
288
289 def assert_connection_blocked(self):
290 try:
291 self.init_socket()
292 self.send(b"\x05\x01\x00")
293 self.assertAuthSuccess()
294 self.fail("Expected ConnectionResetError")
295 except ConnectionResetError:
296 pass
297 finally:
298 self.destroy_socket()
299
300 def test_allow_all_connections(self):
301 with SocksServer({"ALLOW_ALL": "1"}):
302 self.assert_connection_allowed()
303
304 def test_block_all_connections(self):
305 with SocksServer({}):
306 self.assert_connection_blocked()
307
308 def test_allow_ipv4_host(self):
309 with SocksServer({"ALLOW_HOST1": "127.0.0.1"}):
310 self.assert_connection_allowed()
311
312 def test_allow_ipv6_host(self):
313 with SocksServer({"ALLOW_HOST1": "::1"}):
314 self.assert_ipv6_connection_allowed()
315
316 def test_allow_multiple_hosts(self):
317 with SocksServer({"ALLOW_HOST1": "foo.invalid", "ALLOW_HOST2": "localhost"}):
318 self.assert_connection_allowed()
319
320
321 class EchoServer(object):
322 def __enter__(self):
323 self.run = True
324
325 self.sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
326 self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
327 self.sock.bind(("", ECHO_PORT))
328 self.sock.listen(5)
329
330 self.thread = threading.Thread(target=self.run_echo_server)
331 self.thread.start()
332
333 return self
334
335 def __exit__(self, exc_type, exc_val, exc_tb):
336 self.run = False
337 self.sock.shutdown(socket.SHUT_RDWR)
338 self.thread.join()
339
340 def run_echo_server(self):
341 try:
342 while self.run:
343 client, addr = self.sock.accept()
344 try:
345 self.handle_echo_client(client)
346 finally:
347 client.close()
348 except:
349 if not self.run:
350 return
351
352 def handle_echo_client(self, client):
353 line = client.recv(10+1)
354 if not line:
355 return
356 num_bytes = int(line)
357
358 while num_bytes > 0:
359 # force the test app to handle many packets by using small ones
360 data = client.recv(16)
361 if not data:
362 break
363 num_bytes -= len(data)
364 while data:
365 l = client.send(data)
366 data = data[l:]
367
368 class SocksServer(object):
369 def __init__(self, extra_env={"ALLOW_ALL": "1"}):
370 self.env = {}
371 self.env["LISTEN_PORT"] = str(SOCKS_PORT)
372 self.env.update(extra_env)
373 self.log_output = open("out.log", "a")
374
375 def __enter__(self):
376 self.process = subprocess.Popen(
377 args=["./socks5server"],
378 stdout=self.log_output,
379 stderr=self.log_output,
380 env=self.env,
381 )
382
383 self.wait_for_port()
384
385 return self
386
387 def __exit__(self, exc_type, exc_val, exc_tb):
388 self.log_output.close()
389 self.process.terminate()
390 self.process.wait()
391
392 def wait_for_port(self):
393 start_time = time.time()
394 with socket.socket(socket.AF_INET) as s:
395 while start_time + 10 > time.time():
396 try:
397 s.connect(("localhost", SOCKS_PORT))
398 return
399 except ConnectionRefusedError:
400 time.sleep(0.01)
401
402
403 if __name__ == "__main__":
404 with EchoServer():
405 unittest.main()