Skip to content

Commit c85c31e

Browse files
committed
feat: Add support for Proto and Enum types
1 parent 451fd97 commit c85c31e

File tree

17 files changed

+1320
-61
lines changed

17 files changed

+1320
-61
lines changed

google/cloud/bigtable/data/_async/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ async def execute_query(
657657
DeadlineExceeded,
658658
ServiceUnavailable,
659659
),
660+
column_info: dict[str, Any] | None = None,
660661
) -> "ExecuteQueryIteratorAsync":
661662
"""
662663
Executes an SQL query on an instance.
@@ -705,6 +706,13 @@ async def execute_query(
705706
If None, defaults to prepare_operation_timeout.
706707
prepare_retryable_errors: a list of errors that will be retried if encountered during prepareQuery.
707708
Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable)
709+
column_info: Dictionary with mappings between column names and additional column information.
710+
An object where column names as keys and custom objects as corresponding
711+
values for deserialization. It's specifically useful for data types like
712+
protobuf where deserialization logic is on user-specific code. When provided,
713+
the custom object enables deserialization of backend-received column data.
714+
If not provided, data remains serialized as bytes for Proto Messages and
715+
integer for Proto Enums.
708716
Returns:
709717
ExecuteQueryIteratorAsync: an asynchronous iterator that yields rows returned by the query
710718
Raises:
@@ -771,6 +779,7 @@ async def execute_query(
771779
attempt_timeout,
772780
operation_timeout,
773781
retryable_excs=retryable_excs,
782+
column_info=column_info,
774783
)
775784

776785
@CrossSync.convert(sync_name="__enter__")

google/cloud/bigtable/data/_sync_autogen/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ def execute_query(
485485
DeadlineExceeded,
486486
ServiceUnavailable,
487487
),
488+
column_info: dict[str, Any] | None = None,
488489
) -> "ExecuteQueryIterator":
489490
"""Executes an SQL query on an instance.
490491
Returns an iterator to asynchronously stream back columns from selected rows.
@@ -532,6 +533,13 @@ def execute_query(
532533
If None, defaults to prepare_operation_timeout.
533534
prepare_retryable_errors: a list of errors that will be retried if encountered during prepareQuery.
534535
Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable)
536+
column_info: Dictionary with mappings between column names and additional column information.
537+
An object where column names as keys and custom objects as corresponding
538+
values for deserialization. It's specifically useful for data types like
539+
protobuf where deserialization logic is on user-specific code. When provided,
540+
the custom object enables deserialization of backend-received column data.
541+
If not provided, data remains serialized as bytes for Proto Messages and
542+
integer for Proto Enums.
535543
Returns:
536544
ExecuteQueryIterator: an asynchronous iterator that yields rows returned by the query
537545
Raises:
@@ -592,6 +600,7 @@ def execute_query(
592600
attempt_timeout,
593601
operation_timeout,
594602
retryable_excs=retryable_excs,
603+
column_info=column_info,
595604
)
596605

597606
def __enter__(self):

