import functools
-def force_generator(generator_function):
- @functools.wraps(generator_function)
- def forced_generator(*args, **kwargs):
- return list(generator_function(*args, **kwargs))
+def force_generator(to_type):
+ def decorator(generator_function):
+ @functools.wraps(generator_function)
+ def forced_generator(*args, **kwargs):
+ return to_type(generator_function(*args, **kwargs))
- return forced_generator
+ return forced_generator
+
+ return decorator
if __name__ == '__main__':
import unittest
class ForceGeneratorTests(unittest.TestCase):
def test_forces_generator(self):
- forced_range = force_generator(range)
+ forced_range = force_generator(list)(range)
self.assertEqual(
forced_range(10),