X-Git-Url: https://code.kerkeslager.com/?p=fur;a=blobdiff_plain;f=util.py;h=678c516dd0c4b9d9166887d5b6b4560bebb59a97;hp=d73990d0db13fa1ca60cd81dbcfe5aa695e6a9e8;hb=edc05c8d2d465653c02c350592eff62c542a37ed;hpb=4ba4fcfbb2712a22a9f3211182c9ec6cee9dd0f8 diff --git a/util.py b/util.py index d73990d..678c516 100644 --- a/util.py +++ b/util.py @@ -1,18 +1,21 @@ import functools -def force_generator(generator_function): - @functools.wraps(generator_function) - def forced_generator(*args, **kwargs): - return list(generator_function(*args, **kwargs)) +def force_generator(to_type): + def decorator(generator_function): + @functools.wraps(generator_function) + def forced_generator(*args, **kwargs): + return to_type(generator_function(*args, **kwargs)) - return forced_generator + return forced_generator + + return decorator if __name__ == '__main__': import unittest class ForceGeneratorTests(unittest.TestCase): def test_forces_generator(self): - forced_range = force_generator(range) + forced_range = force_generator(list)(range) self.assertEqual( forced_range(10),