From f95ee71758c53ff28c03834d1a1698b44b8bc7dc Mon Sep 17 00:00:00 2001
From: Thang Pham <thang.g.pham@gmail.com>
Date: Mon, 23 Nov 2015 20:12:03 -0800
Subject: [PATCH] Implement refresh() for cinder objects

The following patch implements the refresh functionality
in cinder objects.  Instead of calling get_by_id() to get
the latest object attributes, refresh() can be called.
With this change, delete_volume() was updated to use
volume.refresh().

Change-Id: If0573f1c44c2e67c9a8cbd88dda45310a02e3272
Partial-Implements: blueprint cinder-objects
---
 cinder/objects/base.py                        | 18 ++++++++++++
 cinder/tests/unit/fake_snapshot.py            |  2 +-
 cinder/tests/unit/objects/test_backup.py      | 20 +++++++++++++
 cinder/tests/unit/objects/test_base.py        | 28 +++++++++++++++++++
 cinder/tests/unit/objects/test_cgsnapshot.py  | 20 +++++++++++++
 .../unit/objects/test_consistencygroup.py     | 20 +++++++++++++
 cinder/tests/unit/objects/test_service.py     | 20 +++++++++++++
 cinder/tests/unit/objects/test_snapshot.py    | 20 +++++++++++++
 cinder/tests/unit/objects/test_volume.py      | 21 ++++++++++++++
 .../unit/objects/test_volume_attachment.py    | 20 +++++++++++++
 cinder/tests/unit/objects/test_volume_type.py | 20 +++++++++++++
 cinder/volume/manager.py                      | 11 ++++----
 12 files changed, 213 insertions(+), 7 deletions(-)

diff --git a/cinder/objects/base.py b/cinder/objects/base.py
index 61498370d8a..4e789f60a89 100644
--- a/cinder/objects/base.py
+++ b/cinder/objects/base.py
@@ -205,6 +205,24 @@ class CinderObject(base.VersionedObject):
             self.obj_reset_changes(values.keys())
         return result
 
