diff --git a/cinder/db/sqlalchemy/api.py b/cinder/db/sqlalchemy/api.py index 967ccbfcad9..b1566ec565f 100644 --- a/cinder/db/sqlalchemy/api.py +++ b/cinder/db/sqlalchemy/api.py @@ -4567,11 +4567,35 @@ def is_orm_value(obj): sqlalchemy.sql.expression.ColumnElement)) +def _check_is_not_multitable(values, model): + """Check that we don't try to do multitable updates. + + Since PostgreSQL doesn't support multitable updates we want to always fail + if we have such a query in our code, even if with MySQL it would work. + """ + used_models = set() + for field in values: + if isinstance(field, sqlalchemy.orm.attributes.InstrumentedAttribute): + used_models.add(field.class_) + elif isinstance(field, six.string_types): + used_models.add(model) + else: + raise exception.ProgrammingError( + reason='DB Conditional update - Unknown field type, must be ' + 'string or ORM field.') + if len(used_models) > 1: + raise exception.ProgrammingError( + reason='DB Conditional update - Error in query, multitable ' + 'updates are not supported.') + + @require_context @_retry_on_deadlock def conditional_update(context, model, values, expected_values, filters=(), include_deleted='no', project_only=False, order=None): """Compare-and-swap conditional update SQLAlchemy implementation.""" + _check_is_not_multitable(values, model) + # Provided filters will become part of the where clause where_conds = list(filters) diff --git a/cinder/exception.py b/cinder/exception.py index 50bb43ccc6f..ecc256ead8c 100644 --- a/cinder/exception.py +++ b/cinder/exception.py @@ -149,6 +149,10 @@ class GlanceConnectionFailed(CinderException): message = _("Connection to glance failed: %(reason)s") +class ProgrammingError(CinderException): + message = _('Programming error in Cinder: %(reason)s') + + class NotAuthorized(CinderException): message = _("Not authorized.") code = 403 diff --git a/cinder/tests/unit/objects/test_base.py b/cinder/tests/unit/objects/test_base.py index acd9c24bea3..0eb805f6ac7 100644 --- a/cinder/tests/unit/objects/test_base.py +++ b/cinder/tests/unit/objects/test_base.py @@ -24,6 +24,7 @@ from sqlalchemy import sql from cinder import context from cinder import db from cinder.db.sqlalchemy import models +from cinder import exception from cinder import objects from cinder import test from cinder.tests.unit import fake_constants as fake @@ -664,6 +665,30 @@ class TestCinderObjectConditionalUpdate(test.TestCase): self.assertTrue(isinstance(arg, dict)) self.assertEqual(set(values.keys()), set(arg.keys())) + def test_conditional_update_multitable_fail(self): + volume = self._create_volume() + self.assertRaises(exception.ProgrammingError, + volume.conditional_update, + {'status': 'deleting', + objects.Snapshot.model.status: 'available'}, + {'status': 'available'}) + + def test_conditional_update_multitable_fail_fields_different_models(self): + volume = self._create_volume() + self.assertRaises(exception.ProgrammingError, + volume.conditional_update, + {objects.Backup.model.status: 'available', + objects.Snapshot.model.status: 'available'}) + + def test_conditional_update_not_multitable(self): + volume = self._create_volume() + with mock.patch('cinder.db.sqlalchemy.api._create_facade_lazily') as m: + res = volume.conditional_update( + {objects.Volume.model.status: 'deleting', + objects.Volume.model.size: 12}, reflect_changes=False) + self.assertTrue(res) + self.assertTrue(m.called) + class TestCinderDictObject(test_objects.BaseObjectsTestCase): @objects.base.CinderObjectRegistry.register_if(False)