google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
operation_timeout: float,
8888
req_metadata: Sequence[Tuple[str, str]] = (),
8989
retryable_excs: Sequence[type[Exception]] = (),
90+
column_info: Dict[str, Any] | None = None,
9091
) -> None:
9192
"""
9293
Collects responses from ExecuteQuery requests and parses them into QueryResultRows.
@@ -107,6 +108,8 @@ def __init__(
107108
Failed requests will be retried within the budget
108109
req_metadata: metadata used while sending the gRPC request
109110
retryable_excs: a list of errors that will be retried if encountered.
111+
column_info: dict with mappings between column names and additional column information
112+
for protobuf deserialization.
110113
Raises:
111114
{NO_LOOP}
112115
:class:`ValueError <exceptions.ValueError>` as a safeguard if data is processed in an unexpected state
@@ -135,6 +138,7 @@ def __init__(
135138
exception_factory=_retry_exception_factory,
136139
)
137140
self._req_metadata = req_metadata
141+
self._column_info = column_info
138142
try:
139143
self._register_instance_task = CrossSync.create_task(
140144
self._client._register_instance,
@@ -202,7 +206,9 @@ async def _next_impl(self) -> CrossSync.Iterator[QueryResultRow]:
202206
raise ValueError(
203207
"Error parsing response before finalizing metadata"
204208
)
205-
results = self._reader.consume(batches_to_parse, self.metadata)
209+
results = self._reader.consume(
210+
batches_to_parse, self.metadata, self._column_info
211+
)
206212
if results is None:
207213
continue
208214

google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py

Lines changed: 105 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
from typing import Any, Callable, Dict, Type
17+
18+
from google.protobuf.message import Message
19+
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper
1620
from google.cloud.bigtable.data.execute_query.values import Struct
1721
from google.cloud.bigtable.data.execute_query.metadata import SqlType
1822
from google.cloud.bigtable_v2 import Value as PBValue
@@ -30,24 +34,36 @@
3034
SqlType.Struct: "array_value",
3135
SqlType.Array: "array_value",
3236
SqlType.Map: "array_value",
37+
SqlType.Proto: "bytes_value",
38+
SqlType.Enum: "int_value",
3339
}
3440

3541

36-
def _parse_array_type(value: PBValue, metadata_type: SqlType.Array) -> Any:
42+
def _parse_array_type(
43+
value: PBValue,
44+
metadata_type: SqlType.Array,
45+
column_name: str | None,
46+
column_info: dict[str, Any] | None = None,
47+
) -> Any:
3748
"""
3849
used for parsing an array represented as a protobuf to a python list.
3950
"""
4051
return list(
4152
map(
4253
lambda val: _parse_pb_value_to_python_value(
43-
val, metadata_type.element_type
54+
val, metadata_type.element_type, column_name, column_info
4455
),
4556
value.array_value.values,
4657
)
4758
)
4859

4960

50-
def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any:
61+
def _parse_map_type(
62+
value: PBValue,
63+
metadata_type: SqlType.Map,
64+
column_name: str | None,
65+
column_info: dict[str, Any] | None = None,
66+
) -> Any:
5167
"""
5268
used for parsing a map represented as a protobuf to a python dict.
5369
@@ -64,10 +80,16 @@ def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any:
6480
map(
6581
lambda map_entry: (
6682
_parse_pb_value_to_python_value(
67-
map_entry.array_value.values[0], metadata_type.key_type
83+
map_entry.array_value.values[0],
84+
metadata_type.key_type,
85+
f"{column_name}.key" if column_name is not None else None,
86+
column_info,
6887
),
6988
_parse_pb_value_to_python_value(
70-
map_entry.array_value.values[1], metadata_type.value_type
89+
map_entry.array_value.values[1],
90+
metadata_type.value_type,
91+
f"{column_name}.value" if column_name is not None else None,
92+
column_info,
7193
),
7294
),
7395
value.array_value.values,
@@ -77,7 +99,12 @@ def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any:
7799
raise ValueError("Invalid map entry - less or more than two values.")
78100

79101

80-
def _parse_struct_type(value: PBValue, metadata_type: SqlType.Struct) -> Struct:
102+
def _parse_struct_type(
103+
value: PBValue,
104+
metadata_type: SqlType.Struct,
105+
column_name: str | None,
106+
column_info: dict[str, Any] | None = None,
107+
) -> Struct:
81108
"""
82109
used for parsing a struct represented as a protobuf to a
83110
google.cloud.bigtable.data.execute_query.Struct
@@ -88,29 +115,96 @@ def _parse_struct_type(value: PBValue, metadata_type: SqlType.Struct) -> Struct:
88115
struct = Struct()
89116
for value, field in zip(value.array_value.values, metadata_type.fields):
90117
field_name, field_type = field
91-
struct.add_field(field_name, _parse_pb_value_to_python_value(value, field_type))
118+
nested_column_name: str | None
119+
if column_name is None:
120+
nested_column_name = None
121+
else:
122+
# qualify the column name for nested lookups
123+
nested_column_name = (
124+
f"{column_name}.{field_name}" if field_name else column_name
125+
)
126+
struct.add_field(
127+
field_name,
128+
_parse_pb_value_to_python_value(
129+
value, field_type, nested_column_name, column_info
130+
),
131+
)
92132

93133
return struct
94134

95135

96136
def _parse_timestamp_type(
97-
value: PBValue, metadata_type: SqlType.Timestamp
137+
value: PBValue,
138+
metadata_type: SqlType.Timestamp,
139+
column_name: str | None,
140+
column_info: dict[str, Any] | None = None,
98141
) -> DatetimeWithNanoseconds:
99142
"""
100143
used for parsing a timestamp represented as a protobuf to DatetimeWithNanoseconds
101144
"""
102145
return DatetimeWithNanoseconds.from_timestamp_pb(value.timestamp_value)
103146

104147

105-
_TYPE_PARSERS: Dict[Type[SqlType.Type], Callable[[PBValue, Any], Any]] = {
148+
def _parse_proto_type(
149+
value: PBValue,
150+
metadata_type: SqlType.Proto,
151+
column_name: str | None,
152+
column_info: dict[str, Any] | None = None,
153+
) -> Message | bytes:
154+
"""
155+
Parses a serialized protobuf message into a Message object.
156+
"""
157+
if (
158+
column_name is not None
159+
and column_info is not None
160+
and column_info.get(column_name) is not None
161+
):
162+
default_proto_message = column_info.get(column_name)
163+
if isinstance(default_proto_message, Message):
164+
proto_message = type(default_proto_message)()
165+
proto_message.ParseFromString(value.bytes_value)
166+
return proto_message
167+
return value.bytes_value
168+
169+
170+
def _parse_enum_type(
171+
value: PBValue,
172+
metadata_type: SqlType.Enum,
173+
column_name: str | None,
174+
column_info: dict[str, Any] | None = None,
175+
) -> int | Any:
176+
"""
177+
Parses an integer value into a Protobuf enum.
178+
"""
179+
if (
180+
column_name is not None
181+
and column_info is not None
182+
and column_info.get(column_name) is not None
183+
):
184+
proto_enum = column_info.get(column_name)
185+
if isinstance(proto_enum, EnumTypeWrapper):
186+
return proto_enum.Name(value.int_value)
187+
return value.int_value
188+
189+
190+
_TYPE_PARSERS: Dict[
191+
Type[SqlType.Type], Callable[[PBValue, Any, str | None, dict[str, Any] | None], Any]
192+
] = {
106193
SqlType.Timestamp: _parse_timestamp_type,
107194
SqlType.Struct: _parse_struct_type,
108195
SqlType.Array: _parse_array_type,
109196
SqlType.Map: _parse_map_type,
197+
SqlType.Proto: _parse_proto_type,
198+
SqlType.Enum: _parse_enum_type,
110199
}
111200

112201

113-
def _parse_pb_value_to_python_value(value: PBValue, metadata_type: SqlType.Type) -> Any:
202+
def _parse_pb_value_to_python_value(
203+
value: PBValue,
204+
metadata_type: SqlType.Type,
205+
column_name: str | None,
206+
column_info: dict[str, Any] | None = None,
207+
) -> Any:
114208
"""
115209
used for converting the value represented as a protobufs to a python object.
116210
"""
@@ -126,7 +220,7 @@ def _parse_pb_value_to_python_value(value: PBValue, metadata_type: SqlType.Type)
126220

127221
if kind in _TYPE_PARSERS:
128222
parser = _TYPE_PARSERS[kind]
129-
return parser(value, metadata_type)
223+
return parser(value, metadata_type, column_name, column_info)
130224
elif kind in _REQUIRED_PROTO_FIELDS:
131225
field_name = _REQUIRED_PROTO_FIELDS[kind]
132226
return getattr(value, field_name)

google/cloud/bigtable/data/execute_query/_reader.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
from typing import (
17+
Any,
1618
List,
1719
TypeVar,
1820
Generic,
@@ -54,7 +56,10 @@ class _Reader(ABC, Generic[T]):
5456

5557
@abstractmethod
5658
def consume(
57-
self, batches_to_consume: List[bytes], metadata: Metadata
59+
self,
60+
batches_to_consume: List[bytes],
61+
metadata: Metadata,
62+
column_info: dict[str, Any] | None = None,
5863
) -> Optional[Iterable[T]]:
5964
"""This method receives a list of batches of bytes to be parsed as ProtoRows messages.
6065
It then uses the metadata to group the values in the parsed messages into rows. Returns
@@ -64,6 +69,8 @@ def consume(
6469
:meth:`google.cloud.bigtable.byte_cursor._ByteCursor.consume`
6570
method.
6671
metadata: metadata used to transform values to rows
72+
column_info: (Optional) dict with mappings between column names and additional column information
73+
for protobuf deserialization.
6774
6875
Returns:
6976
Iterable[T] or None: Iterable if gathered values can form one or more instances of T,
@@ -89,7 +96,10 @@ def _parse_proto_rows(self, bytes_to_parse: bytes) -> Iterable[PBValue]:
8996
return proto_rows.values
9097

9198
def _construct_query_result_row(
92-
self, values: Sequence[PBValue], metadata: Metadata
99+
self,
100+
values: Sequence[PBValue],
101+
metadata: Metadata,
102+
column_info: dict[str, Any] | None = None,
93103
) -> QueryResultRow:
94104
result = QueryResultRow()
95105
columns = metadata.columns
@@ -99,20 +109,29 @@ def _construct_query_result_row(
99109
), "This function should be called only when count of values matches count of columns."
100110

101111
for column, value in zip(columns, values):
102-
parsed_value = _parse_pb_value_to_python_value(value, column.column_type)
112+
parsed_value = _parse_pb_value_to_python_value(
113+
value, column.column_type, column.column_name, column_info
114+
)
103115
result.add_field(column.column_name, parsed_value)
104116
return result
105117

106118
def consume(
107-
self, batches_to_consume: List[bytes], metadata: Metadata
119+
self,
120+
batches_to_consume: List[bytes],
121+
metadata: Metadata,
122+
column_info: dict[str, Any] | None = None,
108123
) -> Optional[Iterable[QueryResultRow]]:
109124
num_columns = len(metadata.columns)
110125
rows = []
111126
for batch_bytes in batches_to_consume:
112127
values = self._parse_proto_rows(batch_bytes)
113128
for row_data in batched(values, n=num_columns):
114129
if len(row_data) == num_columns:
115-
rows.append(self._construct_query_result_row(row_data, metadata))
130+
rows.append(
131+
self._construct_query_result_row(
132+
row_data, metadata, column_info
133+
)
134+
)
116135
else:
117136
raise ValueError(
118137
"Unexpected error, recieved bad number of values. "

0 commit comments

Comments
 (0)