960acef8bd02243ed01beabaa2c551eec5b9a1dd
[sandbox] / serial / serial / binary.py
1 import functools
2 import io
3 import struct
4
5 from . import tags
6
7 def _make_tag_only_serializer(tag, expected_value):
8     tag = bytes([tag])
9
10     def serializer(to):
11         assert to.instance == expected_value
12         return tag
13
14     return serializer
15
16 def _make_struct_serializer(fmt):
17     fmt = '!B' + fmt
18     packer = functools.partial(struct.pack, fmt)
19
20     def serializer(to):
21         return packer(to.tag, to.instance)
22
23     return serializer
24
25 def _make_string_serializer(encoder):
26     packer = functools.partial(struct.pack, '!BI')
27
28     def serializer(to):
29         encoded = encoder(to.instance)
30         return packer(to.tag, len(encoded)) + encoded
31
32     return serializer
33
34 def _serialize_tuple(to):
35     assert isinstance(to.instance, tuple)
36
37     payload = b''.join(serialize(item) for item in to.instance)
38
39     fmt = '!BI'
40
41     return struct.pack('!BI', tags.TUPLE, len(payload)) + payload
42
43
44 _TAGS_TO_SERIALIZERS = {
45     tags.NULL: _make_tag_only_serializer(tags.NULL, None),
46     tags.TRUE: _make_tag_only_serializer(tags.TRUE, True),
47     tags.FALSE: _make_tag_only_serializer(tags.FALSE, False),
48     tags.UINT8: _make_struct_serializer('B'),
49     tags.UINT16: _make_struct_serializer('H'),
50     tags.UINT32: _make_struct_serializer('I'),
51     tags.UINT64: _make_struct_serializer('Q'),
52     tags.INT8: _make_struct_serializer('b'),
53     tags.INT16: _make_struct_serializer('h'),
54     tags.INT32: _make_struct_serializer('i'),
55     tags.INT64: _make_struct_serializer('q'),
56     tags.BINARY: _make_string_serializer(lambda s: s),
57     tags.UTF8: _make_string_serializer(lambda s: s.encode('utf-8')),
58     tags.UTF16: _make_string_serializer(lambda s: s.encode('utf-16')),
59     tags.UTF32: _make_string_serializer(lambda s: s.encode('utf-32')),
60     tags.TUPLE: _serialize_tuple,
61 }
62
63 def serialize(to):
64     return _TAGS_TO_SERIALIZERS[to.tag](to)
65
66 def _make_tag_only_parser(tag, value):
67     def parser(b):
68         return 0, tags.TaggedObject(tag = tag, instance = value)
69
70     return parser
71
72 def _make_struct_deserializer(tag, fmt):
73     fmt = '!' + fmt
74     size = struct.calcsize(fmt)
75     unpacker = functools.partial(struct.unpack, fmt)
76
77     def parser(b):
78         b = b.read(size)
79         assert len(b) == size
80         return size, tags.TaggedObject(tag = tag, instance = unpacker(b)[0])
81
82     return parser
83
84 _LENGTH_FMT = '!I'
85 _LENGTH_FMT_SIZE = struct.calcsize(_LENGTH_FMT)
86
87 def _read_length_then_payload(b):
88     length_b = b.read(_LENGTH_FMT_SIZE)
89     assert len(length_b) == _LENGTH_FMT_SIZE
90     length = struct.unpack(_LENGTH_FMT, length_b)[0]
91
92     payload = b.read(length)
93     assert len(payload) == length
94     return _LENGTH_FMT_SIZE + length, payload
95
96 def _make_string_deserializer(tag, decoder):
97     fmt = '!I'
98     size = struct.calcsize(fmt)
99     unpacker = functools.partial(struct.unpack, fmt)
100
101     def parser(b):
102         bytes_read, payload = _read_length_then_payload(b)
103         return bytes_read, tags.TaggedObject(tag = tag, instance = decoder(payload))
104
105     return parser
106
107 def _deserialize_tuple(b):
108     bytes_read, payload = _read_length_then_payload(b)
109
110     payload_stream = io.BytesIO(payload)
111
112     total_bytes_read = 0
113     instance = []
114
115     while total_bytes_read < len(payload):
116         partial_bytes_read, item = _deserialize_partial(payload_stream)
117         total_bytes_read += partial_bytes_read
118         instance.append(item)
119
120     return bytes_read, tags.TaggedObject(tag = tags.TUPLE, instance = tuple(instance))
121
122 _TAGS_TO_PARSERS = {
123     tags.NULL: _make_tag_only_parser(tags.NULL, None),
124     tags.TRUE: _make_tag_only_parser(tags.TRUE, True),
125     tags.FALSE: _make_tag_only_parser(tags.FALSE, False),
126     tags.UINT8: _make_struct_deserializer(tags.UINT8, 'B'),
127     tags.UINT16: _make_struct_deserializer(tags.UINT16, 'H'),
128     tags.UINT32: _make_struct_deserializer(tags.UINT32, 'I'),
129     tags.UINT64: _make_struct_deserializer(tags.UINT64, 'Q'),
130     tags.INT8: _make_struct_deserializer(tags.INT8, 'b'),
131     tags.INT16: _make_struct_deserializer(tags.INT16, 'h'),
132     tags.INT32: _make_struct_deserializer(tags.INT32, 'i'),
133     tags.INT64: _make_struct_deserializer(tags.INT64, 'q'),
134     tags.BINARY: _make_string_deserializer(tags.BINARY, lambda b: b),
135     tags.UTF8: _make_string_deserializer(tags.UTF8, lambda b: b.decode('utf-8')),
136     tags.UTF16: _make_string_deserializer(tags.UTF16, lambda b: b.decode('utf-16')),
137     tags.UTF32: _make_string_deserializer(tags.UTF32, lambda b: b.decode('utf-32')),
138     tags.TUPLE: _deserialize_tuple,
139 }
140
141 def _deserialize_partial(b):
142     tag = b.read(1)
143     assert len(tag) == 1
144     bytes_read, to = _TAGS_TO_PARSERS[tag[0]](b)
145     return bytes_read + 1, to
146
147 def deserialize(b):
148     if isinstance(b, bytes):
149         b = io.BytesIO(b)
150
151     bytes_read, result = _deserialize_partial(b)
152
153     remainder = b.read()
154
155     if len(remainder) == 0:
156         return result
157
158     raise Exception('Unable to parse remainder: {}'.format(remainder))