Let WSGI know the length
... so that eventlet.wsgi will include Connection: close headers on Expect: 100-continue error responses from s3api and make aws clients less unhappy. eventlet.wsgi likes to check for hasattr('__len__', resp) when setting connection close on expect-100 errors, but uses 'content-length' in headers when deciding on chunked-transfer. When we know the length we can support either interface. Also we can imporove s3api to return error responses with the content-length known. Change-Id: Ic504841714bd607cb9733b2de5126092a79c1094
This commit is contained in:
parent
fd9ceecc50
commit
1ca073ce1d
@ -22,7 +22,7 @@ class BadResponseLength(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def enforce_byte_count(inner_iter, nbytes):
|
||||
class ByteEnforcer(object):
|
||||
"""
|
||||
Enforces that inner_iter yields exactly <nbytes> bytes before
|
||||
exhaustion.
|
||||
@ -31,25 +31,39 @@ def enforce_byte_count(inner_iter, nbytes):
|
||||
|
||||
:param inner_iter: iterable of bytestrings
|
||||
:param nbytes: number of bytes expected
|
||||
"""
|
||||
try:
|
||||
bytes_left = nbytes
|
||||
for chunk in inner_iter:
|
||||
if bytes_left >= len(chunk):
|
||||
yield chunk
|
||||
bytes_left -= len(chunk)
|
||||
else:
|
||||
yield chunk[:bytes_left]
|
||||
raise BadResponseLength(
|
||||
"Too many bytes; truncating after %d bytes "
|
||||
"with at least %d surplus bytes remaining" % (
|
||||
nbytes, len(chunk) - bytes_left))
|
||||
|
||||
if bytes_left:
|
||||
raise BadResponseLength('Expected another %d bytes' % (
|
||||
bytes_left,))
|
||||
finally:
|
||||
close_if_possible(inner_iter)
|
||||
N.B. since we require the nbytes param and require the inner_iter to yield
|
||||
exactly that many bytes we can support the __len__ interface for anyone
|
||||
happens to expect non chunked resp iterables to support that
|
||||
(e.g. eventlet's wsgi.server).
|
||||
"""
|
||||
|
||||
def __init__(self, inner_iter, nbytes):
|
||||
self.inner_iter = inner_iter
|
||||
self.nbytes = nbytes
|
||||
|
||||
def __len__(self):
|
||||
return self.nbytes
|
||||
|
||||
def __iter__(self):
|
||||
try:
|
||||
bytes_left = self.nbytes
|
||||
for chunk in self.inner_iter:
|
||||
if bytes_left >= len(chunk):
|
||||
yield chunk
|
||||
bytes_left -= len(chunk)
|
||||
else:
|
||||
yield chunk[:bytes_left]
|
||||
raise BadResponseLength(
|
||||
"Too many bytes; truncating after %d bytes "
|
||||
"with at least %d surplus bytes remaining" % (
|
||||
self.nbytes, len(chunk) - bytes_left))
|
||||
|
||||
if bytes_left:
|
||||
raise BadResponseLength('Expected another %d bytes' % (
|
||||
bytes_left,))
|
||||
finally:
|
||||
close_if_possible(self.inner_iter)
|
||||
|
||||
|
||||
class CatchErrorsContext(WSGIContext):
|
||||
@ -99,7 +113,7 @@ class CatchErrorsContext(WSGIContext):
|
||||
# and raise an exception to stop any more bytes from being
|
||||
# generated and also to kill the TCP connection.
|
||||
if env['REQUEST_METHOD'] == 'HEAD':
|
||||
resp = enforce_byte_count(resp, 0)
|
||||
resp = ByteEnforcer(resp, 0)
|
||||
|
||||
elif self._response_headers:
|
||||
content_lengths = [val for header, val in self._response_headers
|
||||
@ -110,7 +124,7 @@ class CatchErrorsContext(WSGIContext):
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
resp = enforce_byte_count(resp, content_length)
|
||||
resp = ByteEnforcer(resp, content_length)
|
||||
|
||||
# make sure the response has the trans_id
|
||||
if self._response_headers is None:
|
||||
|
@ -89,7 +89,7 @@ import os
|
||||
import time
|
||||
|
||||
from swift.common.constraints import valid_api_version
|
||||
from swift.common.middleware.catch_errors import enforce_byte_count
|
||||
from swift.common.middleware.catch_errors import ByteEnforcer
|
||||
from swift.common.request_helpers import get_log_info
|
||||
from swift.common.swob import Request
|
||||
from swift.common.utils import (get_logger, get_remote_client,
|
||||
@ -430,7 +430,7 @@ class ProxyLoggingMiddleware(object):
|
||||
if method == 'HEAD':
|
||||
content_length = 0
|
||||
if content_length is not None:
|
||||
iterator = enforce_byte_count(iterator, content_length)
|
||||
iterator = ByteEnforcer(iterator, content_length)
|
||||
|
||||
wire_status_int = int(start_response_args[0][0].split(' ', 1)[0])
|
||||
resp_headers = dict(start_response_args[0][1])
|
||||
|
@ -243,6 +243,9 @@ class ErrorResponse(S3ResponseBase, swob.HTTPException):
|
||||
|
||||
swob.HTTPException.__init__(
|
||||
self, status=kwargs.pop('status', self._status),
|
||||
# we use an app_iter, so that we can add our trans_id to the resp
|
||||
# xml *after* we've been called - technically any non-None app_iter
|
||||
# would do, we override swob.Response._response_iter anyway.
|
||||
app_iter=self._body_iter(),
|
||||
content_type='application/xml', *args,
|
||||
**kwargs)
|
||||
@ -265,6 +268,9 @@ class ErrorResponse(S3ResponseBase, swob.HTTPException):
|
||||
error_elem = Element('Error')
|
||||
SubElement(error_elem, 'Code').text = self._code
|
||||
SubElement(error_elem, 'Message').text = self._msg
|
||||
# N.B. swob.Response objects don't normally have an environ attribute
|
||||
# when they're created, but swob always gives this to us when we're
|
||||
# __call__'d
|
||||
if 'swift.trans_id' in self.environ:
|
||||
request_id = self.environ['swift.trans_id']
|
||||
SubElement(error_elem, 'RequestId').text = request_id
|
||||
@ -274,6 +280,13 @@ class ErrorResponse(S3ResponseBase, swob.HTTPException):
|
||||
yield tostring(error_elem, use_s3ns=False,
|
||||
xml_declaration=self.xml_declaration)
|
||||
|
||||
def _response_iter(self, app_iter, body):
|
||||
# we don't actually want our _response_iter to be a generator, a list
|
||||
# of strings is much better for eventlet.wsgi.server connection
|
||||
# handling and request pipelining and ErrorResponses are small. FWIW
|
||||
# we now have self.environ, app_iter=self._body_iter() and body is None
|
||||
return super()._response_iter(list(app_iter), body)
|
||||
|
||||
def _dict_to_etree(self, parent, d):
|
||||
for key, value in d.items():
|
||||
tag = re.sub(r'\W', '', snake_to_camel(key))
|
||||
|
@ -15,8 +15,9 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from swift.common.swob import Response
|
||||
from swift.common.swob import Response, Request
|
||||
from swift.common.utils import HeaderKeyDict
|
||||
from swift.common.middleware.catch_errors import CatchErrorMiddleware
|
||||
from swift.common.middleware.s3api.s3response import S3Response, ErrorResponse
|
||||
from swift.common.middleware.s3api.utils import sysmeta_prefix
|
||||
|
||||
@ -124,6 +125,27 @@ class TestErrorResponse(unittest.TestCase):
|
||||
b"</Error>",
|
||||
resp.body)
|
||||
|
||||
def test_error_response_trans_id(self):
|
||||
req = Request.blank('/bucket/object')
|
||||
err = DummyErrorResponse(msg='my-msg', reason='my reason')
|
||||
app = CatchErrorMiddleware(err, {})
|
||||
with unittest.mock.patch(
|
||||
'swift.common.middleware.catch_errors.generate_trans_id',
|
||||
return_value='fake-trans-id'):
|
||||
resp = req.get_response(app)
|
||||
self.assertIn('swift.trans_id', req.environ)
|
||||
self.assertEqual(418, resp.status_int)
|
||||
self.assertIn('X-Trans-Id', resp.headers)
|
||||
self.assertEqual(
|
||||
b"<?xml version='1.0' encoding='UTF-8'?>\n"
|
||||
b"<Error>"
|
||||
b"<Code>DummyErrorResponse</Code>"
|
||||
b"<Message>my-msg</Message>"
|
||||
b"<RequestId>fake-trans-id</RequestId>"
|
||||
b"</Error>",
|
||||
resp.body)
|
||||
self.assertEqual(146, int(resp.headers['Content-Length']))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from swift.common.swob import Request
|
||||
from swift.common.swob import Request, HTTPOk
|
||||
from swift.common.middleware import catch_errors
|
||||
from swift.common.utils import get_logger
|
||||
|
||||
@ -137,6 +137,30 @@ class TestCatchErrors(unittest.TestCase):
|
||||
resp = app(req.environ, self.start_response)
|
||||
self.assertEqual(list(resp), [b'An error occurred'])
|
||||
|
||||
def test_has_len(self):
|
||||
# sanity
|
||||
app = HTTPOk(body='test-body')
|
||||
req = Request.blank('/')
|
||||
captured_status_length = []
|
||||
|
||||
def capture_start_resp(status, headers, exc_info=None):
|
||||
length = None
|
||||
for k, v in headers:
|
||||
if k == 'Content-Length':
|
||||
length = int(v)
|
||||
captured_status_length.append((status, length))
|
||||
iterable = app(req.environ, capture_start_resp)
|
||||
self.assertEqual(captured_status_length, [('200 OK', 9)])
|
||||
self.assertTrue(hasattr(iterable, '__len__'))
|
||||
# wrapped should work the same way
|
||||
app_resp = HTTPOk(body='test-body')
|
||||
app = catch_errors.CatchErrorMiddleware(app_resp, {})
|
||||
req = Request.blank('/')
|
||||
captured_status_length = []
|
||||
iterable = app(req.environ, capture_start_resp)
|
||||
self.assertEqual(captured_status_length, [('200 OK', 9)])
|
||||
self.assertTrue(hasattr(iterable, '__len__'))
|
||||
|
||||
def test_HEAD_with_content_length(self):
|
||||
def cannot_count_app(env, sr):
|
||||
sr("200 OK", [("Content-Length", "10")])
|
||||
|
Loading…
x
Reference in New Issue
Block a user