Binary object serialization
[sandbox] / serial / serial / binary.py
1 import functools
2 import io
3 import struct
4
5 from . import tags
6
7 def _make_tag_only_serializer(tag, expected_value):
8     tag = bytes([tag])
9
10     def serializer(to):
11         assert to.instance == expected_value
12         return tag
13
14     return serializer
15
16 def _make_struct_serializer(fmt):
17     fmt = '!B' + fmt
18     packer = functools.partial(struct.pack, fmt)
19
20     def serializer(to):
21         return packer(to.tag, to.instance)
22
23     return serializer
24
25 def _make_string_serializer(encoder):
26     packer = functools.partial(struct.pack, '!BI')
27
28     def serializer(to):
29         encoded = encoder(to.instance)
30         return packer(to.tag, len(encoded)) + encoded
31
32     return serializer
33
34 def _serialize_tuple(to):
35     assert isinstance(to.instance, tuple)
36
37     payload = b''.join(serialize(item) for item in to.instance)
38
39     fmt = '!BI'
40
41     return struct.pack('!BI', tags.TUPLE, len(payload)) + payload
42
43 def _serialize_list(to):
44     assert isinstance(to.instance, list)
45
46     # TODO Actually handle this case somehow
47     assert len(to.instance) > 0
48
49     # TODO Do this a better way
50     serialized_items = [serialize(i) for i in to.instance]
51     list_tag = serialized_items[0][0]
52
53     def check_and_strip_prefix(b):
54         item_tag = b[0]
55         assert list_tag == item_tag
56         return b[1:]
57
58     payload = b''.join(check_and_strip_prefix(si) for si in serialized_items)
59
60     fmt = '!BBI'
61
62     return struct.pack(fmt, tags.LIST, list_tag, len(payload)) + payload
63
64 def _serialize_object(to):
65     assert isinstance(to.instance, list)
66
67     # TODO Actually handle this case somehow
68     assert len(to.instance) > 0
69
70     # TODO Do this a better way
71     serialized_kvps = [(serialize(k), serialize(v)) for k,v in to.instance]
72     key_type_tag = serialized_kvps[0][0][0]
73
74     def check_and_strip_prefix(b):
75         item_tag = b[0]
76         assert key_type_tag == item_tag
77         return b[1:]
78
79     payload = b''.join(check_and_strip_prefix(k) + v for k,v in serialized_kvps)
80
81     fmt = '!BBI'
82
83     return struct.pack(fmt, tags.OBJECT, key_type_tag, len(payload)) + payload
84
85 _TAGS_TO_SERIALIZERS = {
86     tags.NULL: _make_tag_only_serializer(tags.NULL, None),
87     tags.TRUE: _make_tag_only_serializer(tags.TRUE, True),
88     tags.FALSE: _make_tag_only_serializer(tags.FALSE, False),
89     tags.UINT8: _make_struct_serializer('B'),
90     tags.UINT16: _make_struct_serializer('H'),
91     tags.UINT32: _make_struct_serializer('I'),
92     tags.UINT64: _make_struct_serializer('Q'),
93     tags.INT8: _make_struct_serializer('b'),
94     tags.INT16: _make_struct_serializer('h'),
95     tags.INT32: _make_struct_serializer('i'),
96     tags.INT64: _make_struct_serializer('q'),
97     tags.BINARY: _make_string_serializer(lambda s: s),
98     tags.UTF8: _make_string_serializer(lambda s: s.encode('utf-8')),
99     tags.UTF16: _make_string_serializer(lambda s: s.encode('utf-16')),
100     tags.UTF32: _make_string_serializer(lambda s: s.encode('utf-32')),
101     tags.TUPLE: _serialize_tuple,
102     tags.LIST: _serialize_list,
103     tags.OBJECT: _serialize_object,
104 }
105
106 def serialize(to):
107     return _TAGS_TO_SERIALIZERS[to.tag](to)
108
109 def _make_tag_only_parser(tag, value):
110     def parser(b):
111         return 0, tags.TaggedObject(tag = tag, instance = value)
112
113     return parser
114
115 def _make_struct_deserializer(tag, fmt):
116     fmt = '!' + fmt
117     size = struct.calcsize(fmt)
118     unpacker = functools.partial(struct.unpack, fmt)
119
120     def parser(b):
121         b = b.read(size)
122         assert len(b) == size
123         return size, tags.TaggedObject(tag = tag, instance = unpacker(b)[0])
124
125     return parser
126
127 _LENGTH_FMT = '!I'
128 _LENGTH_FMT_SIZE = struct.calcsize(_LENGTH_FMT)
129
130 def _read_length_then_payload(b):
131     length_b = b.read(_LENGTH_FMT_SIZE)
132     assert len(length_b) == _LENGTH_FMT_SIZE
133     length = struct.unpack(_LENGTH_FMT, length_b)[0]
134
135     payload = b.read(length)
136     assert len(payload) == length
137     return _LENGTH_FMT_SIZE + length, payload
138
139 def _make_string_deserializer(tag, decoder):
140     fmt = '!I'
141     size = struct.calcsize(fmt)
142     unpacker = functools.partial(struct.unpack, fmt)
143
144     def parser(b):
145         bytes_read, payload = _read_length_then_payload(b)
146         return bytes_read, tags.TaggedObject(tag = tag, instance = decoder(payload))
147
148     return parser
149
150 def _deserialize_tuple(b):
151     bytes_read, payload = _read_length_then_payload(b)
152
153     payload_stream = io.BytesIO(payload)
154
155     total_bytes_read = 0
156     instance = []
157
158     while total_bytes_read < len(payload):
159         partial_bytes_read, item = _deserialize_partial(payload_stream)
160         total_bytes_read += partial_bytes_read
161         instance.append(item)
162
163     return bytes_read, tags.TaggedObject(tag = tags.TUPLE, instance = tuple(instance))
164
165 def _deserialize_list(b):
166     list_tag_bytes = b.read(1)
167     assert len(list_tag_bytes) == 1
168     list_tag = list_tag_bytes[0]
169
170     bytes_read, payload = _read_length_then_payload(b)
171
172     payload_stream = io.BytesIO(payload)
173
174     total_bytes_read = 0
175     instance = []
176
177     while total_bytes_read < len(payload):
178         partial_bytes_read, item = _TAGS_TO_PARSERS[list_tag](payload_stream)
179         total_bytes_read += partial_bytes_read
180         instance.append(item)
181
182     # TODO Return tags = (tags.LIST, list_tag) to function like a generic type
183     return bytes_read, tags.TaggedObject(tag = tags.LIST, instance = instance)
184
185 def _deserialize_object(b):
186     raise Exception('Not implemented')
187
188 _TAGS_TO_PARSERS = {
189     tags.NULL: _make_tag_only_parser(tags.NULL, None),
190     tags.TRUE: _make_tag_only_parser(tags.TRUE, True),
191     tags.FALSE: _make_tag_only_parser(tags.FALSE, False),
192     tags.UINT8: _make_struct_deserializer(tags.UINT8, 'B'),
193     tags.UINT16: _make_struct_deserializer(tags.UINT16, 'H'),
194     tags.UINT32: _make_struct_deserializer(tags.UINT32, 'I'),
195     tags.UINT64: _make_struct_deserializer(tags.UINT64, 'Q'),
196     tags.INT8: _make_struct_deserializer(tags.INT8, 'b'),
197     tags.INT16: _make_struct_deserializer(tags.INT16, 'h'),
198     tags.INT32: _make_struct_deserializer(tags.INT32, 'i'),
199     tags.INT64: _make_struct_deserializer(tags.INT64, 'q'),
200     tags.BINARY: _make_string_deserializer(tags.BINARY, lambda b: b),
201     tags.UTF8: _make_string_deserializer(tags.UTF8, lambda b: b.decode('utf-8')),
202     tags.UTF16: _make_string_deserializer(tags.UTF16, lambda b: b.decode('utf-16')),
203     tags.UTF32: _make_string_deserializer(tags.UTF32, lambda b: b.decode('utf-32')),
204     tags.TUPLE: _deserialize_tuple,
205     tags.LIST: _deserialize_list,
206     tags.OBJECT: _deserialize_object,
207 }
208
209 def _deserialize_partial(b):
210     tag = b.read(1)
211     assert len(tag) == 1
212     bytes_read, to = _TAGS_TO_PARSERS[tag[0]](b)
213     return bytes_read + 1, to
214
215 def deserialize(b):
216     if isinstance(b, bytes):
217         b = io.BytesIO(b)
218
219     bytes_read, result = _deserialize_partial(b)
220
221     remainder = b.read()
222
223     if len(remainder) == 0:
224         return result
225
226     raise Exception('Unable to parse remainder: {}'.format(remainder))