from dataclasses import Field
from dataclasses import fields
from typing import IO
from typing import Literal
from typing import TypeVar
from typing import assert_never
from typing import overload
from kio._utils import cache
from kio.static.protocol import Entity
from . import readers
from ._implicit_defaults import get_tagged_field_default
from ._introspect import EntityField
from ._introspect import EntityTupleField
from ._introspect import PrimitiveField
from ._introspect import PrimitiveTupleField
from ._introspect import classify_field
from ._introspect import get_field_tag
from ._introspect import get_schema_field_type
from ._introspect import is_optional
from ._shared import NullableEntityMarker
from .readers import read_int8
def get_reader(
kafka_type: str,
flexible: bool,
optional: bool,
) -> readers.Reader:
match (kafka_type, flexible, optional):
case ("int8", _, False):
return readers.read_int8
case ("int16", _, False):
return readers.read_int16
case ("int32", _, False):
return readers.read_int32
case ("int64", _, False):
return readers.read_int64
case ("uint8", _, False):
return readers.read_uint8
case ("uint16", _, False):
return readers.read_uint16
case ("uint32", _, False):
return readers.read_uint32
case ("uint64", _, False):
return readers.read_uint64
case ("float64", _, False):
return readers.read_float64
case ("string", True, False):
return readers.read_compact_string
case ("string", True, True):
return readers.read_compact_string_nullable
case ("string", False, False):
return readers.read_legacy_string
case ("string", False, True):
return readers.read_nullable_legacy_string
case ("bytes" | "records", True, False):
return readers.read_compact_string_as_bytes
case ("bytes" | "records", True, True):
return readers.read_compact_string_as_bytes_nullable
case ("bytes" | "records", False, False):
return readers.read_legacy_bytes
case ("bytes" | "records", False, True):
return readers.read_nullable_legacy_bytes
case ("uuid", _, _):
return readers.read_uuid
case ("bool", _, False):
return readers.read_boolean
case ("error_code", _, False):
return readers.read_error_code
case ("timedelta_i32", _, False):
return readers.read_timedelta_i32
case ("timedelta_i64", _, False):
return readers.read_timedelta_i64
case ("datetime_i64", _, False):
return readers.read_datetime_i64
case ("datetime_i64", _, True):
return readers.read_nullable_datetime_i64
raise NotImplementedError(
f"Failed identifying reader for {kafka_type!r} field {flexible=} {optional=}"
)
T = TypeVar("T")
def get_field_reader(
entity_type: type[Entity],
field: Field[T],
is_request_header: bool,
is_tagged_field: bool,
) -> readers.Reader[T]:
# RequestHeader.client_id is special-cased by Apache Kafka® to always use the legacy
# string format.
# https://github.com/apache/kafka/blob/trunk/clients/src/main/resources/common/message/RequestHeader.json#L34-L38
if is_request_header and field.name == "client_id":
return readers.read_nullable_legacy_string # type: ignore[return-value]
flexible = entity_type.__flexible__
field_class = classify_field(field)
match field_class:
case PrimitiveField():
inner_type_reader = get_reader(
kafka_type=get_schema_field_type(field),
flexible=flexible,
optional=is_optional(field) and not is_tagged_field,
)
case PrimitiveTupleField():
inner_type_reader = get_reader(
kafka_type=get_schema_field_type(field),
flexible=flexible,
optional=is_optional(field),
)
case EntityField(field_type):
inner_type_reader = (
entity_reader(field_type, nullable=True)
if is_optional(field)
else entity_reader(field_type, nullable=False)
)
case EntityTupleField(field_type):
inner_type_reader = entity_reader(field_type)
case no_match:
assert_never(no_match)
if field_class.is_array:
array_reader = (
readers.compact_array_reader if flexible else readers.legacy_array_reader
)
# mypy fails to bind T to Sequence[object] here.
return array_reader(inner_type_reader) # type: ignore[return-value]
return inner_type_reader
E = TypeVar("E", bound=Entity)
@overload
def entity_reader(
entity_type: type[E],
nullable: Literal[False] = ...,
) -> readers.Reader[E]: ...
@overload
def entity_reader(
entity_type: type[E],
nullable: Literal[True],
) -> readers.Reader[E | None]: ...
[docs]
@cache
def entity_reader(
entity_type: type[E],
nullable: bool = False,
) -> readers.Reader[E | None]:
field_readers = {}
tagged_field_readers = {}
is_request_header = entity_type.__name__ == "RequestHeader"
for field in fields(entity_type):
tag = get_field_tag(field)
field_reader = get_field_reader(
entity_type=entity_type,
field=field,
is_request_header=is_request_header,
is_tagged_field=tag is not None,
)
if tag is not None:
tagged_field_readers[tag] = (
field,
field_reader,
get_tagged_field_default(field),
)
else:
field_readers[field] = field_reader
# Assert we don't find tags for non-flexible models.
if tagged_field_readers and not entity_type.__flexible__:
raise ValueError("Found tagged fields on a non-flexible model")
def read_entity(buffer: IO[bytes]) -> E:
# Read regular fields.
kwargs = {
field.name: field_reader(buffer)
for field, field_reader in field_readers.items()
}
# For non-flexible entities we're done here.
if not entity_type.__flexible__:
return entity_type(**kwargs)
# Read tagged fields.
tagged_field_values = {}
num_tagged_fields = readers.read_unsigned_varint(buffer)
for _ in range(num_tagged_fields):
field_tag = readers.read_unsigned_varint(buffer)
readers.read_unsigned_varint(buffer) # field length
field, field_reader, _ = tagged_field_readers[field_tag]
tagged_field_values[field.name] = field_reader(buffer)
# Resolve tagged field implicit defaults.
for field, _, implicit_default in tagged_field_readers.values():
kwargs[field.name] = tagged_field_values.get(field.name, implicit_default)
return entity_type(**kwargs)
if not nullable:
return read_entity
# This is undocumented behavior, formalized in KIP-893.
# https://cwiki.apache.org/confluence/display/KAFKA/KIP-893%3A+The+Kafka+protocol+should+support+nullable+structs
def read_nullable_entity(buffer: IO[bytes]) -> E | None:
marker = NullableEntityMarker(read_int8(buffer))
return None if marker is NullableEntityMarker.null else read_entity(buffer)
return read_nullable_entity