Skip to content

Commit 3f70b2f

Browse files
authored
perf(spanner): optimize query result decoding (#17375)
Work in progress. Optimizes the decoding and reading of (large) result sets for Spanner. <img width="1643" height="943" alt="image" src="https://github.com/user-attachments/assets/81997b8d-f77f-4523-acb2-e44fdccb939b" />
1 parent b23bfa4 commit 3f70b2f

5 files changed

Lines changed: 292 additions & 74 deletions

File tree

packages/google-cloud-spanner/google/cloud/spanner_v1/_async/streamed.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,39 @@ def _merge_values(self, values):
129129
decoders = self._decoders
130130
width = len(self.fields)
131131
index = len(self._current_row)
132-
for value in values:
133-
if self._lazy_decode:
134-
self._current_row.append(value)
135-
else:
136-
self._current_row.append(_parse_nullable(value, decoders[index]))
137-
index += 1
138-
if index == width:
139-
self._rows.append(self._current_row)
140-
self._current_row = []
141-
index = 0
132+
current_row = self._current_row
133+
rows = self._rows
134+
135+
current_row_append = current_row.append
136+
rows_append = rows.append
137+
138+
if self._lazy_decode:
139+
for value in values:
140+
current_row_append(value)
141+
index += 1
142+
if index == width:
143+
rows_append(current_row)
144+
current_row = []
145+
current_row_append = current_row.append
146+
index = 0
147+
else:
148+
for value in values:
149+
# Note: We manually check value.HasField("null_value") here instead of
150+
# wrapping every decoder in _parse_nullable to avoid the overhead of
151+
# an extra Python function call layer for every cell value decoded in this loop.
152+
# If the nullable check logic is updated in _parse_nullable, update this check.
153+
if value.HasField("null_value"):
154+
current_row_append(None)
155+
else:
156+
current_row_append(decoders[index](value))
157+
index += 1
158+
if index == width:
159+
rows_append(current_row)
160+
current_row = []
161+
current_row_append = current_row.append
162+
index = 0
163+
164+
self._current_row = current_row
142165

143166
@CrossSync.convert
144167
async def _consume_next(self):

packages/google-cloud-spanner/google/cloud/spanner_v1/_helpers.py

Lines changed: 88 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
import decimal
2020
import logging
2121
import math
22+
import operator
2223
import threading
2324
import time
2425
import uuid
2526
from contextlib import contextmanager
2627

2728
from google.api_core import datetime_helpers
2829
from google.api_core.exceptions import Aborted
29-
from google.cloud._helpers import _date_from_iso8601_date
3030
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper
3131
from google.protobuf.message import DecodeError, Message
3232
from google.protobuf.struct_pb2 import ListValue, Value
@@ -465,6 +465,12 @@ def _parse_value_pb(value_pb, field_type, field_name, column_info=None):
465465
return _parse_nullable(value_pb, decoder)
466466

467467

468+
_date_fromisoformat = datetime.date.fromisoformat
469+
_Decimal = decimal.Decimal
470+
_json_from_str = JsonObject.from_str
471+
_uuid_UUID = uuid.UUID
472+
473+
468474
def _get_type_decoder(field_type, field_name, column_info=None):
469475
"""Returns a function that converts a Value protobuf to cell data.
470476
@@ -489,28 +495,30 @@ def _get_type_decoder(field_type, field_name, column_info=None):
489495
"""
490496

491497
type_code = field_type.code
498+
# Note: STRING and BOOL use operator.attrgetter because direct attribute extraction
499+
# is faster in Python. Other types require type transformation, so they use lambdas.
492500
if type_code == TypeCode.STRING:
493-
return _parse_string
501+
return operator.attrgetter("string_value")
494502
elif type_code == TypeCode.BYTES:
495-
return _parse_bytes
503+
return lambda value_pb: value_pb.string_value.encode("utf8")
496504
elif type_code == TypeCode.BOOL:
497-
return _parse_bool
505+
return operator.attrgetter("bool_value")
498506
elif type_code == TypeCode.INT64:
499-
return _parse_int64
507+
return lambda value_pb: int(value_pb.string_value)
500508
elif type_code == TypeCode.FLOAT64:
501509
return _parse_float
502510
elif type_code == TypeCode.FLOAT32:
503511
return _parse_float
504512
elif type_code == TypeCode.DATE:
505-
return _parse_date
513+
return lambda value_pb: _date_fromisoformat(value_pb.string_value)
506514
elif type_code == TypeCode.TIMESTAMP:
507515
return _parse_timestamp
508516
elif type_code == TypeCode.NUMERIC:
509-
return _parse_numeric
517+
return lambda value_pb: _Decimal(value_pb.string_value)
510518
elif type_code == TypeCode.JSON:
511-
return _parse_json
519+
return lambda value_pb: _json_from_str(value_pb.string_value)
512520
elif type_code == TypeCode.UUID:
513-
return _parse_uuid
521+
return lambda value_pb: _uuid_UUID(value_pb.string_value)
514522
elif type_code == TypeCode.PROTO:
515523
return lambda value_pb: _parse_proto(value_pb, column_info, field_name)
516524
elif type_code == TypeCode.ENUM:
@@ -553,48 +561,81 @@ def _parse_list_value_pbs(rows, row_type):
553561
return result
554562

555563

556-
def _parse_string(value_pb) -> str:
557-
return value_pb.string_value
558-
559-
560-
def _parse_bytes(value_pb):
561-
return value_pb.string_value.encode("utf8")
562-
563-
564-
def _parse_bool(value_pb) -> bool:
565-
return value_pb.bool_value
566-
567-
568-
def _parse_int64(value_pb) -> int:
569-
return int(value_pb.string_value)
570-
571-
572564
def _parse_float(value_pb) -> float:
573-
if value_pb.HasField("string_value"):
574-
return float(value_pb.string_value)
575-
else:
576-
return value_pb.number_value
577-
578-
579-
def _parse_date(value_pb):
580-
return _date_from_iso8601_date(value_pb.string_value)
565+
# Note: Storing val = value_pb.string_value and doing a truthiness check is faster
566+
# than calling value_pb.HasField("string_value") because it avoids the C-extension
567+
# method lookup/call overhead and accesses the attribute only once.
568+
val = value_pb.string_value
569+
return float(val) if val else value_pb.number_value
570+
571+
572+
_POWERS_OF_10 = (
573+
1,
574+
10,
575+
100,
576+
1000,
577+
10000,
578+
100000,
579+
1000000,
580+
10000000,
581+
100000000,
582+
1000000000,
583+
)
581584

582585

583586
def _parse_timestamp(value_pb):
584-
DatetimeWithNanoseconds = datetime_helpers.DatetimeWithNanoseconds
585-
return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value)
586-
587-
588-
def _parse_numeric(value_pb):
589-
return decimal.Decimal(value_pb.string_value)
590-
591-
592-
def _parse_json(value_pb):
593-
return JsonObject.from_str(value_pb.string_value)
594-
595-
596-
def _parse_uuid(value_pb):
597-
return uuid.UUID(value_pb.string_value)
587+
val = value_pb.string_value
588+
try:
589+
if len(val) < 20 or val[10] != "T":
590+
raise ValueError()
591+
no_fraction = val[:19]
592+
bare = datetime.datetime.fromisoformat(no_fraction)
593+
if val[19] == ".":
594+
if val.endswith("Z"):
595+
offset = "Z"
596+
fraction = val[20:-1]
597+
elif val[-6] in ("+", "-"):
598+
offset = val[-6:]
599+
fraction = val[20:-6]
600+
else:
601+
raise ValueError()
602+
if not fraction or len(fraction) > 9 or not fraction.isdigit():
603+
raise ValueError()
604+
scale = 9 - len(fraction)
605+
nanos = int(fraction) * _POWERS_OF_10[scale]
606+
else:
607+
nanos = 0
608+
if val.endswith("Z"):
609+
offset = "Z"
610+
elif val[-6] in ("+", "-"):
611+
offset = val[-6:]
612+
else:
613+
raise ValueError()
614+
615+
if offset != "Z":
616+
sign = offset[0]
617+
hours = int(offset[1:3])
618+
minutes = int(offset[4:6])
619+
if offset[3] != ":":
620+
raise ValueError()
621+
delta = datetime.timedelta(hours=hours, minutes=minutes)
622+
if sign == "-":
623+
delta = -delta
624+
tzinfo = datetime.timezone(delta)
625+
bare = bare.replace(tzinfo=tzinfo).astimezone(datetime.timezone.utc)
626+
627+
return datetime_helpers.DatetimeWithNanoseconds(
628+
bare.year,
629+
bare.month,
630+
bare.day,
631+
bare.hour,
632+
bare.minute,
633+
bare.second,
634+
nanosecond=nanos,
635+
tzinfo=datetime.timezone.utc,
636+
)
637+
except (IndexError, ValueError) as e:
638+
raise ValueError("Timestamp: {} does not match pattern".format(val)) from e
598639