+    def refresh(self):
+        # To refresh we need to have a model and for the model to have an id
+        # field
+        if 'id' not in self.fields:
+            msg = (_('VersionedObject %s cannot retrieve object by id.') %
+                   (self.obj_name()))
+            raise NotImplementedError(msg)
+
+        current = self.get_by_id(self._context, self.id)
+
+        for field in self.fields:
+            # Only update attributes that are already set.  We do not want to
+            # unexpectedly trigger a lazy-load.
+            if self.obj_attr_is_set(field):
+                if self[field] != current[field]:
+                    self[field] = current[field]
+        self.obj_reset_changes()
+
 
 class CinderObjectDictCompat(base.VersionedObjectDictCompat):
     """Mix-in to provide dictionary key access compat.
diff --git a/cinder/tests/unit/fake_snapshot.py b/cinder/tests/unit/fake_snapshot.py
index 1478beca7b7..136e3269aee 100644
--- a/cinder/tests/unit/fake_snapshot.py
+++ b/cinder/tests/unit/fake_snapshot.py
@@ -23,7 +23,7 @@ def fake_db_snapshot(**updates):
         'volume_id': 'fake_id',
         'status': "creating",
         'progress': '0%',
-        'volume_size': '1',
+        'volume_size': 1,
         'display_name': 'fake_name',
         'display_description': 'fake_description',
         'metadata': {},
diff --git a/cinder/tests/unit/objects/test_backup.py b/cinder/tests/unit/objects/test_backup.py
index 99eb6e83ab1..b1e13720ef4 100644
--- a/cinder/tests/unit/objects/test_backup.py
+++ b/cinder/tests/unit/objects/test_backup.py
@@ -140,6 +140,26 @@ class TestBackup(test_objects.BaseObjectsTestCase):
                           objects.Backup.decode_record,
                           export_string)
 
+    @mock.patch('cinder.db.sqlalchemy.api.backup_get')
+    def test_refresh(self, backup_get):
+        db_backup1 = fake_backup.copy()
+        db_backup2 = db_backup1.copy()
+        db_backup2['display_name'] = 'foobar'
+
+        # On the second backup_get, return the backup with an updated
+        # display_name
+        backup_get.side_effect = [db_backup1, db_backup2]
+        backup = objects.Backup.get_by_id(self.context, '1')
+        self._compare(self, db_backup1, backup)
+
+        # display_name was updated, so a backup refresh should have a new value
+        # for that field
+        backup.refresh()
+        self._compare(self, db_backup2, backup)
+        backup_get.assert_has_calls([mock.call(self.context, '1'),
+                                     mock.call.__nonzero__(),
+                                     mock.call(self.context, '1')])
+
 
 class TestBackupList(test_objects.BaseObjectsTestCase):
     @mock.patch('cinder.db.backup_get_all', return_value=[fake_backup])
diff --git a/cinder/tests/unit/objects/test_base.py b/cinder/tests/unit/objects/test_base.py
index 6e95124c334..1aa6662e827 100644
--- a/cinder/tests/unit/objects/test_base.py
+++ b/cinder/tests/unit/objects/test_base.py
@@ -13,6 +13,7 @@
 #    under the License.
 
 import datetime
+import mock
 import uuid
 
 from iso8601 import iso8601
@@ -82,6 +83,33 @@ class TestCinderObject(test_objects.BaseObjectsTestCase):
         self.assertDictEqual({'scheduled_at': now},
                              self.obj.cinder_obj_get_changes())
 
+    def test_refresh(self):
+        @objects.base.CinderObjectRegistry.register_if(False)
+        class MyTestObject(objects.base.CinderObject,
+                           objects.base.CinderObjectDictCompat,
+                           objects.base.CinderComparableObject):
+            fields = {'id': fields.UUIDField(),
+                      'name': fields.StringField()}
+
+        test_obj = MyTestObject(id='1', name='foo')
+        refresh_obj = MyTestObject(id='1', name='bar')
+        with mock.patch(
+                'cinder.objects.base.CinderObject.get_by_id') as get_by_id:
+            get_by_id.return_value = refresh_obj
+
+            test_obj.refresh()
+            self._compare(self, refresh_obj, test_obj)
+
+    def test_refresh_no_id_field(self):
+        @objects.base.CinderObjectRegistry.register_if(False)
+        class MyTestObjectNoId(objects.base.CinderObject,
+                               objects.base.CinderObjectDictCompat,
+                               objects.base.CinderComparableObject):
+            fields = {'uuid': fields.UUIDField()}
+
+        test_obj = MyTestObjectNoId(uuid='1', name='foo')
+        self.assertRaises(NotImplementedError, test_obj.refresh)
+
 
 class TestCinderComparableObject(test_objects.BaseObjectsTestCase):
     def test_comparable_objects(self):
diff --git a/cinder/tests/unit/objects/test_cgsnapshot.py b/cinder/tests/unit/objects/test_cgsnapshot.py
index f36e53d4c0b..5d8cc97f5d3 100644
--- a/cinder/tests/unit/objects/test_cgsnapshot.py
+++ b/cinder/tests/unit/objects/test_cgsnapshot.py
@@ -107,6 +107,26 @@ class TestCGSnapshot(test_objects.BaseObjectsTestCase):
         snapshotlist_get_for_cgs.assert_called_once_with(
             self.context, cgsnapshot.id)
 
+    @mock.patch('cinder.db.sqlalchemy.api.cgsnapshot_get')
+    def test_refresh(self, cgsnapshot_get):
+        db_cgsnapshot1 = fake_cgsnapshot.copy()
+        db_cgsnapshot2 = db_cgsnapshot1.copy()
+        db_cgsnapshot2['description'] = 'foobar'
+
+        # On the second cgsnapshot_get, return the CGSnapshot with an updated
+        # description
+        cgsnapshot_get.side_effect = [db_cgsnapshot1, db_cgsnapshot2]
+        cgsnapshot = objects.CGSnapshot.get_by_id(self.context, '1')
+        self._compare(self, db_cgsnapshot1, cgsnapshot)
+
+        # description was updated, so a CGSnapshot refresh should have a new
+        # value for that field
+        cgsnapshot.refresh()
+        self._compare(self, db_cgsnapshot2, cgsnapshot)
+        cgsnapshot_get.assert_has_calls([mock.call(self.context, '1'),
+                                         mock.call.__nonzero__(),
+                                         mock.call(self.context, '1')])
+
 
 class TestCGSnapshotList(test_objects.BaseObjectsTestCase):
     @mock.patch('cinder.db.cgsnapshot_get_all',
diff --git a/cinder/tests/unit/objects/test_consistencygroup.py b/cinder/tests/unit/objects/test_consistencygroup.py
index 6caad643820..9c66bb9cfaf 100644
--- a/cinder/tests/unit/objects/test_consistencygroup.py
+++ b/cinder/tests/unit/objects/test_consistencygroup.py
@@ -83,6 +83,26 @@ class TestConsistencyGroup(test_objects.BaseObjectsTestCase):
         admin_context = consistencygroup_destroy.call_args[0][0]
         self.assertTrue(admin_context.is_admin)
 
+    @mock.patch('cinder.db.sqlalchemy.api.consistencygroup_get')
+    def test_refresh(self, consistencygroup_get):
+        db_cg1 = fake_consistencygroup.copy()
+        db_cg2 = db_cg1.copy()
+        db_cg2['description'] = 'foobar'
+
+        # On the second consistencygroup_get, return the ConsistencyGroup with
+        # an updated description
+        consistencygroup_get.side_effect = [db_cg1, db_cg2]
+        cg = objects.ConsistencyGroup.get_by_id(self.context, '1')
+        self._compare(self, db_cg1, cg)
+
+        # description was updated, so a ConsistencyGroup refresh should have a
+        # new value for that field
+        cg.refresh()
+        self._compare(self, db_cg2, cg)
+        consistencygroup_get.assert_has_calls([mock.call(self.context, '1'),
+                                               mock.call.__nonzero__(),
+                                               mock.call(self.context, '1')])
+
 
 class TestConsistencyGroupList(test_objects.BaseObjectsTestCase):
     @mock.patch('cinder.db.consistencygroup_get_all',
diff --git a/cinder/tests/unit/objects/test_service.py b/cinder/tests/unit/objects/test_service.py
index 937e7ba01fc..dc10236aca1 100644
--- a/cinder/tests/unit/objects/test_service.py
+++ b/cinder/tests/unit/objects/test_service.py
@@ -77,6 +77,26 @@ class TestService(test_objects.BaseObjectsTestCase):
             service.destroy()
             service_destroy.assert_called_once_with(elevated_ctx(), 123)
 
+    @mock.patch('cinder.db.sqlalchemy.api.service_get')
+    def test_refresh(self, service_get):
+        db_service1 = fake_service.fake_db_service()
+        db_service2 = db_service1.copy()
+        db_service2['availability_zone'] = 'foobar'
+
+        # On the second service_get, return the service with an updated
+        # availability_zone
+        service_get.side_effect = [db_service1, db_service2]
+        service = objects.Service.get_by_id(self.context, 123)
+        self._compare(self, db_service1, service)
+
+        # availability_zone was updated, so a service refresh should have a
+        # new value for that field
+        service.refresh()
+        self._compare(self, db_service2, service)
+        service_get.assert_has_calls([mock.call(self.context, 123),
+                                      mock.call.__nonzero__(),
+                                      mock.call(self.context, 123)])
+
 
 class TestServiceList(test_objects.BaseObjectsTestCase):
     @mock.patch('cinder.db.service_get_all')
diff --git a/cinder/tests/unit/objects/test_snapshot.py b/cinder/tests/unit/objects/test_snapshot.py
index d3933171ed9..021e1b83f08 100644
--- a/cinder/tests/unit/objects/test_snapshot.py
+++ b/cinder/tests/unit/objects/test_snapshot.py
@@ -167,6 +167,26 @@ class TestSnapshot(test_objects.BaseObjectsTestCase):
                                                   self.project_id,
                                                   volume_type_id)
 
+    @mock.patch('cinder.db.sqlalchemy.api.snapshot_get')
+    def test_refresh(self, snapshot_get):
+        db_snapshot1 = fake_snapshot.fake_db_snapshot()
+        db_snapshot2 = db_snapshot1.copy()
+        db_snapshot2['display_name'] = 'foobar'
+
+        # On the second snapshot_get, return the snapshot with an updated
+        # display_name
+        snapshot_get.side_effect = [db_snapshot1, db_snapshot2]
+        snapshot = objects.Snapshot.get_by_id(self.context, '1')
+        self._compare(self, db_snapshot1, snapshot)
+
+        # display_name was updated, so a snapshot refresh should have a new
+        # value for that field
+        snapshot.refresh()
+        self._compare(self, db_snapshot2, snapshot)
+        snapshot_get.assert_has_calls([mock.call(self.context, '1'),
+                                       mock.call.__nonzero__(),
+                                       mock.call(self.context, '1')])
+
 
 class TestSnapshotList(test_objects.BaseObjectsTestCase):
     @mock.patch('cinder.objects.volume.Volume.get_by_id')
diff --git a/cinder/tests/unit/objects/test_volume.py b/cinder/tests/unit/objects/test_volume.py
index aa199e17a64..2beea8639a7 100644
--- a/cinder/tests/unit/objects/test_volume.py
+++ b/cinder/tests/unit/objects/test_volume.py
@@ -272,6 +272,27 @@ class TestVolume(test_objects.BaseObjectsTestCase):
         self._compare(self, db_consistencygroup, volume.consistencygroup)
         self._compare(self, db_snapshots, volume.snapshots)
 
+    @mock.patch('cinder.db.volume_glance_metadata_get', return_value={})
+    @mock.patch('cinder.db.sqlalchemy.api.volume_get')
+    def test_refresh(self, volume_get, volume_metadata_get):
+        db_volume1 = fake_volume.fake_db_volume()
+        db_volume2 = db_volume1.copy()
+        db_volume2['display_name'] = 'foobar'
+
+        # On the second volume_get, return the volume with an updated
+        # display_name
+        volume_get.side_effect = [db_volume1, db_volume2]
+        volume = objects.Volume.get_by_id(self.context, '1')
+        self._compare(self, db_volume1, volume)
+
+        # display_name was updated, so a volume refresh should have a new value
+        # for that field
+        volume.refresh()
+        self._compare(self, db_volume2, volume)
+        volume_get.assert_has_calls([mock.call(self.context, '1'),
+                                     mock.call.__nonzero__(),
+                                     mock.call(self.context, '1')])
+
 
 class TestVolumeList(test_objects.BaseObjectsTestCase):
     @mock.patch('cinder.db.volume_get_all')
diff --git a/cinder/tests/unit/objects/test_volume_attachment.py b/cinder/tests/unit/objects/test_volume_attachment.py
index 87c7ffbcfbc..ed6325eeff1 100644
--- a/cinder/tests/unit/objects/test_volume_attachment.py
+++ b/cinder/tests/unit/objects/test_volume_attachment.py
@@ -36,6 +36,26 @@ class TestVolumeAttachment(test_objects.BaseObjectsTestCase):
         volume_attachment_update.assert_called_once_with(
             self.context, attachment.id, {'attach_status': 'attaching'})
 
+    @mock.patch('cinder.db.sqlalchemy.api.volume_attachment_get')
+    def test_refresh(self, attachment_get):
+        db_attachment1 = fake_volume.fake_db_volume_attachment()
+        db_attachment2 = db_attachment1.copy()
+        db_attachment2['mountpoint'] = '/dev/sdc'
+
+        # On the second volume_attachment_get, return the volume attachment
+        # with an updated mountpoint
+        attachment_get.side_effect = [db_attachment1, db_attachment2]
+        attachment = objects.VolumeAttachment.get_by_id(self.context, '1')
+        self._compare(self, db_attachment1, attachment)
+
+        # mountpoint was updated, so a volume attachment refresh should have a
+        # new value for that field
+        attachment.refresh()
+        self._compare(self, db_attachment2, attachment)
+        attachment_get.assert_has_calls([mock.call(self.context, '1'),
+                                         mock.call.__nonzero__(),
+                                         mock.call(self.context, '1')])
+
 
 class TestVolumeAttachmentList(test_objects.BaseObjectsTestCase):
     @mock.patch('cinder.db.volume_attachment_get_used_by_volume_id')
diff --git a/cinder/tests/unit/objects/test_volume_type.py b/cinder/tests/unit/objects/test_volume_type.py
index ac3b09ecefc..8006adc10f0 100644
--- a/cinder/tests/unit/objects/test_volume_type.py
+++ b/cinder/tests/unit/objects/test_volume_type.py
@@ -70,6 +70,26 @@ class TestVolumeType(test_objects.BaseObjectsTestCase):
         admin_context = volume_type_destroy.call_args[0][0]
         self.assertTrue(admin_context.is_admin)
 
+    @mock.patch('cinder.db.sqlalchemy.api._volume_type_get_full')
+    def test_refresh(self, volume_type_get):
+        db_type1 = fake_volume.fake_db_volume_type()
+        db_type2 = db_type1.copy()
+        db_type2['description'] = 'foobar'
+
+        # On the second _volume_type_get_full, return the volume type with an
+        # updated description
+        volume_type_get.side_effect = [db_type1, db_type2]
+        volume_type = objects.VolumeType.get_by_id(self.context, '1')
+        self._compare(self, db_type1, volume_type)
+
+        # description was updated, so a volume type refresh should have a new
+        # value for that field
+        volume_type.refresh()
+        self._compare(self, db_type2, volume_type)
+        volume_type_get.assert_has_calls([mock.call(self.context, '1'),
+                                          mock.call.__nonzero__(),
+                                          mock.call(self.context, '1')])
+
 
 class TestVolumeTypeList(test_objects.BaseObjectsTestCase):
     @mock.patch('cinder.volume.volume_types.get_all_types')
diff --git a/cinder/volume/manager.py b/cinder/volume/manager.py
index 3f867e53dfa..e94b07e4bf6 100644
--- a/cinder/volume/manager.py
+++ b/cinder/volume/manager.py
@@ -598,13 +598,12 @@ class VolumeManager(manager.SchedulerDependentManager):
 
         context = context.elevated()
 
-        # FIXME(thangp): Remove this in v2.0 of RPC API.
-        if volume is not None:
-            volume_id = volume.id
-
         try:
-            # TODO(thangp): Replace with volume.refresh() when it is available
-            volume = objects.Volume.get_by_id(context, volume_id)
+            # FIXME(thangp): Remove this in v2.0 of RPC API.
+            if volume is None:
+                volume = objects.Volume.get_by_id(context, volume_id)
+            else:
+                volume.refresh()
         except exception.VolumeNotFound:
             # NOTE(thingee): It could be possible for a volume to
             # be deleted when resuming deletes from init_host().