Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/anthropic/lib/streaming/_beta_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,23 @@
from ..._utils._utils import is_given
from .._parse._response import ResponseFormatT, parse_text
from ...types.beta.parsed_beta_message import ParsedBetaMessage, ParsedBetaContentBlock

from ...types.beta.beta_raw_message_start_event import BetaRawMessageStartEvent
from ...types.beta.beta_raw_message_delta_event import BetaRawMessageDeltaEvent
from ...types.beta.beta_raw_message_stop_event import BetaRawMessageStopEvent
from ...types.beta.beta_raw_content_block_start_event import BetaRawContentBlockStartEvent
from ...types.beta.beta_raw_content_block_delta_event import BetaRawContentBlockDeltaEvent
from ...types.beta.beta_raw_content_block_stop_event import BetaRawContentBlockStopEvent



_BETA_RAW_EVENT_TYPE_MAP: dict[str, type[BaseModel]] = {
"message_start": BetaRawMessageStartEvent,
"message_delta": BetaRawMessageDeltaEvent,
"message_stop": BetaRawMessageStopEvent,
"content_block_start": BetaRawContentBlockStartEvent,
"content_block_delta": BetaRawContentBlockDeltaEvent,
"content_block_stop": BetaRawContentBlockStopEvent,
}

class BetaMessageStream(Generic[ResponseFormatT]):
text_stream: Iterator[str]
Expand Down
31 changes: 30 additions & 1 deletion src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@
from ..._utils._utils import is_given
from .._parse._response import ResponseFormatT, parse_text
from ...types.parsed_message import ParsedMessage, ParsedContentBlock
from ...types.raw_message_start_event import RawMessageStartEvent
from ...types.raw_message_delta_event import RawMessageDeltaEvent
from ...types.raw_message_stop_event import RawMessageStopEvent
from ...types.raw_content_block_start_event import RawContentBlockStartEvent
from ...types.raw_content_block_delta_event import RawContentBlockDeltaEvent
from ...types.raw_content_block_stop_event import RawContentBlockStopEvent



_RAW_EVENT_TYPE_MAP: dict[str, type[BaseModel]] = {
"message_start": RawMessageStartEvent,
"message_delta": RawMessageDeltaEvent,
"message_stop": RawMessageStopEvent,
"content_block_start": RawContentBlockStartEvent,
"content_block_delta": RawContentBlockDeltaEvent,
"content_block_stop": RawContentBlockStopEvent,
}


class MessageStream(Generic[ResponseFormatT]):
Expand Down Expand Up @@ -445,7 +462,19 @@ def accumulate_event(
),
)
if not isinstance(cast(Any, event), BaseModel):
raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}")
# Union discriminator deserialization silently returned the raw dict in some
# environments (e.g. older pydantic versions). Fall back to a direct type-map
# lookup using the 'type' field so that well-formed events (including
# content_block_delta) are always promoted to the correct BaseModel. See #941.
raw = cast(Any, event)
if isinstance(raw, dict):
event_type = raw.get("type")
target_cls = _RAW_EVENT_TYPE_MAP.get(event_type) if isinstance(event_type, str) else None
if target_cls is not None:
event = cast(RawMessageStreamEvent, target_cls.model_construct(**raw))

if not isinstance(cast(Any, event), BaseModel):
raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}")

if current_snapshot is None:
if event.type == "message_start":
Expand Down