d270f938ad8eb5e00c9649866179c6c6eaa2cc55
[ton] / don / binary.py
1 import collections
2 import struct
3
4 from don import tags, _shared
5
6 def _binary_serialize_tag_only_type(o):
7     return b''
8
9 def _pack_format_string_to_binary_serializer(pfs):
10     def serializer(i):
11         return struct.pack(pfs, i)
12     return serializer
13
14 def _encoder_to_binary_serializer(e):
15     def serializer(s):
16         encoded = e(s)
17         return struct.pack('!I', len(encoded)) + encoded
18     return serializer
19
20 def _binary_serialize_list(items):
21     # TODO Enforce that items are all the same type
22     items = [tags._tag(i) for i in items]
23
24     if len(items) == 0:
25         item_tag = tags.VOID
26     else:
27         item_tag = items[0].tag
28
29     item_serializer = _BINARY_SERIALIZERS[item_tag]
30     items = [item_serializer(i.value) for i in items]
31     item_length = len(items)
32     items = b''.join(items)
33     byte_length = len(items)
34     return struct.pack('!BII', item_tag, byte_length, item_length) + items
35
36 def _binary_serialize_dict(d):
37     item_length = 0
38     serialized = b''
39
40     key_serializer = _BINARY_SERIALIZERS[tags.UTF8]
41
42     for key, value in d.items():
43         assert isinstance(key, str)
44         item_length += 1
45         serialized += key_serializer(key) + serialize(value)
46
47     byte_length = len(serialized)
48     return struct.pack('!II', byte_length, item_length) + serialized
49
50 _BINARY_SERIALIZERS = {
51     tags.VOID: _binary_serialize_tag_only_type,
52     tags.TRUE: _binary_serialize_tag_only_type,
53     tags.FALSE: _binary_serialize_tag_only_type,
54     tags.INT8: _pack_format_string_to_binary_serializer('!b'),
55     tags.INT16: _pack_format_string_to_binary_serializer('!h'),
56     tags.INT32: _pack_format_string_to_binary_serializer('!i'),
57     tags.FLOAT: _pack_format_string_to_binary_serializer('!f'),
58     tags.DOUBLE: _pack_format_string_to_binary_serializer('!d'),
59     tags.BINARY: _encoder_to_binary_serializer(lambda b: b),
60     tags.UTF8: _encoder_to_binary_serializer(lambda s: s.encode('utf-8')),
61     tags.UTF16: _encoder_to_binary_serializer(lambda s: s.encode('utf-16')),
62     tags.UTF32: _encoder_to_binary_serializer(lambda s: s.encode('utf-32')),
63     tags.LIST: _binary_serialize_list,
64     tags.DICTIONARY: _binary_serialize_dict,
65 }
66
67 def serialize(o):
68     o = tags._tag(o)
69     return struct.pack('!B', o.tag) + _BINARY_SERIALIZERS[o.tag](o.value)
70
71 _BYTE_SIZES_TO_UNPACK_FORMATS = {
72     1: '!b',
73     2: '!h',
74     4: '!i',
75     8: '!q',
76 }
77
78 def make_integer_parser(size_in_bytes):
79     unpack_format = _BYTE_SIZES_TO_UNPACK_FORMATS[size_in_bytes]
80
81     def integer_parser(source):
82         value = struct.unpack(unpack_format, source[:size_in_bytes])[0]
83         remaining = source[size_in_bytes:]
84
85         return _shared.ParseResult(success = True, value = value, remaining = remaining)
86
87     return integer_parser
88
89 def binary64_parser(source):
90     return _shared.ParseResult(
91         success = True,
92         value = struct.unpack('!d', source[:8])[0],
93         remaining = source[8:],
94     )
95
96 def make_string_parser(decoder):
97     def string_parser(source):
98         length = struct.unpack('!I', source[:4])[0]
99         source = source[4:]
100         return _shared.ParseResult(
101             success = True,
102             value = decoder(source[:length]),
103             remaining = source[length:],
104         )
105
106     return string_parser
107
108 def _list_parser(source):
109     tag = source[0]
110     parser = _TAGS_TO_PARSERS[tag]
111
112     source = source[1:]
113     byte_length, items_length = struct.unpack('!II', source[:8])
114     source = source[8:]
115
116     remaining = source[byte_length:]
117     source = source[:byte_length]
118
119     def item_iterator(source):
120         count = 0
121
122         while len(source) > 0:
123             parse_result = parser(source)
124
125             if parse_result.success:
126                 count += 1
127                 yield parse_result.value
128                 source = parse_result.remaining
129
130         assert count == items_length
131     
132     return _shared.ParseResult(
133         success = True,
134         value = item_iterator(source),
135         remaining = remaining,
136     )
137
138 def dictionary_parser(source):
139     key_parser = _TAGS_TO_PARSERS[tags.UTF8]
140
141     byte_length, item_length = struct.unpack('!II', source[:8])
142     source = source[8:]
143
144     remaining = source[byte_length:]
145     source = source[:byte_length]
146
147     def kvp_iterator(source):
148         count = 0
149
150         while len(source) > 0:
151             count += 1
152             key_parse_result = key_parser(source)
153             key, source = key_parse_result.value, key_parse_result.remaining
154             value_parse_result = _object_parser(source)
155             value, source = value_parse_result.value, value_parse_result.remaining
156
157             yield key, value
158
159         assert count == item_length
160
161     return _shared.ParseResult(
162         success = True,
163         value = collections.OrderedDict(kvp_iterator(source)),
164         remaining = remaining,
165     )
166
167
168 _TAGS_TO_PARSERS = {
169     tags.VOID: lambda r: _shared.ParseResult(True, None, r),
170     tags.TRUE: lambda r: _shared.ParseResult(True, True, r),
171     tags.FALSE: lambda r: _shared.ParseResult(True, False, r),
172     tags.INT8: make_integer_parser(1),
173     tags.INT16: make_integer_parser(2),
174     tags.INT32: make_integer_parser(4),
175     tags.INT64: make_integer_parser(8),
176     tags.DOUBLE: binary64_parser,
177     tags.BINARY: make_string_parser(lambda b : b),
178     tags.UTF8: make_string_parser(lambda b : b.decode('utf-8')),
179     tags.UTF16: make_string_parser(lambda b : b.decode('utf-16')),
180     tags.UTF32: make_string_parser(lambda b : b.decode('utf-32')),
181     tags.LIST: _list_parser,
182     tags.DICTIONARY: dictionary_parser,
183 }
184
185 def _object_parser(source):
186     return _TAGS_TO_PARSERS[source[0]](source[1:])
187
188 def _parse(parser, source):
189     result = parser(source)
190
191     if result.success and result.remaining == b'':
192         return result.value
193
194     raise Exception('Unparsed trailing bytes: {}'.format(result.remaining))
195
196 def deserialize(b):
197     return _parse(_object_parser, b)