599640

600641
def _parse_proto(value_pb, column_info, field_name):

packages/google-cloud-spanner/google/cloud/spanner_v1/data_types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def serialize(self):
9999
return json.dumps(self, sort_keys=True, separators=(",", ":"))
100100

101101

102+
_INTERVAL_PATTERN = re.compile(
103+
r"^P(-?\d+Y)?(-?\d+M)?(-?\d+D)?(T(-?\d+H)?(-?\d+M)?(-?((\d+([.,]\d{1,9})?)|([.,]\d{1,9}))S)?)?$"
104+
)
105+
106+
102107
@dataclass
103108
class Interval:
104109
"""Represents a Spanner INTERVAL type.
@@ -187,8 +192,7 @@ def __str__(self) -> str:
187192
@classmethod
188193
def from_str(cls, s: str) -> "Interval":
189194
"""Parse an ISO8601 duration format string into an Interval."""
190-
pattern = r"^P(-?\d+Y)?(-?\d+M)?(-?\d+D)?(T(-?\d+H)?(-?\d+M)?(-?((\d+([.,]\d{1,9})?)|([.,]\d{1,9}))S)?)?$"
191-
match = re.match(pattern, s)
195+
match = _INTERVAL_PATTERN.match(s)
192196
if not match or len(s) == 1:
193197
raise ValueError(f"Invalid interval format: {s}")
194198

packages/google-cloud-spanner/google/cloud/spanner_v1/streamed.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ class StreamedResultSet(object):
3535
instances.
3636
3737
:type source: :class:`~google.cloud.spanner_v1.snapshot.Snapshot`
38-
:param source: Deprecated. Snapshot from which the result set was fetched.
39-
"""
38+
:param source: Deprecated. Snapshot from which the result set was fetched."""
4039

4140
def __init__(
4241
self,
@@ -117,16 +116,36 @@ def _merge_values(self, values):
117116
decoders = self._decoders
118117
width = len(self.fields)
119118
index = len(self._current_row)
120-
for value in values:
121-
if self._lazy_decode:
122-
self._current_row.append(value)
123-
else:
124-
self._current_row.append(_parse_nullable(value, decoders[index]))
125-
index += 1
126-
if index == width:
127-
self._rows.append(self._current_row)
128-
self._current_row = []
129-
index = 0
119+
current_row = self._current_row
120+
rows = self._rows
121+
current_row_append = current_row.append
122+
rows_append = rows.append
123+
if self._lazy_decode:
124+
for value in values:
125+
current_row_append(value)
126+
index += 1
127+
if index == width:
128+
rows_append(current_row)
129+
current_row = []
130+
current_row_append = current_row.append
131+
index = 0
132+
else:
133+
for value in values:
134+
# Note: We manually check value.HasField("null_value") here instead of
135+
# wrapping every decoder in _parse_nullable to avoid the overhead of
136+
# an extra Python function call layer for every cell value decoded in this loop.
137+
# If the nullable check logic is updated in _parse_nullable, update this check.
138+
if value.HasField("null_value"):
139+
current_row_append(None)
140+
else:
141+
current_row_append(decoders[index](value))
142+
index += 1
143+
if index == width:
144+
rows_append(current_row)
145+
current_row = []
146+
current_row_append = current_row.append
147+
index = 0
148+
self._current_row = current_row
130149

131150
def _consume_next(self):
132151
"""Consume the next partial result set from the stream.

0 commit comments

Comments
 (0)