diff --git a/cinder/common/sqlalchemyutils.py b/cinder/common/sqlalchemyutils.py index 161d5ca3e11..c24628cb57c 100644 --- a/cinder/common/sqlalchemyutils.py +++ b/cinder/common/sqlalchemyutils.py @@ -17,10 +17,13 @@ # under the License. """Implementation of paginate query.""" +import datetime from oslo_log import log as logging from six.moves import range import sqlalchemy +import sqlalchemy.sql as sa_sql +from sqlalchemy.sql import type_api from cinder.db import api from cinder import exception @@ -29,7 +32,32 @@ from cinder.i18n import _, _LW LOG = logging.getLogger(__name__) +_TYPE_SCHEMA = { + 'datetime': datetime.datetime(1900, 1, 1), + 'big_integer': 0, + 'integer': 0, + 'string': '' +} + +def _get_default_column_value(model, column_name): + """Return the default value of the columns from DB table. + + In postgreDB case, if no right default values are being set, an + psycopg2.DataError will be thrown. + """ + attr = getattr(model, column_name) + # Return the default value directly if the model contains. Otherwise return + # a default value which is not None. + if attr.default and isinstance(attr.default, type_api.TypeEngine): + return attr.default.arg + + attr_type = attr.type + return _TYPE_SCHEMA[attr_type.__visit_name__] + + +# TODO(wangxiyuan): Use oslo_db.sqlalchemy.utils.paginate_query once it is +# stable and afforded by the minimum version in requirement.txt. # copied from glance/db/sqlalchemy/api.py def paginate_query(query, model, limit, sort_keys, marker=None, sort_dir=None, sort_dirs=None, offset=None): @@ -58,6 +86,8 @@ def paginate_query(query, model, limit, sort_keys, marker=None, results after this value. :param sort_dir: direction in which results should be sorted (asc, desc) :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys + :param offset: the number of items to skip from the marker or from the + first element. :rtype: sqlalchemy.orm.query.Query :return: The query with sorting/pagination added. @@ -100,6 +130,8 @@ def paginate_query(query, model, limit, sort_keys, marker=None, marker_values = [] for sort_key in sort_keys: v = getattr(marker, sort_key) + if v is None: + v = _get_default_column_value(model, sort_key) marker_values.append(v) # Build up an array of sort criteria as in the docstring @@ -108,13 +140,21 @@ def paginate_query(query, model, limit, sort_keys, marker=None, crit_attrs = [] for j in range(0, i): model_attr = getattr(model, sort_keys[j]) - crit_attrs.append((model_attr == marker_values[j])) + default = _get_default_column_value(model, sort_keys[j]) + attr = sa_sql.expression.case([(model_attr.isnot(None), + model_attr), ], + else_=default) + crit_attrs.append((attr == marker_values[j])) model_attr = getattr(model, sort_keys[i]) + default = _get_default_column_value(model, sort_keys[i]) + attr = sa_sql.expression.case([(model_attr.isnot(None), + model_attr), ], + else_=default) if sort_dirs[i] == 'desc': - crit_attrs.append((model_attr < marker_values[i])) + crit_attrs.append((attr < marker_values[i])) elif sort_dirs[i] == 'asc': - crit_attrs.append((model_attr > marker_values[i])) + crit_attrs.append((attr > marker_values[i])) else: raise ValueError(_("Unknown sort direction, " "must be 'desc' or 'asc'")) diff --git a/cinder/tests/unit/test_paginate_query.py b/cinder/tests/unit/test_paginate_query.py new file mode 100644 index 00000000000..05a6c985666 --- /dev/null +++ b/cinder/tests/unit/test_paginate_query.py @@ -0,0 +1,41 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from cinder.common import sqlalchemyutils +from cinder import context +from cinder.db.sqlalchemy import api as db_api +from cinder.db.sqlalchemy import models +from cinder import test +from cinder.tests.unit import fake_constants as fake + + +class TestPaginateQuery(test.TestCase): + def setUp(self): + super(TestPaginateQuery, self).setUp() + self.ctxt = context.RequestContext(fake.USER_ID, fake.PROJECT_ID, + auth_token=True, + is_admin=True) + self.query = db_api._volume_get_query(self.ctxt) + self.model = models.Volume + + def test_paginate_query_marker_null(self): + marker_object = self.model() + self.assertIsNone(marker_object.display_name) + self.assertIsNone(marker_object.updated_at) + + marker_object.size = 1 + # There is no error raised here. + sqlalchemyutils.paginate_query(self.query, self.model, 10, + sort_keys=['display_name', + 'updated_at', + 'size'], + marker=marker_object, + sort_dirs=['desc', 'asc', 'desc'])