Simplified tests, added deserialization for unsigned integers
[sandbox] / serial / binary.py
1 import collections
2 import functools
3 import io
4 import struct
5
6 TAG_NULL = 0x00
7 TAG_TRUE = 0x01
8 TAG_FALSE = 0x02
9 TAG_UINT8 = 0x03
10 TAG_UINT16 = 0x04
11 TAG_UINT32 = 0x05
12 TAG_UINT64 = 0x06
13
14 TaggedObject = collections.namedtuple(
15     'TaggedObject',
16     [
17         'tag',
18         'instance',
19     ],
20 )
21
22 def _make_tag_only_serializer(tag, expected_value):
23     tag = bytes([tag])
24
25     def serializer(to):
26         assert to.instance == expected_value
27         return tag
28
29     return serializer
30
31 def _make_struct_serializer(fmt):
32     fmt = '!B' + fmt
33     packer = functools.partial(struct.pack, fmt)
34
35     def serializer(to):
36         return packer(to.tag, to.instance)
37
38     return serializer
39
40 _TAGS_TO_SERIALIZERS = {
41     TAG_NULL: _make_tag_only_serializer(TAG_NULL, None),
42     TAG_TRUE: _make_tag_only_serializer(TAG_TRUE, True),
43     TAG_FALSE: _make_tag_only_serializer(TAG_FALSE, False),
44     TAG_UINT8: _make_struct_serializer('B'),
45     TAG_UINT16: _make_struct_serializer('H'),
46     TAG_UINT32: _make_struct_serializer('I'),
47     TAG_UINT64: _make_struct_serializer('Q'),
48 }
49
50 def serialize(to):
51     return _TAGS_TO_SERIALIZERS[to.tag](to)
52
53 def _make_tag_only_parser(tag, value):
54     def parser(b):
55         return TaggedObject(tag = tag, instance = value)
56
57     return parser
58
59 def _make_struct_deserializer(tag, fmt):
60     fmt = '!' + fmt
61     size = struct.calcsize(fmt)
62     unpacker = functools.partial(struct.unpack, fmt)
63
64     def parser(b):
65         b = b.read(size)
66         assert len(b) == size
67         return TaggedObject(tag = tag, instance = unpacker(b)[0])
68
69     return parser
70
71 _TAGS_TO_PARSERS = {
72     TAG_NULL: _make_tag_only_parser(TAG_NULL, None),
73     TAG_TRUE: _make_tag_only_parser(TAG_TRUE, True),
74     TAG_FALSE: _make_tag_only_parser(TAG_FALSE, False),
75     TAG_UINT8: _make_struct_deserializer(TAG_UINT8, 'B'),
76     TAG_UINT16: _make_struct_deserializer(TAG_UINT16, 'H'),
77     TAG_UINT32: _make_struct_deserializer(TAG_UINT32, 'I'),
78     TAG_UINT64: _make_struct_deserializer(TAG_UINT64, 'Q'),
79 }
80
81 def deserialize(b):
82     if isinstance(b, bytes):
83         b = io.BytesIO(b)
84
85     tag = b.read(1)[0]
86
87     result = _TAGS_TO_PARSERS[tag](b)
88
89     remainder = b.read()
90
91     if len(remainder) == 0:
92         return result
93
94     raise Exception('Unable to parse remainder: {}'.format(remainder))