Set 1, challenge 3
[sandbox] / fw / fw.py
1 import collections
2 from wsgiref.util import setup_testing_defaults
3
4 Request = collections.namedtuple(
5     'Request',
6     [
7         'method',
8         'path',
9         'content_type',
10         'content_length',
11         'content',
12         'headers',
13         'environment',
14     ],
15 )
16
17 def _environment_to_headers(environment):
18     result = collections.OrderedDict()
19
20     for key, value in environment.items():
21         if key.startswith('HTTP_'):
22             key = key[len('HTTP_'):].lower()
23             result[key.lower()] = value
24
25     return result
26
27 def request(method, path, **kwargs):
28     return Request(
29         method = method,
30         path = path,
31         content_type = kwargs.get('content_type', 'text/plain'),
32         content_length = kwargs.get('content_length', 0),
33         content = kwargs.get('content', ''),
34         headers = kwargs.get('headers', collections.OrderedDict()),
35         environment = kwargs.get('environment', {}),
36     )
37
38 Response = collections.namedtuple(
39     'Response',
40     [
41         'status',
42         'headers',
43         'content',
44         'encoding',
45     ],
46 )
47
48 def response(**kwargs):
49     return Response(
50         status = kwargs.get('status', 200),
51         headers = kwargs.get('headers', [('Content-type', 'text/plain; charset=utf-8')]),
52         content = kwargs.get('content', ''),
53         encoding = kwargs.get('encoding', 'utf-8'),
54     )
55
56 # From https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
57 _STATUS_CODES_TO_STRINGS = {
58     200: '200 OK',
59     404: '404 Not Found',
60     405: '405 Method Not Allowed',
61 }
62
63 def _status_code_to_string(status_code):
64     return _STATUS_CODES_TO_STRINGS[status_code]
65
66 def _wrap_content(content_iterator, encoding):
67     if isinstance(content_iterator, str):
68         content_iterator = content_iterator.encode(encoding)
69         yield content_iterator
70         return
71
72     if isinstance(content_iterator, bytes):
73         yield content_iterator
74         return
75
76     for content_item in content_iterator:
77         if isinstance(content_item, str):
78             yield content_item.encode(encoding)
79
80         else:
81             yield content_item
82
83 def application(request_handler):
84     def wrapped_request_handler(environment, start_response):
85         setup_testing_defaults(environment)
86
87         content_type = environment['CONTENT_TYPE']
88         content_length = environment['CONTENT_LENGTH']
89
90         if environment['CONTENT_LENGTH'] == '':
91             content_length = 0
92             content = ''
93
94         else:
95             content_length = int(environment['CONTENT_LENGTH'])
96             content = environment['wsgi.input'].read(content_length)
97
98         if content_length == '':
99             content_length = len(content)
100
101         else:
102             content_length = int(content_length)
103
104         result = request_handler(request(
105             environment['REQUEST_METHOD'],
106             environment['PATH_INFO'],
107             content_type = content_type,
108             content_length = content_length,
109             content = content,
110             headers = _environment_to_headers(environment),
111             environment = environment,
112         ))
113
114         start_response(
115             _status_code_to_string(result.status),
116             result.headers,
117         )
118
119         return _wrap_content(result.content, result.encoding)
120
121     return wrapped_request_handler
122
123 def _route_matcher(route):
124     def matcher(path):
125         if route == path:
126             return True, ()
127
128         return False, None
129
130     return matcher
131
132 def path_router(*routes_to_handlers, **kwargs):
133     matchers_to_handlers = []
134
135     for route, handler in routes_to_handlers:
136         matcher = _route_matcher(route)
137         matchers_to_handlers.append((matcher, handler))
138
139     defined_routes = [route for route, handler in routes_to_handlers]
140     def default_not_found_handler(request):
141         content = 'FILE {} NOT FOUND\n'.format(request.path)
142         content += 'The following routes are defined for this router:\n'
143         content += '\n'.join(defined_routes)
144         return response(status = 404, content = content)
145
146     not_found_handler = kwargs.pop('not_found_handler', default_not_found_handler)
147
148     if any(kwargs):
149         raise Exception('Unexpected keyword argument "{}"'.format(list(kwargs.keys())[0]))
150
151     def route(request):
152         for matcher, handler in matchers_to_handlers:
153             matched, args = matcher(request.path)
154             if matched:
155                 return handler(request, *args)
156
157         return not_found_handler(request)
158
159     return route
160
161 # https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html
162 _REQUEST_METHODS = ['OPTIONS', 'GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'TRACE', 'CONNECT']
163
164 def method_router(**kwargs):
165     allowed_methods = {}
166
167     for request_method in _REQUEST_METHODS:
168         if request_method in kwargs:
169             allowed_methods[request_method] = kwargs.pop(request_method)
170
171     auto_options = kwargs.pop('auto_options', True)
172
173     def default_options_handler(request):
174         return response(
175             status = 200,
176             headers = [
177                 ('Allow', ', '.join(allowed_methods.keys())),
178                 ('Content-Length', '0'),
179             ],
180         )
181
182     if auto_options and 'OPTIONS' not in allowed_methods:
183         allowed_methods['OPTIONS'] = default_options_handler
184
185     def default_method_not_allowed_handler(request):
186         content = 'METHOD "{}" NOT ALLOWED\n'.format(request.method)
187         content += 'The following methods are allowed for this resource:\n'
188         content += '\n'.join(allowed_methods.keys())
189         return response(status = 405, content = content)
190
191     method_not_allowed_handler = kwargs.pop(
192         'method_not_allowed_handler',
193         default_method_not_allowed_handler,
194     )
195
196     if any(kwargs):
197         raise Exception('Unexpected keyword argument "{}"'.format(list(kwargs.keys())[0]))
198
199     def route(request):
200         return allowed_methods.get(request.method, method_not_allowed_handler)(request)
201
202     return route
203
204
205 if __name__ == '__main__':
206     import unittest
207     from unittest import mock
208
209     class PathRouterTests(unittest.TestCase):
210         def test_routes_to_handler(self):
211             path = '/path'
212             expected_request = request('GET', path)
213             expected_response = response(content = 'Expected')
214             handler = mock.MagicMock()
215             handler.return_value = expected_response
216
217             router = path_router(
218                 ('/path',   handler),
219             )
220             actual_response = router(expected_request)
221
222             handler.assert_called_with(expected_request)
223             self.assertEqual(expected_response, actual_response)
224
225         def test_routes_to_first_matching_handler(self):
226             path = '/path'
227             expected_request = request('GET', path)
228             expected_response = response(content = 'Expected')
229             expected_handler = mock.MagicMock()
230             expected_handler.return_value = expected_response
231             unexpected_handler = mock.MagicMock()
232
233             router = path_router(
234                 ('/path',   expected_handler),
235                 ('/path',   unexpected_handler),
236             )
237             actual_response = router(expected_request)
238
239             expected_handler.assert_called_with(expected_request)
240             self.assertEqual(expected_response, actual_response)
241
242     unittest.main()