diff --git a/swift/common/request_helpers.py b/swift/common/request_helpers.py index 8b3940fe54..e01d1a0045 100644 --- a/swift/common/request_helpers.py +++ b/swift/common/request_helpers.py @@ -39,7 +39,7 @@ from swift.common.utils import split_path, validate_device_partition, \ close_if_possible, maybe_multipart_byteranges_to_document_iters, \ multipart_byteranges_to_document_iters, parse_content_type, \ parse_content_range, csv_append, list_from_csv, Spliterator, quote, \ - RESERVED, config_true_value, md5 + RESERVED, config_true_value, md5, CloseableChain from swift.common.wsgi import make_subrequest @@ -736,7 +736,7 @@ class SegmentedIterable(object): if self.peeked_chunk is not None: pc = self.peeked_chunk self.peeked_chunk = None - return itertools.chain([pc], self.app_iter) + return CloseableChain([pc], self.app_iter) else: return self.app_iter diff --git a/test/unit/__init__.py b/test/unit/__init__.py index 0ac52e9b1d..71356326f2 100644 --- a/test/unit/__init__.py +++ b/test/unit/__init__.py @@ -423,6 +423,25 @@ class FakeMemcache(object): return True +class FakeIterable(object): + def __init__(self, values): + self.next_call_count = 0 + self.close_call_count = 0 + self.values = iter(values) + + def __iter__(self): + return self + + def __next__(self): + self.next_call_count += 1 + return next(self.values) + + next = __next__ # py2 + + def close(self): + self.close_call_count += 1 + + def readuntil2crlfs(fd): rv = b'' lc = b'' diff --git a/test/unit/common/middleware/test_slo.py b/test/unit/common/middleware/test_slo.py index d02c755854..f4c45423f6 100644 --- a/test/unit/common/middleware/test_slo.py +++ b/test/unit/common/middleware/test_slo.py @@ -3257,7 +3257,7 @@ class TestSloGetManifest(SloTestCase): self.assertEqual(headers['X-Object-Meta-Fish'], 'Bass') self.assertEqual(body, b'') - def test_generator_closure(self): + def _do_test_generator_closure(self, leaks): # Test that the SLO WSGI iterable closes its internal .app_iter when # it receives a close() message. # @@ -3270,8 +3270,6 @@ class TestSloGetManifest(SloTestCase): # well; calling .close() on the generator is sufficient, but not # necessary. However, having this test is better than nothing for # preventing regressions. - leaks = [0] - class LeakTracker(object): def __init__(self, inner_iter): leaks[0] += 1 @@ -3313,13 +3311,31 @@ class TestSloGetManifest(SloTestCase): LeakTrackingSegmentedIterable): app_resp = self.slo(req.environ, start_response) self.assertEqual(status[0], '200 OK') # sanity check + return app_resp + + def test_generator_closure(self): + leaks = [0] + app_resp = self._do_test_generator_closure(leaks) body_iter = iter(app_resp) chunk = next(body_iter) self.assertEqual(chunk, b'aaaaa') # sanity check - app_resp.close() self.assertEqual(0, leaks[0]) + def test_generator_closure_iter_app_resp(self): + # verify that the result of iter(app_resp) has a close method that + # closes app_resp + leaks = [0] + app_resp = self._do_test_generator_closure(leaks) + body_iter = iter(app_resp) + chunk = next(body_iter) + self.assertEqual(chunk, b'aaaaa') # sanity check + close_method = getattr(body_iter, 'close', None) + self.assertIsNotNone(close_method) + self.assertTrue(callable(close_method)) + close_method() + self.assertEqual(0, leaks[0]) + def test_head_manifest_is_efficient(self): req = Request.blank( '/v1/AUTH_test/gettest/manifest-abcd', diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py index dfdefd1d84..c59172098a 100644 --- a/test/unit/common/test_utils.py +++ b/test/unit/common/test_utils.py @@ -19,7 +19,7 @@ from __future__ import print_function import hashlib from test.unit import temptree, debug_logger, make_timestamp_iter, \ - with_tempdir, mock_timestamp_now + with_tempdir, mock_timestamp_now, FakeIterable import ctypes import contextlib @@ -8771,3 +8771,68 @@ class TestWatchdog(unittest.TestCase): self.assertEqual(exc.seconds, 5.0) self.assertEqual(None, w._next_expiration) w._evt.wait.assert_called_once_with(None) + + +class TestReiterate(unittest.TestCase): + def test_reiterate_consumes_first(self): + test_iter = FakeIterable([1, 2, 3]) + reiterated = utils.reiterate(test_iter) + self.assertEqual(1, test_iter.next_call_count) + self.assertEqual(1, next(reiterated)) + self.assertEqual(1, test_iter.next_call_count) + self.assertEqual(2, next(reiterated)) + self.assertEqual(2, test_iter.next_call_count) + self.assertEqual(3, next(reiterated)) + self.assertEqual(3, test_iter.next_call_count) + + def test_reiterate_closes(self): + test_iter = FakeIterable([1, 2, 3]) + self.assertEqual(0, test_iter.close_call_count) + reiterated = utils.reiterate(test_iter) + self.assertEqual(0, test_iter.close_call_count) + self.assertTrue(hasattr(reiterated, 'close')) + self.assertTrue(callable(reiterated.close)) + reiterated.close() + self.assertEqual(1, test_iter.close_call_count) + + # empty iter gets closed when reiterated + test_iter = FakeIterable([]) + self.assertEqual(0, test_iter.close_call_count) + reiterated = utils.reiterate(test_iter) + self.assertFalse(hasattr(reiterated, 'close')) + self.assertEqual(1, test_iter.close_call_count) + + def test_reiterate_list_or_tuple(self): + test_list = [1, 2] + reiterated = utils.reiterate(test_list) + self.assertIs(test_list, reiterated) + test_tuple = (1, 2) + reiterated = utils.reiterate(test_tuple) + self.assertIs(test_tuple, reiterated) + + +class TestCloseableChain(unittest.TestCase): + def test_closeable_chain_iterates(self): + test_iter1 = FakeIterable([1]) + test_iter2 = FakeIterable([2, 3]) + chain = utils.CloseableChain(test_iter1, test_iter2) + self.assertEqual([1, 2, 3], [x for x in chain]) + + chain = utils.CloseableChain([1, 2], [3]) + self.assertEqual([1, 2, 3], [x for x in chain]) + + def test_closeable_chain_closes(self): + test_iter1 = FakeIterable([1]) + test_iter2 = FakeIterable([2, 3]) + chain = utils.CloseableChain(test_iter1, test_iter2) + self.assertEqual(0, test_iter1.close_call_count) + self.assertEqual(0, test_iter2.close_call_count) + chain.close() + self.assertEqual(1, test_iter1.close_call_count) + self.assertEqual(1, test_iter2.close_call_count) + + # check that close is safe to call even when component iters have no + # close + chain = utils.CloseableChain([1, 2], [3]) + chain.close() + self.assertEqual([1, 2, 3], [x for x in chain])