Skip to content

Commit 0a62c14

Browse files
committed
fix: fix instance registration cleanup on early iterator termination
1 parent 145110f commit 0a62c14

File tree

8 files changed

+377
-109
lines changed

8 files changed

+377
-109
lines changed

google/cloud/bigtable/data/_async/client.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,8 @@ async def _manage_channel(
476476
async def _register_instance(
477477
self,
478478
instance_id: str,
479-
owner: _DataApiTargetAsync | ExecuteQueryIteratorAsync,
479+
app_profile_id: Optional[str],
480+
owner_id: int,
480481
) -> None:
481482
"""
482483
Registers an instance with the client, and warms the channel for the instance
@@ -486,13 +487,15 @@ async def _register_instance(
486487
487488
Args:
488489
instance_id: id of the instance to register.
489-
owner: table that owns the instance. Owners will be tracked in
490+
app_profile_id: id of the app profile calling the instance.
491+
owner_id: integer id of the object owning the instance. Owners will be tracked in
490492
_instance_owners, and instances will only be unregistered when all
491-
owners call _remove_instance_registration
493+
owners call _remove_instance_registration. Can be obtained by calling
494+
id(owner)
492495
"""
493496
instance_name = self._gapic_client.instance_path(self.project, instance_id)
494-
instance_key = _WarmedInstanceKey(instance_name, owner.app_profile_id)
495-
self._instance_owners.setdefault(instance_key, set()).add(id(owner))
497+
instance_key = _WarmedInstanceKey(instance_name, app_profile_id)
498+
self._instance_owners.setdefault(instance_key, set()).add(owner_id)
496499
if instance_key not in self._active_instances:
497500
self._active_instances.add(instance_key)
498501
if self._channel_refresh_task:
@@ -510,10 +513,11 @@ async def _register_instance(
510513
"_DataApiTargetAsync": "_DataApiTarget",
511514
}
512515
)
513-
async def _remove_instance_registration(
516+
def _remove_instance_registration(
514517
self,
515518
instance_id: str,
516-
owner: _DataApiTargetAsync | ExecuteQueryIteratorAsync,
519+
app_profile_id: Optional[str],
520+
owner_id: int,
517521
) -> bool:
518522
"""
519523
Removes an instance from the client's registered instances, to prevent
@@ -523,17 +527,17 @@ async def _remove_instance_registration(
523527
524528
Args:
525529
instance_id: id of the instance to remove
526-
owner: table that owns the instance. Owners will be tracked in
527-
_instance_owners, and instances will only be unregistered when all
528-
owners call _remove_instance_registration
530+
app_profile_id: id of the app profile calling the instance.
531+
owner_id: integer id of the object owning the instance. Can be
532+
obtained by calling id(owner)
529533
Returns:
530534
bool: True if instance was removed, else False
531535
"""
532536
instance_name = self._gapic_client.instance_path(self.project, instance_id)
533-
instance_key = _WarmedInstanceKey(instance_name, owner.app_profile_id)
537+
instance_key = _WarmedInstanceKey(instance_name, app_profile_id)
534538
owner_list = self._instance_owners.get(instance_key, set())
535539
try:
536-
owner_list.remove(id(owner))
540+
owner_list.remove(owner_id)
537541
if len(owner_list) == 0:
538542
self._active_instances.remove(instance_key)
539543
return True
@@ -1014,7 +1018,8 @@ def __init__(
10141018
self._register_instance_future = CrossSync.create_task(
10151019
self.client._register_instance,
10161020
self.instance_id,
1017-
self,
1021+
self.app_profile_id,
1022+
id(self),
10181023
sync_executor=self.client._executor,
10191024
)
10201025
except RuntimeError as e:
@@ -1725,7 +1730,9 @@ async def close(self):
17251730
"""
17261731
if self._register_instance_future:
17271732
self._register_instance_future.cancel()
1728-
await self.client._remove_instance_registration(self.instance_id, self)
1733+
self.client._remove_instance_registration(
1734+
self.instance_id, self.app_profile_id, id(self)
1735+
)
17291736

17301737
@CrossSync.convert(sync_name="__enter__")
17311738
async def __aenter__(self):

google/cloud/bigtable/data/_sync_autogen/client.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def _manage_channel(
354354
next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0)
355355

356356
def _register_instance(
357-
self, instance_id: str, owner: _DataApiTarget | ExecuteQueryIterator
357+
self, instance_id: str, app_profile_id: Optional[str], owner_id: int
358358
) -> None:
359359
"""Registers an instance with the client, and warms the channel for the instance
360360
The client will periodically refresh grpc channel used to make
@@ -363,12 +363,14 @@ def _register_instance(
363363
364364
Args:
365365
instance_id: id of the instance to register.
366-
owner: table that owns the instance. Owners will be tracked in
366+
app_profile_id: id of the app profile calling the instance.
367+
owner_id: integer id of the object owning the instance. Owners will be tracked in
367368
_instance_owners, and instances will only be unregistered when all
368-
owners call _remove_instance_registration"""
369+
owners call _remove_instance_registration. Can be obtained by calling
370+
id(owner)"""
369371
instance_name = self._gapic_client.instance_path(self.project, instance_id)
370-
instance_key = _WarmedInstanceKey(instance_name, owner.app_profile_id)
371-
self._instance_owners.setdefault(instance_key, set()).add(id(owner))
372+
instance_key = _WarmedInstanceKey(instance_name, app_profile_id)
373+
self._instance_owners.setdefault(instance_key, set()).add(owner_id)
372374
if instance_key not in self._active_instances:
373375
self._active_instances.add(instance_key)
374376
if self._channel_refresh_task:
@@ -377,7 +379,7 @@ def _register_instance(
377379
self._start_background_channel_refresh()
378380

379381
def _remove_instance_registration(
380-
self, instance_id: str, owner: _DataApiTarget | ExecuteQueryIterator
382+
self, instance_id: str, app_profile_id: Optional[str], owner_id: int
381383
) -> bool:
382384
"""Removes an instance from the client's registered instances, to prevent
383385
warming new channels for the instance
@@ -386,16 +388,16 @@ def _remove_instance_registration(
386388
387389
Args:
388390
instance_id: id of the instance to remove
389-
owner: table that owns the instance. Owners will be tracked in
390-
_instance_owners, and instances will only be unregistered when all
391-
owners call _remove_instance_registration
391+
app_profile_id: id of the app profile calling the instance.
392+
owner_id: integer id of the object owning the instance. Can be
393+
obtained by calling id(owner)
392394
Returns:
393395
bool: True if instance was removed, else False"""
394396
instance_name = self._gapic_client.instance_path(self.project, instance_id)
395-
instance_key = _WarmedInstanceKey(instance_name, owner.app_profile_id)
397+
instance_key = _WarmedInstanceKey(instance_name, app_profile_id)
396398
owner_list = self._instance_owners.get(instance_key, set())
397399
try:
398-
owner_list.remove(id(owner))
400+
owner_list.remove(owner_id)
399401
if len(owner_list) == 0:
400402
self._active_instances.remove(instance_key)
401403
return True
@@ -806,7 +808,8 @@ def __init__(
806808
self._register_instance_future = CrossSync._Sync_Impl.create_task(
807809
self.client._register_instance,
808810
self.instance_id,
809-
self,
811+
self.app_profile_id,
812+
id(self),
810813
sync_executor=self.client._executor,
811814
)
812815
except RuntimeError as e:
@@ -1460,7 +1463,9 @@ def close(self):
14601463
"""Called to close the Table instance and release any resources held by it."""
14611464
if self._register_instance_future:
14621465
self._register_instance_future.cancel()
1463-
self.client._remove_instance_registration(self.instance_id, self)
1466+
self.client._remove_instance_registration(
1467+
self.instance_id, self.app_profile_id, id(self)
1468+
)
14641469

14651470
def __enter__(self):
14661471
"""Implement async context manager protocol

google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __init__(
127127
self.has_received_token = False
128128
self._result_generator = self._next_impl()
129129
self._register_instance_task = None
130+
self._fully_consumed = False
130131
self._is_closed = False
131132
self._request_body = request_body
132133
self._attempt_timeout_gen = _attempt_timeout_generator(
@@ -145,7 +146,8 @@ def __init__(
145146
self._register_instance_task = CrossSync.create_task(
146147
self._client._register_instance,
147148
self._instance_id,
148-
self,
149+
self.app_profile_id,
150+
id(self),
149151
sync_executor=self._client._executor,
150152
)
151153
except RuntimeError as e:
@@ -193,39 +195,42 @@ async def _next_impl(self) -> CrossSync.Iterator[QueryResultRow]:
193195
Generator wrapping the response stream which parses the stream results
194196
and returns full `QueryResultRow`s.
195197
"""
196-
async for response in self._stream:
197-
try:
198-
# we've received a resume token, so we can finalize the metadata
199-
if self._final_metadata is None and _has_resume_token(response):
200-
self._finalize_metadata()
201-
202-
batches_to_parse = self._byte_cursor.consume(response)
203-
if not batches_to_parse:
204-
continue
205-
# metadata must be set at this point since there must be a resume_token
206-
# for byte_cursor to yield data
207-
if not self.metadata:
208-
raise ValueError(
209-
"Error parsing response before finalizing metadata"
198+
try:
199+
async for response in self._stream:
200+
try:
201+
# we've received a resume token, so we can finalize the metadata
202+
if self._final_metadata is None and _has_resume_token(response):
203+
self._finalize_metadata()
204+
205+
batches_to_parse = self._byte_cursor.consume(response)
206+
if not batches_to_parse:
207+
continue
208+
# metadata must be set at this point since there must be a resume_token
209+
# for byte_cursor to yield data
210+
if not self.metadata:
211+
raise ValueError(
212+
"Error parsing response before finalizing metadata"
213+
)
214+
results = self._reader.consume(
215+
batches_to_parse, self.metadata, self._column_info
210216
)
211-
results = self._reader.consume(
212-
batches_to_parse, self.metadata, self._column_info
213-
)
214-
if results is None:
215-
continue
216-
217-
except ValueError as e:
218-
raise InvalidExecuteQueryResponse(
219-
"Invalid ExecuteQuery response received"
220-
) from e
221-
222-
for result in results:
223-
yield result
224-
# this means the stream has finished with no responses. In that case we know the
225-
# latest_prepare_reponses was used successfully so we can finalize the metadata
226-
if self._final_metadata is None:
227-
self._finalize_metadata()
228-
await self.close()
217+
if results is None:
218+
continue
219+
220+
except ValueError as e:
221+
raise InvalidExecuteQueryResponse(
222+
"Invalid ExecuteQuery response received"
223+
) from e
224+
225+
for result in results:
226+
yield result
227+
# this means the stream has finished with no responses. In that case we know the
228+
# latest_prepare_reponses was used successfully so we can finalize the metadata
229+
if self._final_metadata is None:
230+
self._finalize_metadata()
231+
self._fully_consumed = True
232+
finally:
233+
self._close_internal()
229234

230235
@CrossSync.convert(sync_name="__next__", replace_symbols={"__anext__": "__next__"})
231236
async def __anext__(self) -> QueryResultRow:
@@ -285,15 +290,26 @@ def metadata(self) -> Metadata:
285290
@CrossSync.convert
286291
async def close(self) -> None:
287292
"""
288-
Cancel all background tasks. Should be called all rows were processed.
293+
Cancel all background tasks. Should be called after all rows were processed.
294+
295+
Called automatically by iterator
289296
290297
:raises: :class:`ValueError <exceptions.ValueError>` if called in an invalid state
291298
"""
299+
# this doesn't need to be async anymore but we wrap the sync api to avoid a breaking
300+
# change
301+
self._close_internal()
302+
303+
def _close_internal(self) -> None:
292304
if self._is_closed:
293305
return
294-
if not self._byte_cursor.empty():
306+
# Throw an error if the iterator has been successfully consumed but there is
307+
# still buffered data
308+
if self._fully_consumed and not self._byte_cursor.empty():
295309
raise ValueError("Unexpected buffered data at end of executeQuery reqest")
296310
self._is_closed = True
297311
if self._register_instance_task is not None:
298312
self._register_instance_task.cancel()
299-
await self._client._remove_instance_registration(self._instance_id, self)
313+
self._client._remove_instance_registration(
314+
self._instance_id, self.app_profile_id, id(self)
315+
)

google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
self.has_received_token = False
103103
self._result_generator = self._next_impl()
104104
self._register_instance_task = None
105+
self._fully_consumed = False
105106
self._is_closed = False
106107
self._request_body = request_body
107108
self._attempt_timeout_gen = _attempt_timeout_generator(
@@ -120,7 +121,8 @@ def __init__(
120121
self._register_instance_task = CrossSync._Sync_Impl.create_task(
121122
self._client._register_instance,
122123
self._instance_id,
123-
self,
124+
self.app_profile_id,
125+
id(self),
124126
sync_executor=self._client._executor,
125127
)
126128
except RuntimeError as e:
@@ -159,31 +161,34 @@ def _make_request_with_resume_token(self):
159161
def _next_impl(self) -> CrossSync._Sync_Impl.Iterator[QueryResultRow]:
160162
"""Generator wrapping the response stream which parses the stream results
161163
and returns full `QueryResultRow`s."""
162-
for response in self._stream:
163-
try:
164-
if self._final_metadata is None and _has_resume_token(response):
165-
self._finalize_metadata()
166-
batches_to_parse = self._byte_cursor.consume(response)
167-
if not batches_to_parse:
168-
continue
169-
if not self.metadata:
170-
raise ValueError(
171-
"Error parsing response before finalizing metadata"
164+
try:
165+
for response in self._stream:
166+
try:
167+
if self._final_metadata is None and _has_resume_token(response):
168+
self._finalize_metadata()
169+
batches_to_parse = self._byte_cursor.consume(response)
170+
if not batches_to_parse:
171+
continue
172+
if not self.metadata:
173+
raise ValueError(
174+
"Error parsing response before finalizing metadata"
175+
)
176+
results = self._reader.consume(
177+
batches_to_parse, self.metadata, self._column_info
172178
)
173-
results = self._reader.consume(
174-
batches_to_parse, self.metadata, self._column_info
175-
)
176-
if results is None:
177-
continue
178-
except ValueError as e:
179-
raise InvalidExecuteQueryResponse(
180-
"Invalid ExecuteQuery response received"
181-
) from e
182-
for result in results:
183-
yield result
184-
if self._final_metadata is None:
185-
self._finalize_metadata()
186-
self.close()
179+
if results is None:
180+
continue
181+
except ValueError as e:
182+
raise InvalidExecuteQueryResponse(
183+
"Invalid ExecuteQuery response received"
184+
) from e
185+
for result in results:
186+
yield result
187+
if self._final_metadata is None:
188+
self._finalize_metadata()
189+
self._fully_consumed = True
190+
finally:
191+
self._close_internal()
187192

188193
def __next__(self) -> QueryResultRow:
189194
"""Yields QueryResultRows representing the results of the query.
@@ -233,15 +238,22 @@ def metadata(self) -> Metadata:
233238
return self._final_metadata
234239

235240
def close(self) -> None:
236-
"""Cancel all background tasks. Should be called all rows were processed.
241+
"""Cancel all background tasks. Should be called after all rows were processed.
242+
243+
Called automatically by iterator
237244
238245
:raises: :class:`ValueError <exceptions.ValueError>` if called in an invalid state
239246
"""
247+
self._close_internal()
248+
249+
def _close_internal(self) -> None:
240250
if self._is_closed:
241251
return
242-
if not self._byte_cursor.empty():
252+
if self._fully_consumed and (not self._byte_cursor.empty()):
243253
raise ValueError("Unexpected buffered data at end of executeQuery reqest")
244254
self._is_closed = True
245255
if self._register_instance_task is not None:
246256
self._register_instance_task.cancel()
247-
self._client._remove_instance_registration(self._instance_id, self)
257+
self._client._remove_instance_registration(
258+
self._instance_id, self.app_profile_id, id(self)
259+
)

0 commit comments

Comments
 (0)