Skip to content

Commit 6fa3008

Browse files
fix: refactor channel refresh (#1174)
1 parent 71efcb6 commit 6fa3008

File tree

13 files changed

+712
-157
lines changed

13 files changed

+712
-157
lines changed
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
from __future__ import annotations
16+
17+
from typing import Callable
18+
19+
from google.cloud.bigtable.data._cross_sync import CrossSync
20+
21+
from grpc import ChannelConnectivity
22+
23+
if CrossSync.is_async:
24+
from grpc.aio import Channel
25+
else:
26+
from grpc import Channel
27+
28+
__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._swappable_channel"
29+
30+
31+
@CrossSync.convert_class(sync_name="_WrappedChannel", rm_aio=True)
32+
class _AsyncWrappedChannel(Channel):
33+
"""
34+
A wrapper around a gRPC channel. All methods are passed
35+
through to the underlying channel.
36+
"""
37+
38+
def __init__(self, channel: Channel):
39+
self._channel = channel
40+
41+
def unary_unary(self, *args, **kwargs):
42+
return self._channel.unary_unary(*args, **kwargs)
43+
44+
def unary_stream(self, *args, **kwargs):
45+
return self._channel.unary_stream(*args, **kwargs)
46+
47+
def stream_unary(self, *args, **kwargs):
48+
return self._channel.stream_unary(*args, **kwargs)
49+
50+
def stream_stream(self, *args, **kwargs):
51+
return self._channel.stream_stream(*args, **kwargs)
52+
53+
async def channel_ready(self):
54+
return await self._channel.channel_ready()
55+
56+
@CrossSync.convert(
57+
sync_name="__enter__", replace_symbols={"__aenter__": "__enter__"}
58+
)
59+
async def __aenter__(self):
60+
await self._channel.__aenter__()
61+
return self
62+
63+
@CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"})
64+
async def __aexit__(self, exc_type, exc_val, exc_tb):
65+
return await self._channel.__aexit__(exc_type, exc_val, exc_tb)
66+
67+
def get_state(self, try_to_connect: bool = False) -> ChannelConnectivity:
68+
return self._channel.get_state(try_to_connect=try_to_connect)
69+
70+
async def wait_for_state_change(self, last_observed_state):
71+
return await self._channel.wait_for_state_change(last_observed_state)
72+
73+
def __getattr__(self, name):
74+
return getattr(self._channel, name)
75+
76+
async def close(self, grace=None):
77+
if CrossSync.is_async:
78+
return await self._channel.close(grace=grace)
79+
else:
80+
# grace not supported by sync version
81+
return self._channel.close()
82+
83+
if not CrossSync.is_async:
84+
# add required sync methods
85+
86+
def subscribe(self, callback, try_to_connect=False):
87+
return self._channel.subscribe(callback, try_to_connect)
88+
89+
def unsubscribe(self, callback):
90+
return self._channel.unsubscribe(callback)
91+
92+
93+
@CrossSync.convert_class(
94+
sync_name="SwappableChannel",
95+
replace_symbols={"_AsyncWrappedChannel": "_WrappedChannel"},
96+
)
97+
class AsyncSwappableChannel(_AsyncWrappedChannel):
98+
"""
99+
Provides a grpc channel wrapper, that allows the internal channel to be swapped out
100+
101+
Args:
102+
- channel_fn: a nullary function that returns a new channel instance.
103+
It should be a partial with all channel configuration arguments built-in
104+
"""
105+
106+
def __init__(self, channel_fn: Callable[[], Channel]):
107+
self._channel_fn = channel_fn
108+
self._channel = channel_fn()
109+
110+
def create_channel(self) -> Channel:
111+
"""
112+
Create a fresh channel using the stored `channel_fn` partial
113+
"""
114+
new_channel = self._channel_fn()
115+
if CrossSync.is_async:
116+
# copy over interceptors
117+
# this is needed because of how gapic attaches the LoggingClientAIOInterceptor
118+
# sync channels add interceptors by wrapping, so this step isn't needed
119+
new_channel._unary_unary_interceptors = (
120+
self._channel._unary_unary_interceptors
121+
)
122+
new_channel._unary_stream_interceptors = (
123+
self._channel._unary_stream_interceptors
124+
)
125+
new_channel._stream_unary_interceptors = (
126+
self._channel._stream_unary_interceptors
127+
)
128+
new_channel._stream_stream_interceptors = (
129+
self._channel._stream_stream_interceptors
130+
)
131+
return new_channel
132+
133+
def swap_channel(self, new_channel: Channel) -> Channel:
134+
"""
135+
Replace the wrapped channel with a new instance. Typically created using `create_channel`
136+
"""
137+
old_channel = self._channel
138+
self._channel = new_channel
139+
return old_channel

google/cloud/bigtable/data/_async/client.py

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,22 @@
9292
from google.cloud.bigtable_v2.services.bigtable.transports import (
9393
BigtableGrpcAsyncIOTransport as TransportType,
9494
)
95+
from google.cloud.bigtable_v2.services.bigtable import (
96+
BigtableAsyncClient as GapicClient,
97+
)
9598
from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE
99+
from google.cloud.bigtable.data._async._swappable_channel import (
100+
AsyncSwappableChannel,
101+
)
96102
else:
97103
from typing import Iterable # noqa: F401
98104
from grpc import insecure_channel
99-
from grpc import intercept_channel
100105
from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore
106+
from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient # type: ignore
101107
from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE
108+
from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( # noqa: F401
109+
SwappableChannel,
110+
)
102111

103112

104113
if TYPE_CHECKING:
@@ -182,7 +191,6 @@ def __init__(
182191
client_options = cast(
183192
Optional[client_options_lib.ClientOptions], client_options
184193
)
185-
custom_channel = None
186194
self._emulator_host = os.getenv(BIGTABLE_EMULATOR)
187195
if self._emulator_host is not None:
188196
warnings.warn(
@@ -191,24 +199,24 @@ def __init__(
191199
stacklevel=2,
192200
)
193201
# use insecure channel if emulator is set
194-
custom_channel = insecure_channel(self._emulator_host)
195202
if credentials is None:
196203
credentials = google.auth.credentials.AnonymousCredentials()
197204
if project is None:
198205
project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT
206+
199207
# initialize client
200208
ClientWithProject.__init__(
201209
self,
202210
credentials=credentials,
203211
project=project,
204212
client_options=client_options,
205213
)
206-
self._gapic_client = CrossSync.GapicClient(
214+
self._gapic_client = GapicClient(
207215
credentials=credentials,
208216
client_options=client_options,
209217
client_info=self.client_info,
210218
transport=lambda *args, **kwargs: TransportType(
211-
*args, **kwargs, channel=custom_channel
219+
*args, **kwargs, channel=self._build_grpc_channel
212220
),
213221
)
214222
if (
@@ -234,7 +242,7 @@ def __init__(
234242
self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {}
235243
self._channel_init_time = time.monotonic()
236244
self._channel_refresh_task: CrossSync.Task[None] | None = None
237-
self._executor = (
245+
self._executor: concurrent.futures.ThreadPoolExecutor | None = (
238246
concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None
239247
)
240248
if self._emulator_host is None:
@@ -249,6 +257,29 @@ def __init__(
249257
stacklevel=2,
250258
)
251259

260+
@CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"})
261+
def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel:
262+
"""
263+
This method is called by the gapic transport to create a grpc channel.
264+
265+
The init arguments passed down are captured in a partial used by AsyncSwappableChannel
266+
to create new channel instances in the future, as part of the channel refresh logic
267+
268+
Emulators always use an inseucre channel
269+
270+
Args:
271+
- *args: positional arguments passed by the gapic layer to create a new channel with
272+
- **kwargs: keyword arguments passed by the gapic layer to create a new channel with
273+
Returns:
274+
a custom wrapped swappable channel
275+
"""
276+
if self._emulator_host is not None:
277+
# emulators use insecure channel
278+
create_channel_fn = partial(insecure_channel, self._emulator_host)
279+
else:
280+
create_channel_fn = partial(TransportType.create_channel, *args, **kwargs)
281+
return AsyncSwappableChannel(create_channel_fn)
282+
252283
@property
253284
def universe_domain(self) -> str:
254285
"""Return the universe domain used by the client instance.
@@ -364,7 +395,12 @@ async def _ping_and_warm_instances(
364395
)
365396
return [r or None for r in result_list]
366397

367-
@CrossSync.convert
398+
def _invalidate_channel_stubs(self):
399+
"""Helper to reset the cached stubs. Needed when changing out the grpc channel"""
400+
self.transport._stubs = {}
401+
self.transport._prep_wrapped_messages(self.client_info)
402+
403+
@CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"})
368404
async def _manage_channel(
369405
self,
370406
refresh_interval_min: float = 60 * 35,
@@ -389,13 +425,17 @@ async def _manage_channel(
389425
grace_period: time to allow previous channel to serve existing
390426
requests before closing, in seconds
391427
"""
428+
if not isinstance(self.transport.grpc_channel, AsyncSwappableChannel):
429+
warnings.warn("Channel does not support auto-refresh.")
430+
return
431+
super_channel: AsyncSwappableChannel = self.transport.grpc_channel
392432
first_refresh = self._channel_init_time + random.uniform(
393433
refresh_interval_min, refresh_interval_max
394434
)
395435
next_sleep = max(first_refresh - time.monotonic(), 0)
396436
if next_sleep > 0:
397437
# warm the current channel immediately
398-
await self._ping_and_warm_instances(channel=self.transport.grpc_channel)
438+
await self._ping_and_warm_instances(channel=super_channel)
399439
# continuously refresh the channel every `refresh_interval` seconds
400440
while not self._is_closed.is_set():
401441
await CrossSync.event_wait(
@@ -408,32 +448,19 @@ async def _manage_channel(
408448
break
409449
start_timestamp = time.monotonic()
410450
# prepare new channel for use
411-
# TODO: refactor to avoid using internal references: https://github.com/googleapis/python-bigtable/issues/1094
412-
old_channel = self.transport.grpc_channel
413-
new_channel = self.transport.create_channel()
414-
if CrossSync.is_async:
415-
new_channel._unary_unary_interceptors.append(
416-
self.transport._interceptor
417-
)
418-
else:
419-
new_channel = intercept_channel(
420-
new_channel, self.transport._interceptor
421-
)
451+
new_channel = super_channel.create_channel()
422452
await self._ping_and_warm_instances(channel=new_channel)
423453
# cycle channel out of use, with long grace window before closure
424-
self.transport._grpc_channel = new_channel
425-
self.transport._logged_channel = new_channel
426-
# invalidate caches
427-
self.transport._stubs = {}
428-
self.transport._prep_wrapped_messages(self.client_info)
454+
old_channel = super_channel.swap_channel(new_channel)
455+
self._invalidate_channel_stubs()
429456
# give old_channel a chance to complete existing rpcs
430457
if CrossSync.is_async:
431458
await old_channel.close(grace_period)
432459
else:
433460
if grace_period:
434461
self._is_closed.wait(grace_period) # type: ignore
435462
old_channel.close() # type: ignore
436-
# subtract thed time spent waiting for the channel to be replaced
463+
# subtract the time spent waiting for the channel to be replaced
437464
next_refresh = random.uniform(refresh_interval_min, refresh_interval_max)
438465
next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0)
439466

@@ -895,24 +922,32 @@ def __init__(
895922
self.table_name = self.client._gapic_client.table_path(
896923
self.client.project, instance_id, table_id
897924
)
898-
self.app_profile_id = app_profile_id
925+
self.app_profile_id: str | None = app_profile_id
899926

900-
self.default_operation_timeout = default_operation_timeout
901-
self.default_attempt_timeout = default_attempt_timeout
902-
self.default_read_rows_operation_timeout = default_read_rows_operation_timeout
903-
self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout
904-
self.default_mutate_rows_operation_timeout = (
927+
self.default_operation_timeout: float = default_operation_timeout
928+
self.default_attempt_timeout: float | None = default_attempt_timeout
929+
self.default_read_rows_operation_timeout: float = (
930+
default_read_rows_operation_timeout
931+
)
932+
self.default_read_rows_attempt_timeout: float | None = (
933+
default_read_rows_attempt_timeout
934+
)
935+
self.default_mutate_rows_operation_timeout: float = (
905936
default_mutate_rows_operation_timeout
906937
)
907-
self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout
938+
self.default_mutate_rows_attempt_timeout: float | None = (
939+
default_mutate_rows_attempt_timeout
940+
)
908941

909-
self.default_read_rows_retryable_errors = (
942+
self.default_read_rows_retryable_errors: Sequence[type[Exception]] = (
910943
default_read_rows_retryable_errors or ()
911944
)
912-
self.default_mutate_rows_retryable_errors = (
945+
self.default_mutate_rows_retryable_errors: Sequence[type[Exception]] = (
913946
default_mutate_rows_retryable_errors or ()
914947
)
915-
self.default_retryable_errors = default_retryable_errors or ()
948+
self.default_retryable_errors: Sequence[type[Exception]] = (
949+
default_retryable_errors or ()
950+
)
916951

917952
try:
918953
self._register_instance_future = CrossSync.create_task(

0 commit comments

Comments
 (0)