diff --git a/pkg/hooks/types.go b/pkg/hooks/types.go index ef97e9232..ae9fbda9d 100644 --- a/pkg/hooks/types.go +++ b/pkg/hooks/types.go @@ -148,6 +148,17 @@ type Input struct { // closure (e.g. response cache). AgentName string `json:"agent_name,omitempty"` + // ModelID identifies the model the runtime is about to call (for + // before_llm_call) or just called (for after_llm_call), in the + // canonical "/" form expected by + // [modelsdev.Store.GetModel]. Populated from the loop's resolved + // model so it reflects per-tool model overrides and alloy-mode + // random selection — do NOT call Agent.Model() from a hook to + // recompute it, since alloy mode would re-randomize and a per-tool + // override would be invisible. Empty for events that aren't + // model-call-scoped. + ModelID string `json:"model_id,omitempty"` + // LastUserMessage is the text content of the latest user message in // the session at dispatch time. Populated for events that respond to // a user turn (stop, after_llm_call). Empty for events that aren't diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index 0a1fad321..097f155d1 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -284,9 +284,16 @@ func (r *LocalRuntime) executeOnToolApprovalDecisionHooks( // / exit 2) stops the run loop — see [hooks.EventBeforeLLMCall] for // the contract. Hooks that just want to contribute system messages // should target turn_start instead. -func (r *LocalRuntime) executeBeforeLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent) (stop bool, message string) { +// +// modelID is the canonical model identifier the loop has just +// resolved (after per-tool overrides and alloy-mode selection); it's +// surfaced to hooks via [hooks.Input.ModelID] so handlers don't need +// to recompute it from the agent. +func (r *LocalRuntime) executeBeforeLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent, modelID string) (stop bool, message string) { result := r.dispatchHook(ctx, a, hooks.EventBeforeLLMCall, &hooks.Input{ SessionID: sess.ID, + AgentName: a.Name(), + ModelID: modelID, }, nil) if result == nil || result.Allowed { return false, "" diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index aff218800..7c1120e55 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -375,17 +375,10 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, messages := sess.GetMessages(a, slices.Concat(sessionStartMsgs, userPromptMsgs, turnStartMsgs)...) slog.Debug("Retrieved messages for processing", "agent", a.Name(), "message_count", len(messages)) - // Strip image content from messages if the model doesn't support image input. - // This prevents API errors when conversation history contains images (e.g. from - // tool results or user attachments) but the current model is text-only. - if m != nil && len(m.Modalities.Input) > 0 && !slices.Contains(m.Modalities.Input, "image") { - messages = stripImageContent(messages) - } - // before_llm_call hooks fire just before the model is invoked. // A terminating verdict (e.g. from the max_iterations builtin) // stops the run loop here, before any tokens are spent. - if stop, msg := r.executeBeforeLLMCallHooks(ctx, sess, a); stop { + if stop, msg := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID); stop { slog.Warn("before_llm_call hook signalled run termination", "agent", a.Name(), "session_id", sess.ID, "reason", msg) r.emitHookDrivenShutdown(ctx, a, sess, msg, events) @@ -393,6 +386,16 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, return } + // Apply registered before_llm_call message transforms (e.g. + // strip_unsupported_modalities for text-only models, plus any + // embedder-supplied redactor / scrubber registered via + // WithMessageTransform). Runs after the gate so a transform + // failure cannot waste the gate's allow verdict. modelID is + // passed explicitly so transforms see the actual model the + // loop chose (per-tool override + alloy-mode selection), + // not whatever a fresh agent.Model() call would re-randomize. + messages = r.applyBeforeLLMCallTransforms(ctx, sess, a, modelID, messages) + // Try primary model with fallback chain if configured res, usedModel, err := r.fallback.execute(streamCtx, a, model, messages, agentTools, sess, m, events) if err != nil { diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 222429b52..323c44d5e 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -161,6 +161,13 @@ type LocalRuntime struct { // construction, so no locking is needed. hooksExecByAgent map[string]*hooks.Executor + // transforms is the runtime's [MessageTransform] chain, applied to + // every LLM call in registration order. Populated by + // [NewLocalRuntime] (for the runtime-shipped strip transform) and by + // [WithMessageTransform] (for embedder-supplied transforms). + // Read-only after construction. + transforms []registeredTransform + fallback *fallbackExecutor // observers receive every event the runtime produces, in @@ -392,6 +399,17 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { return nil, fmt.Errorf("register %q builtin: %w", BuiltinCacheResponse, err) } + // strip_unsupported_modalities is the runtime-shipped + // before_llm_call message transform that drops image content from + // messages when the agent's model is text-only. Like + // cache_response it captures the runtime closure (to resolve the + // agent and its model from Input.AgentName) and is therefore + // registered here rather than in pkg/hooks/builtins. + r.transforms = append(r.transforms, registeredTransform{ + name: BuiltinStripUnsupportedModalities, + fn: r.stripUnsupportedModalitiesTransform, + }) + for _, opt := range opts { opt(r) } diff --git a/pkg/runtime/streaming.go b/pkg/runtime/streaming.go index d3635ba82..937fdb227 100644 --- a/pkg/runtime/streaming.go +++ b/pkg/runtime/streaming.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "log/slog" "strings" "github.com/docker/docker-agent/pkg/agent" @@ -236,39 +235,3 @@ func handleStream(ctx context.Context, stream chat.MessageStream, a *agent.Agent Usage: messageUsage, }, nil } - -// stripImageContent returns a copy of messages with all image-related content -// removed. This is used when the target model doesn't support image input to -// prevent API errors. Text content is preserved; image parts in MultiContent -// are filtered out, and file attachments with image MIME types are dropped. -func stripImageContent(messages []chat.Message) []chat.Message { - result := make([]chat.Message, len(messages)) - for i, msg := range messages { - result[i] = msg - - if len(msg.MultiContent) == 0 { - continue - } - - var filtered []chat.MessagePart - for _, part := range msg.MultiContent { - switch part.Type { - case chat.MessagePartTypeImageURL: - // Drop image URL parts entirely. - continue - case chat.MessagePartTypeFile: - // Drop file parts that are images. - if part.File != nil && chat.IsImageMimeType(part.File.MimeType) { - continue - } - } - filtered = append(filtered, part) - } - - if len(filtered) != len(msg.MultiContent) { - result[i].MultiContent = filtered - slog.Debug("Stripped image content from message", "role", msg.Role, "original_parts", len(msg.MultiContent), "remaining_parts", len(filtered)) - } - } - return result -} diff --git a/pkg/runtime/strip_modalities.go b/pkg/runtime/strip_modalities.go new file mode 100644 index 000000000..611b568ae --- /dev/null +++ b/pkg/runtime/strip_modalities.go @@ -0,0 +1,115 @@ +package runtime + +import ( + "context" + "log/slog" + "slices" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/hooks" +) + +// BuiltinStripUnsupportedModalities is the name of the runtime-shipped +// before_llm_call message transform that drops image content from the +// outgoing messages when the agent's current model doesn't list image +// in its input modalities. It's the runtime-shipped peer of +// [BuiltinCacheResponse] (a stop hook) — the constant exists mostly +// for log filtering and diagnostics. +// +// Sending images to a text-only model produces hard provider errors +// (HTTP 400 from OpenAI, "image input is not supported" from +// Anthropic text variants, etc.); promoting the strip into a +// registered transform replaces an inline branch in runStreamLoop and +// opens the door to a family of message-mutating transforms +// (redactors, scrubbers, ...). +const BuiltinStripUnsupportedModalities = "strip_unsupported_modalities" + +// modalityImage is the canonical models.dev modality name for image +// input. A constant instead of a literal so a typo trips a compile +// error and the contract with [modelsdev.Modalities.Input] is +// discoverable from the runtime side. +const modalityImage = "image" + +// stripUnsupportedModalitiesTransform is the [MessageTransform] +// registered under [BuiltinStripUnsupportedModalities]. It looks up +// the model definition from [hooks.Input.ModelID] (populated by the +// runtime with the actual model the loop chose, including per-tool +// overrides and alloy-mode selection) and applies +// [stripImageContent] when image is missing from the model's input +// modalities. +// +// The transform is a no-op for every "we don't know enough to act" +// case (missing ModelID, models.dev miss, empty modalities, image +// already supported): erring on the side of "send the messages +// as-is" matches the previous inline behavior in runStreamLoop, +// where an unknown model also fell through. Each fall-through emits +// a Debug log so operators can tell strip_unsupported_modalities +// from a transform that's silently inactive. +func (r *LocalRuntime) stripUnsupportedModalitiesTransform( + ctx context.Context, + in *hooks.Input, + msgs []chat.Message, +) ([]chat.Message, error) { + if in == nil || in.ModelID == "" { + slog.Debug("strip_unsupported_modalities: skipping, no ModelID on input") + return msgs, nil + } + m, err := r.modelsStore.GetModel(ctx, in.ModelID) + if err != nil || m == nil { + // Unknown model: keep the previous (inline) behavior of + // passing messages through untouched. The model call will + // surface any modality mismatch as a provider error. + slog.Debug("strip_unsupported_modalities: skipping, model definition unavailable", + "model_id", in.ModelID, "error", err) + return msgs, nil + } + if len(m.Modalities.Input) == 0 || slices.Contains(m.Modalities.Input, modalityImage) { + return msgs, nil + } + return stripImageContent(msgs), nil +} + +// stripImageContent returns a copy of messages with all image-related +// content removed. Text content is preserved; image parts in +// [chat.Message.MultiContent] are filtered out, and file attachments +// with image MIME types are dropped. +// +// Lives next to [stripUnsupportedModalitiesTransform] (rather than in +// streaming.go where it originated) so the builtin's storage, +// transform, and helper are co-located. Kept as an unexported helper +// because the only legitimate caller is the transform itself — direct +// use bypasses the modality check. +func stripImageContent(messages []chat.Message) []chat.Message { + result := make([]chat.Message, len(messages)) + for i, msg := range messages { + result[i] = msg + + if len(msg.MultiContent) == 0 { + continue + } + + var filtered []chat.MessagePart + for _, part := range msg.MultiContent { + switch part.Type { + case chat.MessagePartTypeImageURL: + // Drop image URL parts entirely. + continue + case chat.MessagePartTypeFile: + // Drop file parts that are images. + if part.File != nil && chat.IsImageMimeType(part.File.MimeType) { + continue + } + } + filtered = append(filtered, part) + } + + if len(filtered) != len(msg.MultiContent) { + result[i].MultiContent = filtered + slog.Debug("Stripped image content from message", + "role", msg.Role, + "original_parts", len(msg.MultiContent), + "remaining_parts", len(filtered)) + } + } + return result +} diff --git a/pkg/runtime/transforms.go b/pkg/runtime/transforms.go new file mode 100644 index 000000000..fd14d87f1 --- /dev/null +++ b/pkg/runtime/transforms.go @@ -0,0 +1,101 @@ +package runtime + +import ( + "context" + "log/slog" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/session" +) + +// MessageTransform is the in-process-only handler signature for a +// before_llm_call transform that rewrites the chat messages about to +// be sent to the model. It receives the full message slice in chain +// order and returns the (possibly-rewritten) replacement. +// +// Transforms are intentionally a runtime-private contract: the cost of +// JSON-roundtripping a full conversation through the cross-process +// hook protocol would be prohibitive, so command and model hooks +// cannot rewrite messages. Embedders register transforms via +// [WithMessageTransform]; the runtime ships +// [BuiltinStripUnsupportedModalities] out of the box. +// +// Transforms run AFTER the standard before_llm_call gate (see +// [LocalRuntime.executeBeforeLLMCallHooks]) — a hook that wants to +// abort the call should target the gate, not a transform. +// +// Returning a non-nil error logs a warning and falls through to the +// previous message slice; a transform failure must never break the +// run loop. +type MessageTransform func(ctx context.Context, in *hooks.Input, msgs []chat.Message) ([]chat.Message, error) + +// registeredTransform pairs a [MessageTransform] with the name it was +// registered under. The name is purely diagnostic — it shows up in +// slog records when a transform errors out — so re-registering the +// same name simply appends another entry without any de-duplication. +type registeredTransform struct { + name string + fn MessageTransform +} + +// WithMessageTransform registers a [MessageTransform] under name so +// it is applied to every LLM call, in registration order, after the +// before_llm_call gate. Transforms are runtime-global: per-agent +// scoping (if needed) lives in the transform body, where +// [hooks.Input.AgentName] is available — the runtime-shipped strip +// transform is an example. +// +// Empty name or nil fn are silently ignored, matching the no-error +// shape of the other [Opt] helpers. +func WithMessageTransform(name string, fn MessageTransform) Opt { + return func(r *LocalRuntime) { + if name == "" || fn == nil { + slog.Warn("Ignoring message transform with empty name or nil fn", "name", name) + return + } + r.transforms = append(r.transforms, registeredTransform{name: name, fn: fn}) + } +} + +// applyBeforeLLMCallTransforms runs every registered +// [MessageTransform] in chain order, just before the model call and +// AFTER [LocalRuntime.executeBeforeLLMCallHooks] has approved it. +// Errors from individual transforms are logged at warn level and the +// chain continues with the previous slice — a transform failure must +// never break the run loop. +// +// modelID is the canonical model identifier the loop has just +// resolved (after per-tool overrides and alloy-mode selection); +// transforms read it via [hooks.Input.ModelID]. Calling +// agent.Model() from a transform would re-randomize the alloy pick +// and miss the per-tool override. +func (r *LocalRuntime) applyBeforeLLMCallTransforms( + ctx context.Context, + sess *session.Session, + a *agent.Agent, + modelID string, + msgs []chat.Message, +) []chat.Message { + if len(r.transforms) == 0 { + return msgs + } + in := &hooks.Input{ + SessionID: sess.ID, + AgentName: a.Name(), + ModelID: modelID, + HookEventName: hooks.EventBeforeLLMCall, + Cwd: r.workingDir, + } + for _, t := range r.transforms { + out, err := t.fn(ctx, in, msgs) + if err != nil { + slog.Warn("Message transform failed; continuing with previous messages", + "transform", t.name, "agent", a.Name(), "error", err) + continue + } + msgs = out + } + return msgs +} diff --git a/pkg/runtime/transforms_test.go b/pkg/runtime/transforms_test.go new file mode 100644 index 000000000..fc6c193ce --- /dev/null +++ b/pkg/runtime/transforms_test.go @@ -0,0 +1,365 @@ +// TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap covers the hot +package runtime + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/team" + "github.com/docker/docker-agent/pkg/tools" +) + +// modalityModelStore returns a fixed [modelsdev.Model] regardless of +// the requested ID. Tests configure its Modalities to exercise the +// strip_unsupported_modalities transform's three branches: text-only +// (strip), image-supporting (no-op), and unknown-model (no-op). +type modalityModelStore struct { + ModelStore + + model *modelsdev.Model + err error +} + +func (m modalityModelStore) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { + return m.model, m.err +} + +// modalityByIDStore returns a different [modelsdev.Model] depending +// on the requested ID, letting tests prove the transform consulted +// the right ID (via [hooks.Input.ModelID]) rather than recomputing +// it from the agent. +type modalityByIDStore struct { + ModelStore + + models map[string]*modelsdev.Model +} + +func (m modalityByIDStore) GetModel(_ context.Context, id string) (*modelsdev.Model, error) { + return m.models[id], nil +} + +// recordingMsgProvider captures the messages each model call sees so +// a test can confirm a transform actually rewrote what reached the +// provider (rather than just what the in-memory slice ended up +// looking like). +type recordingMsgProvider struct { + mockProvider + + got [][]chat.Message +} + +func (p *recordingMsgProvider) CreateChatCompletionStream(_ context.Context, msgs []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { + p.got = append(p.got, append([]chat.Message{}, msgs...)) + return p.stream, nil +} + +// TestStripUnsupportedModalitiesTransform pins the three branches of +// the runtime-shipped transform: a text-only model strips images, a +// multimodal model passes them through, and an unknown model also +// passes them through (the call surfaces any modality mismatch as a +// provider error rather than panicking transform-side). +func TestStripUnsupportedModalitiesTransform(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + imgMsg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "look at this"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, + }, + } + + cases := []struct { + name string + store modalityModelStore + modelID string + wantStrip bool + }{ + {name: "text-only model strips images", modelID: "test/text", store: modalityModelStore{model: &modelsdev.Model{Modalities: modelsdev.Modalities{Input: []string{"text"}}}}, wantStrip: true}, + {name: "multimodal model passes through", modelID: "test/multimodal", store: modalityModelStore{model: &modelsdev.Model{Modalities: modelsdev.Modalities{Input: []string{"text", "image"}}}}}, + {name: "nil model passes through", modelID: "test/unknown", store: modalityModelStore{model: nil}}, + {name: "lookup error passes through", modelID: "test/unknown", store: modalityModelStore{err: errors.New("not found")}}, + {name: "empty modalities passes through", modelID: "test/empty", store: modalityModelStore{model: &modelsdev.Model{}}}, + {name: "empty ModelID passes through", modelID: "", store: modalityModelStore{model: &modelsdev.Model{Modalities: modelsdev.Modalities{Input: []string{"text"}}}}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r, err := NewLocalRuntime(tm, WithModelStore(tc.store)) + require.NoError(t, err) + + got, err := r.stripUnsupportedModalitiesTransform(t.Context(), + &hooks.Input{ModelID: tc.modelID}, []chat.Message{imgMsg}) + require.NoError(t, err) + require.Len(t, got, 1) + if tc.wantStrip { + require.Len(t, got[0].MultiContent, 1, "image part must be stripped") + assert.Equal(t, chat.MessagePartTypeText, got[0].MultiContent[0].Type) + } else { + assert.Equal(t, imgMsg, got[0], "messages must reach the model untouched") + } + }) + } +} + +// TestStripUnsupportedModalitiesTransform_UsesInputModelID pins the +// fix for an alloy-mode / per-tool-override correctness bug: the +// transform must trust [hooks.Input.ModelID] (populated by the loop +// with the model it actually picked) and NOT recompute the model by +// calling agent.Model() — doing so would re-randomize the alloy +// pick and miss any per-tool override the loop had applied. +// +// The test wires a store that reports text-only for one ID and +// multimodal for another. Querying by the text-only ID must strip; +// querying by the multimodal ID must pass through. The agent's own +// model (its pool) is irrelevant — it's never consulted. +func TestStripUnsupportedModalitiesTransform_UsesInputModelID(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/agent-pool-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + store := modalityByIDStore{models: map[string]*modelsdev.Model{ + "text/only": {Modalities: modelsdev.Modalities{Input: []string{"text"}}}, + "multi/modal": {Modalities: modelsdev.Modalities{Input: []string{"text", "image"}}}, + "test/agent-pool-model": {Modalities: modelsdev.Modalities{Input: []string{"text", "image"}}}, + }} + r, err := NewLocalRuntime(tm, WithModelStore(store)) + require.NoError(t, err) + + imgMsg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "describe"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, + }, + } + + // ModelID = text-only — strip must happen even though the agent's + // pool model is multimodal. + stripped, err := r.stripUnsupportedModalitiesTransform(t.Context(), + &hooks.Input{ModelID: "text/only"}, []chat.Message{imgMsg}) + require.NoError(t, err) + require.Len(t, stripped[0].MultiContent, 1, "image must be stripped when ModelID is text-only") + assert.Equal(t, chat.MessagePartTypeText, stripped[0].MultiContent[0].Type) + + // ModelID = multimodal — strip must NOT happen even if some other + // model in scope is text-only. Proves the lookup keys off ModelID. + passed, err := r.stripUnsupportedModalitiesTransform(t.Context(), + &hooks.Input{ModelID: "multi/modal"}, []chat.Message{imgMsg}) + require.NoError(t, err) + assert.Equal(t, imgMsg, passed[0], "images must reach a multimodal ModelID untouched") +} + +// path: a runtime with no registered transforms returns the input +// slice as-is without allocating a [hooks.Input]. +func TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) + require.NoError(t, err) + + // Drop the runtime-shipped strip transform so we can observe the + // cheap-path behavior. + r.transforms = nil + + sess := session.New(session.WithUserMessage("hi")) + msgs := []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}} + + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, "", msgs) + assert.Equal(t, msgs, got) +} + +// TestApplyBeforeLLMCallTransforms_OrderAndChain verifies that +// transforms registered via [WithMessageTransform] run in +// registration order and feed each transform the cumulative output of +// the previous one (chain semantics, not parallel). +func TestApplyBeforeLLMCallTransforms_OrderAndChain(t *testing.T) { + t.Parallel() + + type call struct { + name string + seenIn int + } + var calls []call + tag := func(name string) MessageTransform { + return func(_ context.Context, _ *hooks.Input, msgs []chat.Message) ([]chat.Message, error) { + calls = append(calls, call{name: name, seenIn: len(msgs)}) + return append(msgs, chat.Message{Role: chat.MessageRoleSystem, Content: name}), nil + } + } + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + r, err := NewLocalRuntime(tm, + WithModelStore(mockModelStore{}), + WithMessageTransform("tag_a", tag("tag_a")), + WithMessageTransform("tag_b", tag("tag_b")), + ) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("hi")) + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, "test/mock-model", + []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}}) + + require.Len(t, calls, 2, "expected tag_a + tag_b to fire exactly once each") + assert.Equal(t, "tag_a", calls[0].name, "transforms must run in registration order") + assert.Equal(t, "tag_b", calls[1].name) + assert.Greater(t, calls[1].seenIn, calls[0].seenIn, + "tag_b must see tag_a's appended message (chain semantics, not parallel)") + + var contents []string + for _, m := range got { + contents = append(contents, m.Content) + } + assert.Contains(t, contents, "tag_a") + assert.Contains(t, contents, "tag_b") +} + +// TestApplyBeforeLLMCallTransforms_ErrorsAreSwallowed pins the +// fail-soft contract: a transform that returns an error must NOT +// break the run loop; the previous slice continues through the +// chain. +func TestApplyBeforeLLMCallTransforms_ErrorsAreSwallowed(t *testing.T) { + t.Parallel() + + failing := func(_ context.Context, _ *hooks.Input, _ []chat.Message) ([]chat.Message, error) { + return nil, errors.New("boom") + } + tag := func(_ context.Context, _ *hooks.Input, msgs []chat.Message) ([]chat.Message, error) { + return append(msgs, chat.Message{Role: chat.MessageRoleSystem, Content: "after_failure"}), nil + } + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + r, err := NewLocalRuntime(tm, + WithModelStore(mockModelStore{}), + WithMessageTransform("failing", failing), + WithMessageTransform("tag", tag), + ) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("hi")) + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, "test/mock-model", + []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}}) + + var contents []string + for _, m := range got { + contents = append(contents, m.Content) + } + assert.Contains(t, contents, "after_failure", + "a transform error must not abort the chain") +} + +// TestRunStream_StripsImagesForTextOnlyModel is the end-to-end smoke +// test confirming the inline strip in runStreamLoop has been +// replaced: messages reaching the provider must no longer carry +// image parts when the agent's model is text-only. +func TestRunStream_StripsImagesForTextOnlyModel(t *testing.T) { + t.Parallel() + + stream := newStreamBuilder().AddContent("ok").AddStopWithUsage(1, 1).Build() + prov := &recordingMsgProvider{mockProvider: mockProvider{id: "test/text-only", stream: stream}} + + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + store := modalityModelStore{model: &modelsdev.Model{ + Modalities: modelsdev.Modalities{Input: []string{"text"}}, + }} + r, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(store)) + require.NoError(t, err) + + sess := session.New() + sess.AddMessage(session.UserMessage("", + chat.MessagePart{Type: chat.MessagePartTypeText, Text: "describe"}, + chat.MessagePart{Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, + )) + + for range r.RunStream(t.Context(), sess) { + // drain — only the recorded provider state matters + } + + require.NotEmpty(t, prov.got, "provider must have been called") + for _, m := range prov.got[0] { + for _, p := range m.MultiContent { + assert.NotEqual(t, chat.MessagePartTypeImageURL, p.Type, + "image parts must be stripped before reaching a text-only model") + } + } +} + +// TestRunStream_TransformErrorDoesNotBreakRun is the end-to-end smoke +// test confirming the fail-soft contract: a transform error must not +// prevent the model from being called and the run from completing. +func TestRunStream_TransformErrorDoesNotBreakRun(t *testing.T) { + t.Parallel() + + stream := newStreamBuilder().AddContent("ok").AddStopWithUsage(1, 1).Build() + prov := &mockProvider{id: "test/mock-model", stream: stream} + + failing := func(_ context.Context, _ *hooks.Input, _ []chat.Message) ([]chat.Message, error) { + return nil, errors.New("boom") + } + + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + r, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithMessageTransform("failing", failing), + ) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("hi")) + var sawStop bool + for ev := range r.RunStream(t.Context(), sess) { + if _, ok := ev.(*StreamStoppedEvent); ok { + sawStop = true + } + } + assert.True(t, sawStop, "run must complete despite a failing transform") +} + +// TestWithMessageTransform_RejectsEmptyAndNil pins the input +// validation: empty name or nil fn must be silently ignored +// (matching the no-error shape of other Opts). +func TestWithMessageTransform_RejectsEmptyAndNil(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + r, err := NewLocalRuntime(tm, + WithModelStore(mockModelStore{}), + WithMessageTransform("", func(_ context.Context, _ *hooks.Input, msgs []chat.Message) ([]chat.Message, error) { + return msgs, nil + }), + WithMessageTransform("nilfn", nil), + ) + require.NoError(t, err, "WithMessageTransform must not surface a constructor error") + + // Only the runtime-shipped strip transform should remain. + require.Len(t, r.transforms, 1, "invalid transforms must be silently ignored") + assert.Equal(t, BuiltinStripUnsupportedModalities, r.transforms[0].name) +}