Let WSGI know the length

... so that eventlet.wsgi will include Connection: close headers on
Expect: 100-continue error responses from s3api and make aws clients
less unhappy.

eventlet.wsgi likes to check for hasattr('__len__', resp) when setting
connection close on expect-100 errors, but uses 'content-length' in
headers when deciding on chunked-transfer.

When we know the length we can support either interface.  Also we can
imporove s3api to return error responses with the content-length known.

Change-Id: Ic504841714bd607cb9733b2de5126092a79c1094
This commit is contained in:
Clay Gerrard 2025-03-18 18:12:31 -05:00
parent fd9ceecc50
commit 1ca073ce1d
5 changed files with 98 additions and 25 deletions

View File

@ -22,7 +22,7 @@ class BadResponseLength(Exception):
pass
def enforce_byte_count(inner_iter, nbytes):
class ByteEnforcer(object):
"""
Enforces that inner_iter yields exactly <nbytes> bytes before
exhaustion.
@ -31,25 +31,39 @@ def enforce_byte_count(inner_iter, nbytes):
:param inner_iter: iterable of bytestrings
:param nbytes: number of bytes expected
"""
try:
bytes_left = nbytes
for chunk in inner_iter:
if bytes_left >= len(chunk):
yield chunk
bytes_left -= len(chunk)
else:
yield chunk[:bytes_left]
raise BadResponseLength(
"Too many bytes; truncating after %d bytes "
"with at least %d surplus bytes remaining" % (
nbytes, len(chunk) - bytes_left))
if bytes_left:
raise BadResponseLength('Expected another %d bytes' % (
bytes_left,))
finally:
close_if_possible(inner_iter)
N.B. since we require the nbytes param and require the inner_iter to yield
exactly that many bytes we can support the __len__ interface for anyone
happens to expect non chunked resp iterables to support that
(e.g. eventlet's wsgi.server).
"""
def __init__(self, inner_iter, nbytes):
self.inner_iter = inner_iter
self.nbytes = nbytes
def __len__(self):
return self.nbytes
def __iter__(self):
try:
bytes_left = self.nbytes
for chunk in self.inner_iter:
if bytes_left >= len(chunk):
yield chunk
bytes_left -= len(chunk)
else:
yield chunk[:bytes_left]
raise BadResponseLength(
"Too many bytes; truncating after %d bytes "
"with at least %d surplus bytes remaining" % (
self.nbytes, len(chunk) - bytes_left))
if bytes_left:
raise BadResponseLength('Expected another %d bytes' % (
bytes_left,))
finally:
close_if_possible(self.inner_iter)
class CatchErrorsContext(WSGIContext):
@ -99,7 +113,7 @@ class CatchErrorsContext(WSGIContext):
# and raise an exception to stop any more bytes from being
# generated and also to kill the TCP connection.
if env['REQUEST_METHOD'] == 'HEAD':
resp = enforce_byte_count(resp, 0)
resp = ByteEnforcer(resp, 0)
elif self._response_headers:
content_lengths = [val for header, val in self._response_headers
@ -110,7 +124,7 @@ class CatchErrorsContext(WSGIContext):
except ValueError:
pass
else:
resp = enforce_byte_count(resp, content_length)
resp = ByteEnforcer(resp, content_length)
# make sure the response has the trans_id
if self._response_headers is None:

View File

@ -89,7 +89,7 @@ import os
import time
from swift.common.constraints import valid_api_version
from swift.common.middleware.catch_errors import enforce_byte_count
from swift.common.middleware.catch_errors import ByteEnforcer
from swift.common.request_helpers import get_log_info
from swift.common.swob import Request
from swift.common.utils import (get_logger, get_remote_client,
@ -430,7 +430,7 @@ class ProxyLoggingMiddleware(object):
if method == 'HEAD':
content_length = 0
if content_length is not None:
iterator = enforce_byte_count(iterator, content_length)
iterator = ByteEnforcer(iterator, content_length)
wire_status_int = int(start_response_args[0][0].split(' ', 1)[0])
resp_headers = dict(start_response_args[0][1])

View File

@ -243,6 +243,9 @@ class ErrorResponse(S3ResponseBase, swob.HTTPException):
swob.HTTPException.__init__(
self, status=kwargs.pop('status', self._status),
# we use an app_iter, so that we can add our trans_id to the resp
# xml *after* we've been called - technically any non-None app_iter
# would do, we override swob.Response._response_iter anyway.
app_iter=self._body_iter(),
content_type='application/xml', *args,
**kwargs)
@ -265,6 +268,9 @@ class ErrorResponse(S3ResponseBase, swob.HTTPException):
error_elem = Element('Error')
SubElement(error_elem, 'Code').text = self._code
SubElement(error_elem, 'Message').text = self._msg
# N.B. swob.Response objects don't normally have an environ attribute
# when they're created, but swob always gives this to us when we're
# __call__'d
if 'swift.trans_id' in self.environ:
request_id = self.environ['swift.trans_id']
SubElement(error_elem, 'RequestId').text = request_id
@ -274,6 +280,13 @@ class ErrorResponse(S3ResponseBase, swob.HTTPException):
yield tostring(error_elem, use_s3ns=False,
xml_declaration=self.xml_declaration)
def _response_iter(self, app_iter, body):
# we don't actually want our _response_iter to be a generator, a list
# of strings is much better for eventlet.wsgi.server connection
# handling and request pipelining and ErrorResponses are small. FWIW
# we now have self.environ, app_iter=self._body_iter() and body is None
return super()._response_iter(list(app_iter), body)
def _dict_to_etree(self, parent, d):
for key, value in d.items():
tag = re.sub(r'\W', '', snake_to_camel(key))

View File

@ -15,8 +15,9 @@
import unittest
from swift.common.swob import Response
from swift.common.swob import Response, Request
from swift.common.utils import HeaderKeyDict
from swift.common.middleware.catch_errors import CatchErrorMiddleware
from swift.common.middleware.s3api.s3response import S3Response, ErrorResponse
from swift.common.middleware.s3api.utils import sysmeta_prefix
@ -124,6 +125,27 @@ class TestErrorResponse(unittest.TestCase):
b"</Error>",
resp.body)
def test_error_response_trans_id(self):
req = Request.blank('/bucket/object')
err = DummyErrorResponse(msg='my-msg', reason='my reason')
app = CatchErrorMiddleware(err, {})
with unittest.mock.patch(
'swift.common.middleware.catch_errors.generate_trans_id',
return_value='fake-trans-id'):
resp = req.get_response(app)
self.assertIn('swift.trans_id', req.environ)
self.assertEqual(418, resp.status_int)
self.assertIn('X-Trans-Id', resp.headers)
self.assertEqual(
b"<?xml version='1.0' encoding='UTF-8'?>\n"
b"<Error>"
b"<Code>DummyErrorResponse</Code>"
b"<Message>my-msg</Message>"
b"<RequestId>fake-trans-id</RequestId>"
b"</Error>",
resp.body)
self.assertEqual(146, int(resp.headers['Content-Length']))
if __name__ == '__main__':
unittest.main()

View File

@ -15,7 +15,7 @@
import unittest
from swift.common.swob import Request
from swift.common.swob import Request, HTTPOk
from swift.common.middleware import catch_errors
from swift.common.utils import get_logger
@ -137,6 +137,30 @@ class TestCatchErrors(unittest.TestCase):
resp = app(req.environ, self.start_response)
self.assertEqual(list(resp), [b'An error occurred'])
def test_has_len(self):
# sanity
app = HTTPOk(body='test-body')
req = Request.blank('/')
captured_status_length = []
def capture_start_resp(status, headers, exc_info=None):
length = None
for k, v in headers:
if k == 'Content-Length':
length = int(v)
captured_status_length.append((status, length))
iterable = app(req.environ, capture_start_resp)
self.assertEqual(captured_status_length, [('200 OK', 9)])
self.assertTrue(hasattr(iterable, '__len__'))
# wrapped should work the same way
app_resp = HTTPOk(body='test-body')
app = catch_errors.CatchErrorMiddleware(app_resp, {})
req = Request.blank('/')
captured_status_length = []
iterable = app(req.environ, capture_start_resp)
self.assertEqual(captured_status_length, [('200 OK', 9)])
self.assertTrue(hasattr(iterable, '__len__'))
def test_HEAD_with_content_length(self):
def cannot_count_app(env, sr):
sr("200 OK", [("Content-Length", "10")])