From 2b50abcac63a3e85ea6fbcb17d230ec0fc36c7bf Mon Sep 17 00:00:00 2001 From: David Gageot Date: Tue, 28 Apr 2026 11:11:57 +0200 Subject: [PATCH 1/3] extract strip_unsupported_modalities into a registered before_llm_call transform --- pkg/runtime/hooks.go | 18 +- pkg/runtime/hooks_wiring_test.go | 31 +- pkg/runtime/loop.go | 14 +- pkg/runtime/runtime.go | 31 ++ pkg/runtime/streaming.go | 37 --- pkg/runtime/strip_modalities.go | 122 ++++++++ pkg/runtime/transforms.go | 222 +++++++++++++++ pkg/runtime/transforms_test.go | 471 +++++++++++++++++++++++++++++++ 8 files changed, 885 insertions(+), 61 deletions(-) create mode 100644 pkg/runtime/strip_modalities.go create mode 100644 pkg/runtime/transforms.go create mode 100644 pkg/runtime/transforms_test.go diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index 0a1fad321..aa19cdc8e 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -16,17 +16,23 @@ import ( // buildHooksExecutors builds a [hooks.Executor] for every agent in the // team that has user-configured hooks, an agent-flag that maps to a -// builtin (AddDate / AddEnvironmentInfo / AddPromptFiles), or a +// builtin (AddDate / AddEnvironmentInfo / AddPromptFiles), a // configured response cache (which auto-injects a cache_response stop -// hook). Agents with no hooks have no entry; lookups fall through to -// nil so callers can short-circuit cheaply. +// hook), or any registered [MessageTransform] (which auto-injects a +// before_llm_call builtin). Agents with no hooks have no entry; +// lookups fall through to nil so callers can short-circuit cheaply. +// +// The matching [resolvedTransform] chain for each agent is also +// pre-resolved here into r.transformsByAgent so per-LLM-call +// dispatch is a flat slice walk. // // Called once from [NewLocalRuntime] after r.workingDir, r.env and -// r.hooksRegistry are finalized; the resulting map is read-only for +// r.hooksRegistry are finalized; the resulting maps are read-only for // the lifetime of the runtime, so per-dispatch lookups don't need to // lock. func (r *LocalRuntime) buildHooksExecutors() { r.hooksExecByAgent = make(map[string]*hooks.Executor) + r.transformsByAgent = make(map[string][]resolvedTransform) for _, name := range r.team.AgentNames() { a, err := r.team.Agent(name) if err != nil { @@ -38,10 +44,14 @@ func (r *LocalRuntime) buildHooksExecutors() { AddPromptFiles: a.AddPromptFiles(), }) cfg = applyCacheDefault(cfg, a) + cfg = r.applyMessageTransformDefaults(cfg) if cfg == nil { continue } r.hooksExecByAgent[name] = hooks.NewExecutorWithRegistry(cfg, r.workingDir, r.env, r.hooksRegistry) + if transforms := r.resolveTransforms(cfg); len(transforms) > 0 { + r.transformsByAgent[name] = transforms + } } } diff --git a/pkg/runtime/hooks_wiring_test.go b/pkg/runtime/hooks_wiring_test.go index 2fe0f1923..775aee53e 100644 --- a/pkg/runtime/hooks_wiring_test.go +++ b/pkg/runtime/hooks_wiring_test.go @@ -21,6 +21,12 @@ import ( // - AddPromptFiles -> turn_start (file may be edited mid-session) // - AddEnvironmentInfo -> session_start (wd/OS/arch don't change) // +// Every agent additionally receives an auto-injected +// [BuiltinStripUnsupportedModalities] entry on before_llm_call (the +// runtime-shipped message transform that drops images for text-only +// models), so the executor is always non-nil — even for an agent +// without any explicit flags. +// // The behavior of each builtin (what it puts in AdditionalContext) is // covered by pkg/hooks/builtins; this test only asserts the wiring, // using a smoke Dispatch to confirm that the registered builtin name @@ -33,16 +39,14 @@ func TestHooksExecWiresAgentFlagsToBuiltins(t *testing.T) { prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} cases := []struct { - name string - opts []agent.Opt - wantNoExecutor bool - wantTurnStart bool - wantSessStart bool + name string + opts []agent.Opt + wantTurnStart bool + wantSessStart bool }{ { - name: "no flags: no implicit hooks, no executor", - opts: []agent.Opt{agent.WithModel(prov)}, - wantNoExecutor: true, + name: "no flags: only the auto-injected strip transform on before_llm_call", + opts: []agent.Opt{agent.WithModel(prov)}, }, { name: "AddDate wires turn_start", @@ -82,16 +86,17 @@ func TestHooksExecWiresAgentFlagsToBuiltins(t *testing.T) { require.NoError(t, err) exec := r.hooksExec(a) - if tc.wantNoExecutor { - assert.Nil(t, exec, "no flags must not produce an executor") - return - } - require.NotNil(t, exec) + require.NotNil(t, exec, "every agent receives the auto-injected strip transform") // hooksExec is read-only after [LocalRuntime.buildHooksExecutors], // so calling it twice returns the same pointer. assert.Same(t, exec, r.hooksExec(a), "hooksExec must be stable across calls") + // before_llm_call always carries the strip_unsupported_modalities + // builtin, regardless of agent flags. + assert.True(t, exec.Has(hooks.EventBeforeLLMCall), + "before_llm_call must always carry the auto-injected strip transform") + assert.Equal(t, tc.wantTurnStart, exec.Has(hooks.EventTurnStart), "turn_start activation must match flags") assert.Equal(t, tc.wantSessStart, exec.Has(hooks.EventSessionStart), diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index aff218800..35e1ab82a 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -375,13 +375,6 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, messages := sess.GetMessages(a, slices.Concat(sessionStartMsgs, userPromptMsgs, turnStartMsgs)...) slog.Debug("Retrieved messages for processing", "agent", a.Name(), "message_count", len(messages)) - // Strip image content from messages if the model doesn't support image input. - // This prevents API errors when conversation history contains images (e.g. from - // tool results or user attachments) but the current model is text-only. - if m != nil && len(m.Modalities.Input) > 0 && !slices.Contains(m.Modalities.Input, "image") { - messages = stripImageContent(messages) - } - // before_llm_call hooks fire just before the model is invoked. // A terminating verdict (e.g. from the max_iterations builtin) // stops the run loop here, before any tokens are spent. @@ -393,6 +386,13 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, return } + // Apply registered before_llm_call message transforms (e.g. + // strip_unsupported_modalities for text-only models, plus any + // embedder-supplied redactor / scrubber registered via + // WithMessageTransform). Runs after the gate so a transform + // failure cannot waste the gate's allow verdict. + messages = r.applyBeforeLLMCallTransforms(ctx, sess, a, messages) + // Try primary model with fallback chain if configured res, usedModel, err := r.fallback.execute(streamCtx, a, model, messages, agentTools, sess, m, events) if err != nil { diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 222429b52..c86f9e973 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -161,6 +161,27 @@ type LocalRuntime struct { // construction, so no locking is needed. hooksExecByAgent map[string]*hooks.Executor + // transforms holds the runtime-private [MessageTransform] table, + // keyed by builtin name. Populated by [registerMessageTransform] + // from [NewLocalRuntime] (for the runtime-shipped strippers) and + // from [WithMessageTransform] (for embedder-supplied transforms). + // Read-only after construction. + transforms map[string]MessageTransform + + // transformNames is the registration-order list of names in + // [transforms], used by [applyMessageTransformDefaults] to inject + // hook entries deterministically (a Go map iteration would scramble + // the order across runs and break the chain semantics tests rely + // on). Read-only after construction. + transformNames []string + + // transformsByAgent is the per-agent resolution of the message + // transforms registered in [transforms], pre-walked from the + // agent's before_llm_call hook config. Built alongside + // [hooksExecByAgent] in [buildHooksExecutors] so per-LLM-call + // dispatch is a flat slice walk. + transformsByAgent map[string][]resolvedTransform + fallback *fallbackExecutor // observers receive every event the runtime produces, in @@ -392,6 +413,16 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { return nil, fmt.Errorf("register %q builtin: %w", BuiltinCacheResponse, err) } + // strip_unsupported_modalities is the runtime-shipped + // before_llm_call message transform that drops image content from + // messages when the agent's model is text-only. Like + // cache_response it captures the runtime closure (to resolve the + // agent and its model from Input.AgentName) and is therefore + // registered here rather than in pkg/hooks/builtins. + if err := r.registerMessageTransform(BuiltinStripUnsupportedModalities, r.stripUnsupportedModalitiesTransform); err != nil { + return nil, fmt.Errorf("register %q transform: %w", BuiltinStripUnsupportedModalities, err) + } + for _, opt := range opts { opt(r) } diff --git a/pkg/runtime/streaming.go b/pkg/runtime/streaming.go index d3635ba82..937fdb227 100644 --- a/pkg/runtime/streaming.go +++ b/pkg/runtime/streaming.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "log/slog" "strings" "github.com/docker/docker-agent/pkg/agent" @@ -236,39 +235,3 @@ func handleStream(ctx context.Context, stream chat.MessageStream, a *agent.Agent Usage: messageUsage, }, nil } - -// stripImageContent returns a copy of messages with all image-related content -// removed. This is used when the target model doesn't support image input to -// prevent API errors. Text content is preserved; image parts in MultiContent -// are filtered out, and file attachments with image MIME types are dropped. -func stripImageContent(messages []chat.Message) []chat.Message { - result := make([]chat.Message, len(messages)) - for i, msg := range messages { - result[i] = msg - - if len(msg.MultiContent) == 0 { - continue - } - - var filtered []chat.MessagePart - for _, part := range msg.MultiContent { - switch part.Type { - case chat.MessagePartTypeImageURL: - // Drop image URL parts entirely. - continue - case chat.MessagePartTypeFile: - // Drop file parts that are images. - if part.File != nil && chat.IsImageMimeType(part.File.MimeType) { - continue - } - } - filtered = append(filtered, part) - } - - if len(filtered) != len(msg.MultiContent) { - result[i].MultiContent = filtered - slog.Debug("Stripped image content from message", "role", msg.Role, "original_parts", len(msg.MultiContent), "remaining_parts", len(filtered)) - } - } - return result -} diff --git a/pkg/runtime/strip_modalities.go b/pkg/runtime/strip_modalities.go new file mode 100644 index 000000000..6483312bc --- /dev/null +++ b/pkg/runtime/strip_modalities.go @@ -0,0 +1,122 @@ +package runtime + +import ( + "context" + "log/slog" + "slices" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/hooks" +) + +// BuiltinStripUnsupportedModalities is the name of the builtin +// before_llm_call message transform that drops image content from the +// outgoing messages when the agent's current model doesn't list image +// in its input modalities. It is auto-injected by +// [LocalRuntime.applyMessageTransformDefaults] for every agent, +// mirroring [BuiltinCacheResponse]'s auto-injection from +// [applyCacheDefault]. +// +// Sending images to a text-only model produces hard provider errors +// (HTTP 400 from OpenAI, "image input is not supported" from Anthropic +// text variants, etc.); the runtime previously side-stepped this with +// an inline strip in runStreamLoop. Promoting it to a registered +// transform makes the behavior visible to user-authored hook +// configurations, lets it be deduplicated/ordered alongside other +// transforms, and opens the door to a family of message-mutating +// builtins (redactors, scrubbers, ...). +const BuiltinStripUnsupportedModalities = "strip_unsupported_modalities" + +// modalityImage is the canonical models.dev modality name for image +// input. Constants instead of literals so a typo trips a compile +// error and the contract with [modelsdev.Modalities.Input] is +// discoverable from the runtime side. +const modalityImage = "image" + +// stripUnsupportedModalitiesTransform is the [MessageTransform] +// registered as [BuiltinStripUnsupportedModalities]. It resolves the +// agent (and therefore its current model) through the runtime closure +// and the [hooks.Input.AgentName] field, looks up the model +// definition, and applies [stripImageContent] when the model's input +// modalities are known and don't include image. +// +// The transform is a no-op (returns msgs unchanged, nil error) for +// every "we don't know enough to act" case: missing agent, missing +// model definition (unknown model ID, models.dev fetch failed), +// missing modalities list, or image already supported. Erring on the +// side of "send the messages as-is" matches the previous inline +// behavior in runStreamLoop, where an unknown model also fell through. +func (r *LocalRuntime) stripUnsupportedModalitiesTransform( + ctx context.Context, + in *hooks.Input, + _ []string, + msgs []chat.Message, +) ([]chat.Message, error) { + if in == nil || in.AgentName == "" { + return msgs, nil + } + a, err := r.team.Agent(in.AgentName) + if err != nil || a == nil { + return msgs, nil + } + model := a.Model() + if model == nil { + return msgs, nil + } + m, err := r.modelsStore.GetModel(ctx, model.ID()) + if err != nil || m == nil { + // Unknown model: keep the previous (inline) behavior of + // passing messages through untouched. The model call will + // surface any modality mismatch as a provider error. + return msgs, nil + } + if len(m.Modalities.Input) == 0 || slices.Contains(m.Modalities.Input, modalityImage) { + return msgs, nil + } + return stripImageContent(msgs), nil +} + +// stripImageContent returns a copy of messages with all image-related +// content removed. Text content is preserved; image parts in +// [chat.Message.MultiContent] are filtered out, and file attachments +// with image MIME types are dropped. +// +// Lives next to [stripUnsupportedModalitiesTransform] (rather than in +// streaming.go where it originated) so the builtin's storage, +// transform, and helper are co-located. Kept as an unexported helper +// because the only legitimate caller is the transform itself — direct +// use bypasses the modality check. +func stripImageContent(messages []chat.Message) []chat.Message { + result := make([]chat.Message, len(messages)) + for i, msg := range messages { + result[i] = msg + + if len(msg.MultiContent) == 0 { + continue + } + + var filtered []chat.MessagePart + for _, part := range msg.MultiContent { + switch part.Type { + case chat.MessagePartTypeImageURL: + // Drop image URL parts entirely. + continue + case chat.MessagePartTypeFile: + // Drop file parts that are images. + if part.File != nil && chat.IsImageMimeType(part.File.MimeType) { + continue + } + } + filtered = append(filtered, part) + } + + if len(filtered) != len(msg.MultiContent) { + result[i].MultiContent = filtered + slog.Debug("Stripped image content from message", + "role", msg.Role, + "original_parts", len(msg.MultiContent), + "remaining_parts", len(filtered)) + } + } + return result +} diff --git a/pkg/runtime/transforms.go b/pkg/runtime/transforms.go new file mode 100644 index 000000000..ca3565192 --- /dev/null +++ b/pkg/runtime/transforms.go @@ -0,0 +1,222 @@ +package runtime + +import ( + "context" + "errors" + "log/slog" + "strings" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/session" +) + +// MessageTransform is the in-process-only handler signature for a +// before_llm_call builtin that rewrites the chat messages about to be +// sent to the model. It receives the full message slice in configured +// order and returns the (possibly-rewritten) replacement. +// +// Transforms are intentionally a runtime-private contract: the cost of +// JSON-roundtripping a full conversation through the cross-process +// hook protocol would be prohibitive, so command and model hooks +// cannot rewrite messages. Embedders register transforms via +// [WithMessageTransform]; the runtime ships [BuiltinStripUnsupportedModalities] +// out of the box. +// +// Transforms run AFTER the standard before_llm_call gate (see +// [LocalRuntime.executeBeforeLLMCallHooks]) — a hook that wants to +// abort the call should target the gate, not a transform. +// +// Returning a non-nil error logs a warning and falls through to the +// previous message slice; a transform failure must never break the run +// loop. +type MessageTransform func(ctx context.Context, in *hooks.Input, args []string, msgs []chat.Message) ([]chat.Message, error) + +// resolvedTransform pairs a registered [MessageTransform] with the +// hook-config args that selected it. Pre-resolved once during +// [LocalRuntime.buildHooksExecutors] so per-call dispatch is a flat +// slice walk. +type resolvedTransform struct { + name string + args []string + fn MessageTransform +} + +// WithMessageTransform registers a [MessageTransform] under name. The +// transform is auto-applied to every agent for the before_llm_call +// event (the runtime injects a corresponding `{type: builtin, command: +// name}` entry into each agent's hook config), giving custom +// redactors / scrubbers / modality strippers an always-on lifecycle +// without any per-agent YAML. +// +// Empty name or nil fn are silently ignored, matching the no-error +// shape of the other [Opt] helpers; explicit registration via +// [LocalRuntime.RegisterMessageTransform] (called internally) returns +// errors for those cases. +func WithMessageTransform(name string, fn MessageTransform) Opt { + return func(r *LocalRuntime) { + if name == "" || fn == nil { + return + } + if err := r.registerMessageTransform(name, fn); err != nil { + slog.Warn("Failed to register message transform; ignoring", "name", name, "error", err) + } + } +} + +// registerMessageTransform records fn under name in the runtime's +// transform table AND registers a no-op [hooks.BuiltinFunc] shim for +// the same name on the runtime's hooks registry. The shim makes the +// auto-injected `{type: builtin, command: name}` entry resolvable by +// the standard [hooks.Executor] (so it doesn't fail with "no builtin +// hook registered as ..."), while the actual rewrite happens through +// the typed transform path in [applyBeforeLLMCallTransforms]. +// +// Re-registering an existing name replaces the previous transform but +// does NOT change its position in the auto-injection order, so users +// of [WithMessageTransform] get a stable, predictable chain regardless +// of how often they patch in test code. +func (r *LocalRuntime) registerMessageTransform(name string, fn MessageTransform) error { + if name == "" { + return errors.New("message transform name must not be empty") + } + if fn == nil { + return errors.New("message transform function must not be nil") + } + if r.transforms == nil { + r.transforms = make(map[string]MessageTransform) + } + if _, exists := r.transforms[name]; !exists { + r.transformNames = append(r.transformNames, name) + } + r.transforms[name] = fn + // No-op shim: the standard hooks executor will see the + // auto-injected builtin entry and dispatch this; returning nil + // signals "ran cleanly, no opinion" so aggregate() doesn't + // surface a warning. The actual message rewrite happens through + // applyBeforeLLMCallTransforms. + return r.hooksRegistry.RegisterBuiltin(name, noopBuiltin) +} + +// noopBuiltin is the [hooks.BuiltinFunc] companion used by every +// registered [MessageTransform]: it accepts the JSON-serialized input +// and returns nothing. Pulled out as a package-level value so all +// transforms share the same function pointer (cheap dedup, easier to +// recognize in logs). +func noopBuiltin(_ context.Context, _ *hooks.Input, _ []string) (*hooks.Output, error) { + return nil, nil +} + +// applyMessageTransformDefaults appends a `{type: builtin, command: +// name}` entry to cfg.BeforeLLMCall for every registered +// [MessageTransform], mirroring the role of [applyCacheDefault] for +// the cache_response stop builtin and of [builtins.ApplyAgentDefaults] +// for the date / env / prompt-files turn_start builtins. +// +// Transforms are auto-injected in registration order (see +// [registerMessageTransform]), giving callers a stable, predictable +// chain even though the underlying lookup table is a map. +// +// The helper accepts (and may return) a nil cfg so callers can chain +// it after the other default helpers without an extra branch. It is a +// no-op when no transforms are registered, in which case it preserves +// the cfg-may-be-nil contract. +func (r *LocalRuntime) applyMessageTransformDefaults(cfg *hooks.Config) *hooks.Config { + if len(r.transformNames) == 0 { + return cfg + } + if cfg == nil { + cfg = &hooks.Config{} + } + for _, name := range r.transformNames { + cfg.BeforeLLMCall = append(cfg.BeforeLLMCall, hooks.Hook{ + Type: hooks.HookTypeBuiltin, + Command: name, + }) + } + return cfg +} + +// resolveTransforms walks cfg.BeforeLLMCall in configured order and +// returns the registered [MessageTransform]s to apply, deduplicated by +// (name, args) so a user-authored YAML entry that overlaps the +// runtime's auto-injected builtin doesn't run the transform twice. +// +// Returns nil for an empty resolution so callers can short-circuit +// cheaply on the (common) no-transforms path. +func (r *LocalRuntime) resolveTransforms(cfg *hooks.Config) []resolvedTransform { + if cfg == nil || len(cfg.BeforeLLMCall) == 0 || len(r.transforms) == 0 { + return nil + } + var out []resolvedTransform + seen := make(map[string]bool) + for _, h := range cfg.BeforeLLMCall { + if h.Type != hooks.HookTypeBuiltin { + continue + } + fn, ok := r.transforms[h.Command] + if !ok { + continue + } + key := transformDedupKey(h.Command, h.Args) + if seen[key] { + continue + } + seen[key] = true + out = append(out, resolvedTransform{name: h.Command, args: h.Args, fn: fn}) + } + return out +} + +// transformDedupKey mirrors [hooks.dedupKey]'s (command, args) shape so +// transforms are deduplicated on the same axis as the standard hook +// executor. Type is always `builtin` for transforms, so it's not part +// of the key. +func transformDedupKey(name string, args []string) string { + var b strings.Builder + b.WriteString(name) + for _, a := range args { + b.WriteByte(0) + b.WriteString(a) + } + return b.String() +} + +// applyBeforeLLMCallTransforms dispatches the agent's pre-resolved +// [MessageTransform] chain just before the model call, AFTER +// [LocalRuntime.executeBeforeLLMCallHooks] has run its gate. Transforms +// rewrite (or drop) messages but cannot abort the call — that +// responsibility lives with the gate. +// +// Returns the (possibly-rewritten) message slice. Errors from +// individual transforms are logged at warn level and the chain +// continues with the previous slice, matching the executor's "warn, +// don't break the loop" stance for non-fail-closed events. +func (r *LocalRuntime) applyBeforeLLMCallTransforms( + ctx context.Context, + sess *session.Session, + a *agent.Agent, + msgs []chat.Message, +) []chat.Message { + transforms := r.transformsByAgent[a.Name()] + if len(transforms) == 0 { + return msgs + } + in := &hooks.Input{ + SessionID: sess.ID, + AgentName: a.Name(), + HookEventName: hooks.EventBeforeLLMCall, + Cwd: r.workingDir, + } + for _, t := range transforms { + out, err := t.fn(ctx, in, t.args, msgs) + if err != nil { + slog.Warn("Message transform failed; continuing with previous messages", + "transform", t.name, "agent", a.Name(), "error", err) + continue + } + msgs = out + } + return msgs +} diff --git a/pkg/runtime/transforms_test.go b/pkg/runtime/transforms_test.go new file mode 100644 index 000000000..32cd0c094 --- /dev/null +++ b/pkg/runtime/transforms_test.go @@ -0,0 +1,471 @@ +package runtime + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/team" + "github.com/docker/docker-agent/pkg/tools" +) + +// modalityModelStore returns a fixed [modelsdev.Model] regardless of +// the requested ID. Tests configure its Modalities to exercise the +// strip_unsupported_modalities transform's three branches: text-only +// (strip), image-supporting (no-op), and unknown-modality (no-op, +// no panic). +type modalityModelStore struct { + ModelStore + + model *modelsdev.Model + err error +} + +func (m modalityModelStore) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { + return m.model, m.err +} + +// recordingMsgProvider captures the messages each model call sees so a +// test can confirm a transform actually rewrote what reached the +// provider (rather than just what the in-memory slice ended up looking +// like). +type recordingMsgProvider struct { + mockProvider + + got [][]chat.Message +} + +func (p *recordingMsgProvider) CreateChatCompletionStream(_ context.Context, msgs []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { + snap := append([]chat.Message{}, msgs...) + p.got = append(p.got, snap) + return p.stream, nil +} + +// TestStripUnsupportedModalitiesTransform_TextOnlyModelDropsImages +// pins the runtime-shipped [stripUnsupportedModalitiesTransform]'s +// happy path: a text-only model receives messages with all image +// content stripped, while text content is preserved. +func TestStripUnsupportedModalitiesTransform_TextOnlyModelDropsImages(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/text-only", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + store := modalityModelStore{model: &modelsdev.Model{ + Modalities: modelsdev.Modalities{Input: []string{"text"}}, + }} + r, err := NewLocalRuntime(tm, WithModelStore(store)) + require.NoError(t, err) + + in := &hooks.Input{AgentName: "root"} + msgs := []chat.Message{ + { + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "look at this"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, + }, + }, + } + + got, err := r.stripUnsupportedModalitiesTransform(t.Context(), in, nil, msgs) + require.NoError(t, err) + require.Len(t, got, 1) + require.Len(t, got[0].MultiContent, 1, "image part must be stripped") + assert.Equal(t, chat.MessagePartTypeText, got[0].MultiContent[0].Type) +} + +// TestStripUnsupportedModalitiesTransform_ImageModelPassThrough pins +// the no-op branch: when the model's input modalities include "image", +// messages must reach the provider unchanged. +func TestStripUnsupportedModalitiesTransform_ImageModelPassThrough(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/multimodal", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + store := modalityModelStore{model: &modelsdev.Model{ + Modalities: modelsdev.Modalities{Input: []string{"text", "image"}}, + }} + r, err := NewLocalRuntime(tm, WithModelStore(store)) + require.NoError(t, err) + + in := &hooks.Input{AgentName: "root"} + msgs := []chat.Message{{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "describe this"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, + }, + }} + + got, err := r.stripUnsupportedModalitiesTransform(t.Context(), in, nil, msgs) + require.NoError(t, err) + assert.Equal(t, msgs, got, "messages must reach a multimodal model untouched") +} + +// TestStripUnsupportedModalitiesTransform_UnknownModelPassThrough pins +// the safe-fallback branch: when the models.dev lookup fails (or +// returns nil), the transform returns msgs unchanged so the request +// still reaches the provider; any modality mismatch surfaces as a +// provider error rather than a transform-side panic. +func TestStripUnsupportedModalitiesTransform_UnknownModelPassThrough(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/unknown", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + cases := []struct { + name string + store modalityModelStore + }{ + {name: "nil model", store: modalityModelStore{model: nil}}, + {name: "lookup error", store: modalityModelStore{err: errors.New("not found")}}, + {name: "empty modalities", store: modalityModelStore{model: &modelsdev.Model{}}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r, err := NewLocalRuntime(tm, WithModelStore(tc.store)) + require.NoError(t, err) + + msgs := []chat.Message{{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "x"}}, + }, + }} + got, err := r.stripUnsupportedModalitiesTransform(t.Context(), &hooks.Input{AgentName: "root"}, nil, msgs) + require.NoError(t, err) + assert.Equal(t, msgs, got, "unknown model must fall through unchanged") + }) + } +} + +// TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap covers the hot +// path: an agent without any registered transforms runs no allocator, +// no slog noise, and returns the input slice as-is. The test also +// covers the "agent not in transformsByAgent" branch for agents +// constructed outside the runtime's normal flow. +func TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) + require.NoError(t, err) + + // Drop the auto-registered strip_unsupported_modalities so we can + // observe the cheap-path behavior. + r.transforms = nil + r.transformNames = nil + r.transformsByAgent = nil + + sess := session.New(session.WithUserMessage("hi")) + msgs := []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}} + + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, msgs) + assert.Equal(t, msgs, got) +} + +// TestApplyBeforeLLMCallTransforms_OrderAndArgs verifies that +// transforms registered via [WithMessageTransform] (a) auto-inject a +// before_llm_call entry on every agent, (b) run in configured order, +// and (c) receive the per-hook args from the YAML / auto-injection. +func TestApplyBeforeLLMCallTransforms_OrderAndArgs(t *testing.T) { + t.Parallel() + + type call struct { + name string + args []string + seen []chat.Message + } + var calls []call + + tagA := func(_ context.Context, _ *hooks.Input, args []string, msgs []chat.Message) ([]chat.Message, error) { + seen := append([]chat.Message{}, msgs...) + calls = append(calls, call{name: "tag_a", args: args, seen: seen}) + out := append([]chat.Message{}, msgs...) + out = append(out, chat.Message{Role: chat.MessageRoleSystem, Content: "tag_a"}) + return out, nil + } + tagB := func(_ context.Context, _ *hooks.Input, args []string, msgs []chat.Message) ([]chat.Message, error) { + seen := append([]chat.Message{}, msgs...) + calls = append(calls, call{name: "tag_b", args: args, seen: seen}) + out := append([]chat.Message{}, msgs...) + out = append(out, chat.Message{Role: chat.MessageRoleSystem, Content: "tag_b"}) + return out, nil + } + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + r, err := NewLocalRuntime(tm, + WithModelStore(mockModelStore{}), + WithMessageTransform("tag_a", tagA), + WithMessageTransform("tag_b", tagB), + ) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("hi")) + msgs := []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}} + + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, msgs) + + // The two registered tag transforms each fire exactly once. + // (The runtime-shipped strip transform also runs, but it doesn't + // append to `calls` since it's a different function.) + require.Len(t, calls, 2, "expected tag_a + tag_b to fire exactly once each") + + // Registration order must be preserved: tag_a was registered first, + // so it must be invoked first; tag_b second. + assert.Equal(t, "tag_a", calls[0].name, "transforms must run in registration order") + assert.Equal(t, "tag_b", calls[1].name, "transforms must run in registration order") + + // Cumulative semantics: the second transform must have observed the + // first transform's appended message. + assert.Greater(t, len(calls[1].seen), len(calls[0].seen), + "tag_b must see tag_a's appended message (chain semantics, not parallel)") + + // The final slice must contain both tags. + var finalContent []string + for _, m := range got { + finalContent = append(finalContent, m.Content) + } + assert.Contains(t, finalContent, "tag_a") + assert.Contains(t, finalContent, "tag_b") +} + +// TestApplyBeforeLLMCallTransforms_ErrorsAreSwallowed pins the +// fail-soft contract: a transform that returns an error must NOT +// break the run loop; the previous slice continues through the chain. +func TestApplyBeforeLLMCallTransforms_ErrorsAreSwallowed(t *testing.T) { + t.Parallel() + + failing := func(_ context.Context, _ *hooks.Input, _ []string, _ []chat.Message) ([]chat.Message, error) { + return nil, errors.New("boom") + } + tag := func(_ context.Context, _ *hooks.Input, _ []string, msgs []chat.Message) ([]chat.Message, error) { + out := append([]chat.Message{}, msgs...) + out = append(out, chat.Message{Role: chat.MessageRoleSystem, Content: "after_failure"}) + return out, nil + } + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + r, err := NewLocalRuntime(tm, + WithModelStore(mockModelStore{}), + WithMessageTransform("failing", failing), + WithMessageTransform("tag", tag), + ) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("hi")) + msgs := []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}} + + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, msgs) + + // The "tag" transform must have run despite the failing one + // erroring out, and its output must be present. + var contents []string + for _, m := range got { + contents = append(contents, m.Content) + } + assert.Contains(t, contents, "after_failure", + "a transform error must not abort the chain") +} + +// TestRunStream_StripsImagesForTextOnlyModel confirms the inline +// strip in runStreamLoop has been replaced end-to-end: messages +// reaching the provider must no longer carry image parts when the +// agent's model is text-only. +func TestRunStream_StripsImagesForTextOnlyModel(t *testing.T) { + t.Parallel() + + stream := newStreamBuilder().AddContent("ok").AddStopWithUsage(1, 1).Build() + prov := &recordingMsgProvider{mockProvider: mockProvider{id: "test/text-only", stream: stream}} + + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + store := modalityModelStore{model: &modelsdev.Model{ + Modalities: modelsdev.Modalities{Input: []string{"text"}}, + }} + r, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(store)) + require.NoError(t, err) + + sess := session.New() + sess.AddMessage(session.UserMessage("")) + // Replace the empty user message with a multi-part one carrying an image. + last := &sess.Messages[len(sess.Messages)-1] + last.Message.Message.MultiContent = []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "describe"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, + } + + for range r.RunStream(t.Context(), sess) { + // drain — only the recorded provider state matters + } + + require.NotEmpty(t, prov.got, "provider must have been called") + for _, m := range prov.got[0] { + for _, p := range m.MultiContent { + assert.NotEqual(t, chat.MessagePartTypeImageURL, p.Type, + "image parts must be stripped before reaching a text-only model") + } + } +} + +// TestApplyMessageTransformDefaults_NoTransformsPreservesNil keeps the +// "nil cfg may stay nil when there are no defaults to add" contract, +// matching [applyCacheDefault]'s shape so [buildHooksExecutors] can +// continue to skip executor construction for agents with no hooks. +func TestApplyMessageTransformDefaults_NoTransformsPreservesNil(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) + require.NoError(t, err) + + r.transforms = nil // simulate "no transforms registered" + r.transformNames = nil + got := r.applyMessageTransformDefaults(nil) + assert.Nil(t, got, "no transforms registered must preserve a nil cfg") +} + +// TestResolveTransforms_DedupsByCommandAndArgs guards against double +// invocation when an agent's user-authored YAML already lists a +// builtin that the runtime auto-injects on top. +func TestResolveTransforms_DedupsByCommandAndArgs(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) + require.NoError(t, err) + + cfg := &hooks.Config{ + BeforeLLMCall: []hooks.Hook{ + {Type: hooks.HookTypeBuiltin, Command: BuiltinStripUnsupportedModalities}, + {Type: hooks.HookTypeBuiltin, Command: BuiltinStripUnsupportedModalities}, + {Type: hooks.HookTypeBuiltin, Command: BuiltinStripUnsupportedModalities, Args: []string{"foo"}}, + }, + } + + got := r.resolveTransforms(cfg) + require.Len(t, got, 2, "duplicate (name, args) must collapse to one") + assert.Equal(t, BuiltinStripUnsupportedModalities, got[0].name) + assert.Empty(t, got[0].args) + assert.Equal(t, []string{"foo"}, got[1].args, "differing args must NOT be deduplicated") +} + +// TestRegisterMessageTransform_ShimAvoidsExecutorErrors confirms the +// shim wired by [registerMessageTransform]: a hooks.Executor.Dispatch +// for a registered-transform builtin must succeed (Allowed=true, +// Result is a no-op) instead of failing with "no builtin hook +// registered as ...". +func TestRegisterMessageTransform_ShimAvoidsExecutorErrors(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + called := 0 + tag := func(_ context.Context, _ *hooks.Input, _ []string, msgs []chat.Message) ([]chat.Message, error) { + called++ + return msgs, nil + } + r, err := NewLocalRuntime(tm, + WithModelStore(mockModelStore{}), + WithMessageTransform("tag", tag), + ) + require.NoError(t, err) + + exec := r.hooksExec(a) + require.NotNil(t, exec) + + res, err := exec.Dispatch(t.Context(), hooks.EventBeforeLLMCall, &hooks.Input{ + SessionID: "session", AgentName: "root", + }) + require.NoError(t, err, "executor must not error on a transform-only builtin") + assert.True(t, res.Allowed, "shim must report success") + // The transform itself isn't invoked through the executor — only via + // applyBeforeLLMCallTransforms — so `called` stays 0 here. + assert.Equal(t, 0, called, "executor path must NOT invoke the transform body") +} + +// TestRunStream_TransformErrorDoesNotBreakRun is an integration smoke +// test confirming end-to-end: a transform that returns an error must +// not prevent the model from being called; the run completes +// normally and the messages reaching the provider are the pre-error +// snapshot. +func TestRunStream_TransformErrorDoesNotBreakRun(t *testing.T) { + t.Parallel() + + stream := newStreamBuilder().AddContent("ok").AddStopWithUsage(1, 1).Build() + prov := &mockProvider{id: "test/mock-model", stream: stream} + + failing := func(_ context.Context, _ *hooks.Input, _ []string, _ []chat.Message) ([]chat.Message, error) { + return nil, errors.New("boom") + } + + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + r, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithMessageTransform("failing", failing), + ) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("hi")) + var sawStop bool + for ev := range r.RunStream(t.Context(), sess) { + if _, ok := ev.(*StreamStoppedEvent); ok { + sawStop = true + } + } + assert.True(t, sawStop, "run must complete despite a failing transform") +} + +// TestWithMessageTransform_RejectsEmptyAndNil pins the input +// validation: empty name or nil fn must be silently ignored (matching +// the no-error shape of other Opts), with a slog warning. +func TestWithMessageTransform_RejectsEmptyAndNil(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + r, err := NewLocalRuntime(tm, + WithModelStore(mockModelStore{}), + WithMessageTransform("", func(_ context.Context, _ *hooks.Input, _ []string, msgs []chat.Message) ([]chat.Message, error) { + return msgs, nil + }), + WithMessageTransform("nilfn", nil), + ) + require.NoError(t, err, "WithMessageTransform must not surface a constructor error") + + // Only the runtime-shipped strip transform should be in the table. + require.Len(t, r.transforms, 1, "invalid transforms must be silently ignored") + _, ok := r.transforms[BuiltinStripUnsupportedModalities] + assert.True(t, ok) +} From 48e2b71f445b58686dd8466e0fbec58745bbda52 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Tue, 28 Apr 2026 11:18:33 +0200 Subject: [PATCH 2/3] simplify message transforms: drop the YAML auto-injection plumbing --- pkg/runtime/hooks.go | 18 +- pkg/runtime/hooks_wiring_test.go | 31 ++- pkg/runtime/runtime.go | 31 +-- pkg/runtime/strip_modalities.go | 52 ++--- pkg/runtime/transforms.go | 189 +++--------------- pkg/runtime/transforms_test.go | 328 ++++++++----------------------- 6 files changed, 156 insertions(+), 493 deletions(-) diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index aa19cdc8e..0a1fad321 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -16,23 +16,17 @@ import ( // buildHooksExecutors builds a [hooks.Executor] for every agent in the // team that has user-configured hooks, an agent-flag that maps to a -// builtin (AddDate / AddEnvironmentInfo / AddPromptFiles), a +// builtin (AddDate / AddEnvironmentInfo / AddPromptFiles), or a // configured response cache (which auto-injects a cache_response stop -// hook), or any registered [MessageTransform] (which auto-injects a -// before_llm_call builtin). Agents with no hooks have no entry; -// lookups fall through to nil so callers can short-circuit cheaply. -// -// The matching [resolvedTransform] chain for each agent is also -// pre-resolved here into r.transformsByAgent so per-LLM-call -// dispatch is a flat slice walk. +// hook). Agents with no hooks have no entry; lookups fall through to +// nil so callers can short-circuit cheaply. // // Called once from [NewLocalRuntime] after r.workingDir, r.env and -// r.hooksRegistry are finalized; the resulting maps are read-only for +// r.hooksRegistry are finalized; the resulting map is read-only for // the lifetime of the runtime, so per-dispatch lookups don't need to // lock. func (r *LocalRuntime) buildHooksExecutors() { r.hooksExecByAgent = make(map[string]*hooks.Executor) - r.transformsByAgent = make(map[string][]resolvedTransform) for _, name := range r.team.AgentNames() { a, err := r.team.Agent(name) if err != nil { @@ -44,14 +38,10 @@ func (r *LocalRuntime) buildHooksExecutors() { AddPromptFiles: a.AddPromptFiles(), }) cfg = applyCacheDefault(cfg, a) - cfg = r.applyMessageTransformDefaults(cfg) if cfg == nil { continue } r.hooksExecByAgent[name] = hooks.NewExecutorWithRegistry(cfg, r.workingDir, r.env, r.hooksRegistry) - if transforms := r.resolveTransforms(cfg); len(transforms) > 0 { - r.transformsByAgent[name] = transforms - } } } diff --git a/pkg/runtime/hooks_wiring_test.go b/pkg/runtime/hooks_wiring_test.go index 775aee53e..2fe0f1923 100644 --- a/pkg/runtime/hooks_wiring_test.go +++ b/pkg/runtime/hooks_wiring_test.go @@ -21,12 +21,6 @@ import ( // - AddPromptFiles -> turn_start (file may be edited mid-session) // - AddEnvironmentInfo -> session_start (wd/OS/arch don't change) // -// Every agent additionally receives an auto-injected -// [BuiltinStripUnsupportedModalities] entry on before_llm_call (the -// runtime-shipped message transform that drops images for text-only -// models), so the executor is always non-nil — even for an agent -// without any explicit flags. -// // The behavior of each builtin (what it puts in AdditionalContext) is // covered by pkg/hooks/builtins; this test only asserts the wiring, // using a smoke Dispatch to confirm that the registered builtin name @@ -39,14 +33,16 @@ func TestHooksExecWiresAgentFlagsToBuiltins(t *testing.T) { prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} cases := []struct { - name string - opts []agent.Opt - wantTurnStart bool - wantSessStart bool + name string + opts []agent.Opt + wantNoExecutor bool + wantTurnStart bool + wantSessStart bool }{ { - name: "no flags: only the auto-injected strip transform on before_llm_call", - opts: []agent.Opt{agent.WithModel(prov)}, + name: "no flags: no implicit hooks, no executor", + opts: []agent.Opt{agent.WithModel(prov)}, + wantNoExecutor: true, }, { name: "AddDate wires turn_start", @@ -86,17 +82,16 @@ func TestHooksExecWiresAgentFlagsToBuiltins(t *testing.T) { require.NoError(t, err) exec := r.hooksExec(a) - require.NotNil(t, exec, "every agent receives the auto-injected strip transform") + if tc.wantNoExecutor { + assert.Nil(t, exec, "no flags must not produce an executor") + return + } + require.NotNil(t, exec) // hooksExec is read-only after [LocalRuntime.buildHooksExecutors], // so calling it twice returns the same pointer. assert.Same(t, exec, r.hooksExec(a), "hooksExec must be stable across calls") - // before_llm_call always carries the strip_unsupported_modalities - // builtin, regardless of agent flags. - assert.True(t, exec.Has(hooks.EventBeforeLLMCall), - "before_llm_call must always carry the auto-injected strip transform") - assert.Equal(t, tc.wantTurnStart, exec.Has(hooks.EventTurnStart), "turn_start activation must match flags") assert.Equal(t, tc.wantSessStart, exec.Has(hooks.EventSessionStart), diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index c86f9e973..19b9c0c0b 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -161,26 +161,12 @@ type LocalRuntime struct { // construction, so no locking is needed. hooksExecByAgent map[string]*hooks.Executor - // transforms holds the runtime-private [MessageTransform] table, - // keyed by builtin name. Populated by [registerMessageTransform] - // from [NewLocalRuntime] (for the runtime-shipped strippers) and - // from [WithMessageTransform] (for embedder-supplied transforms). + // transforms is the runtime's [MessageTransform] chain, applied to + // every LLM call in registration order. Populated by + // [NewLocalRuntime] (for the runtime-shipped strippers) and by + // [WithMessageTransform] (for embedder-supplied transforms). // Read-only after construction. - transforms map[string]MessageTransform - - // transformNames is the registration-order list of names in - // [transforms], used by [applyMessageTransformDefaults] to inject - // hook entries deterministically (a Go map iteration would scramble - // the order across runs and break the chain semantics tests rely - // on). Read-only after construction. - transformNames []string - - // transformsByAgent is the per-agent resolution of the message - // transforms registered in [transforms], pre-walked from the - // agent's before_llm_call hook config. Built alongside - // [hooksExecByAgent] in [buildHooksExecutors] so per-LLM-call - // dispatch is a flat slice walk. - transformsByAgent map[string][]resolvedTransform + transforms []registeredTransform fallback *fallbackExecutor @@ -419,9 +405,10 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { // cache_response it captures the runtime closure (to resolve the // agent and its model from Input.AgentName) and is therefore // registered here rather than in pkg/hooks/builtins. - if err := r.registerMessageTransform(BuiltinStripUnsupportedModalities, r.stripUnsupportedModalitiesTransform); err != nil { - return nil, fmt.Errorf("register %q transform: %w", BuiltinStripUnsupportedModalities, err) - } + r.transforms = append(r.transforms, registeredTransform{ + name: BuiltinStripUnsupportedModalities, + fn: r.stripUnsupportedModalitiesTransform, + }) for _, opt := range opts { opt(r) diff --git a/pkg/runtime/strip_modalities.go b/pkg/runtime/strip_modalities.go index 6483312bc..63c679659 100644 --- a/pkg/runtime/strip_modalities.go +++ b/pkg/runtime/strip_modalities.go @@ -9,61 +9,51 @@ import ( "github.com/docker/docker-agent/pkg/hooks" ) -// BuiltinStripUnsupportedModalities is the name of the builtin +// BuiltinStripUnsupportedModalities is the name of the runtime-shipped // before_llm_call message transform that drops image content from the // outgoing messages when the agent's current model doesn't list image -// in its input modalities. It is auto-injected by -// [LocalRuntime.applyMessageTransformDefaults] for every agent, -// mirroring [BuiltinCacheResponse]'s auto-injection from -// [applyCacheDefault]. +// in its input modalities. It's the runtime-shipped peer of +// [BuiltinCacheResponse] (a stop hook) — the constant exists mostly +// for log filtering and diagnostics. // // Sending images to a text-only model produces hard provider errors -// (HTTP 400 from OpenAI, "image input is not supported" from Anthropic -// text variants, etc.); the runtime previously side-stepped this with -// an inline strip in runStreamLoop. Promoting it to a registered -// transform makes the behavior visible to user-authored hook -// configurations, lets it be deduplicated/ordered alongside other -// transforms, and opens the door to a family of message-mutating -// builtins (redactors, scrubbers, ...). +// (HTTP 400 from OpenAI, "image input is not supported" from +// Anthropic text variants, etc.); promoting the strip into a +// registered transform replaces an inline branch in runStreamLoop and +// opens the door to a family of message-mutating transforms +// (redactors, scrubbers, ...). const BuiltinStripUnsupportedModalities = "strip_unsupported_modalities" // modalityImage is the canonical models.dev modality name for image -// input. Constants instead of literals so a typo trips a compile +// input. A constant instead of a literal so a typo trips a compile // error and the contract with [modelsdev.Modalities.Input] is // discoverable from the runtime side. const modalityImage = "image" // stripUnsupportedModalitiesTransform is the [MessageTransform] -// registered as [BuiltinStripUnsupportedModalities]. It resolves the -// agent (and therefore its current model) through the runtime closure -// and the [hooks.Input.AgentName] field, looks up the model -// definition, and applies [stripImageContent] when the model's input -// modalities are known and don't include image. +// registered under [BuiltinStripUnsupportedModalities]. It resolves +// the agent (and therefore its current model) from +// [hooks.Input.AgentName], looks up the model's input modalities, and +// applies [stripImageContent] when image is missing from the list. // -// The transform is a no-op (returns msgs unchanged, nil error) for -// every "we don't know enough to act" case: missing agent, missing -// model definition (unknown model ID, models.dev fetch failed), -// missing modalities list, or image already supported. Erring on the -// side of "send the messages as-is" matches the previous inline -// behavior in runStreamLoop, where an unknown model also fell through. +// The transform is a no-op for every "we don't know enough to act" +// case (missing agent, missing model, models.dev miss, empty +// modalities, image already supported): erring on the side of "send +// the messages as-is" matches the previous inline behavior in +// runStreamLoop, where an unknown model also fell through. func (r *LocalRuntime) stripUnsupportedModalitiesTransform( ctx context.Context, in *hooks.Input, - _ []string, msgs []chat.Message, ) ([]chat.Message, error) { if in == nil || in.AgentName == "" { return msgs, nil } a, err := r.team.Agent(in.AgentName) - if err != nil || a == nil { + if err != nil || a == nil || a.Model() == nil { return msgs, nil } - model := a.Model() - if model == nil { - return msgs, nil - } - m, err := r.modelsStore.GetModel(ctx, model.ID()) + m, err := r.modelsStore.GetModel(ctx, a.Model().ID()) if err != nil || m == nil { // Unknown model: keep the previous (inline) behavior of // passing messages through untouched. The model call will diff --git a/pkg/runtime/transforms.go b/pkg/runtime/transforms.go index ca3565192..81a468f0e 100644 --- a/pkg/runtime/transforms.go +++ b/pkg/runtime/transforms.go @@ -2,9 +2,7 @@ package runtime import ( "context" - "errors" "log/slog" - "strings" "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/chat" @@ -13,194 +11,67 @@ import ( ) // MessageTransform is the in-process-only handler signature for a -// before_llm_call builtin that rewrites the chat messages about to be -// sent to the model. It receives the full message slice in configured +// before_llm_call transform that rewrites the chat messages about to +// be sent to the model. It receives the full message slice in chain // order and returns the (possibly-rewritten) replacement. // // Transforms are intentionally a runtime-private contract: the cost of // JSON-roundtripping a full conversation through the cross-process // hook protocol would be prohibitive, so command and model hooks // cannot rewrite messages. Embedders register transforms via -// [WithMessageTransform]; the runtime ships [BuiltinStripUnsupportedModalities] -// out of the box. +// [WithMessageTransform]; the runtime ships +// [BuiltinStripUnsupportedModalities] out of the box. // // Transforms run AFTER the standard before_llm_call gate (see // [LocalRuntime.executeBeforeLLMCallHooks]) — a hook that wants to // abort the call should target the gate, not a transform. // // Returning a non-nil error logs a warning and falls through to the -// previous message slice; a transform failure must never break the run -// loop. -type MessageTransform func(ctx context.Context, in *hooks.Input, args []string, msgs []chat.Message) ([]chat.Message, error) +// previous message slice; a transform failure must never break the +// run loop. +type MessageTransform func(ctx context.Context, in *hooks.Input, msgs []chat.Message) ([]chat.Message, error) -// resolvedTransform pairs a registered [MessageTransform] with the -// hook-config args that selected it. Pre-resolved once during -// [LocalRuntime.buildHooksExecutors] so per-call dispatch is a flat -// slice walk. -type resolvedTransform struct { +// registeredTransform pairs a [MessageTransform] with the name it was +// registered under. The name is purely diagnostic — it shows up in +// slog records when a transform errors out — so re-registering the +// same name simply appends another entry without any de-duplication. +type registeredTransform struct { name string - args []string fn MessageTransform } -// WithMessageTransform registers a [MessageTransform] under name. The -// transform is auto-applied to every agent for the before_llm_call -// event (the runtime injects a corresponding `{type: builtin, command: -// name}` entry into each agent's hook config), giving custom -// redactors / scrubbers / modality strippers an always-on lifecycle -// without any per-agent YAML. +// WithMessageTransform registers a [MessageTransform] under name so +// it is applied to every LLM call, in registration order, after the +// before_llm_call gate. Transforms are runtime-global: per-agent +// scoping (if needed) lives in the transform body, where +// [hooks.Input.AgentName] is available — the runtime-shipped strip +// transform is an example. // // Empty name or nil fn are silently ignored, matching the no-error -// shape of the other [Opt] helpers; explicit registration via -// [LocalRuntime.RegisterMessageTransform] (called internally) returns -// errors for those cases. +// shape of the other [Opt] helpers. func WithMessageTransform(name string, fn MessageTransform) Opt { return func(r *LocalRuntime) { if name == "" || fn == nil { + slog.Warn("Ignoring message transform with empty name or nil fn", "name", name) return } - if err := r.registerMessageTransform(name, fn); err != nil { - slog.Warn("Failed to register message transform; ignoring", "name", name, "error", err) - } - } -} - -// registerMessageTransform records fn under name in the runtime's -// transform table AND registers a no-op [hooks.BuiltinFunc] shim for -// the same name on the runtime's hooks registry. The shim makes the -// auto-injected `{type: builtin, command: name}` entry resolvable by -// the standard [hooks.Executor] (so it doesn't fail with "no builtin -// hook registered as ..."), while the actual rewrite happens through -// the typed transform path in [applyBeforeLLMCallTransforms]. -// -// Re-registering an existing name replaces the previous transform but -// does NOT change its position in the auto-injection order, so users -// of [WithMessageTransform] get a stable, predictable chain regardless -// of how often they patch in test code. -func (r *LocalRuntime) registerMessageTransform(name string, fn MessageTransform) error { - if name == "" { - return errors.New("message transform name must not be empty") - } - if fn == nil { - return errors.New("message transform function must not be nil") - } - if r.transforms == nil { - r.transforms = make(map[string]MessageTransform) - } - if _, exists := r.transforms[name]; !exists { - r.transformNames = append(r.transformNames, name) - } - r.transforms[name] = fn - // No-op shim: the standard hooks executor will see the - // auto-injected builtin entry and dispatch this; returning nil - // signals "ran cleanly, no opinion" so aggregate() doesn't - // surface a warning. The actual message rewrite happens through - // applyBeforeLLMCallTransforms. - return r.hooksRegistry.RegisterBuiltin(name, noopBuiltin) -} - -// noopBuiltin is the [hooks.BuiltinFunc] companion used by every -// registered [MessageTransform]: it accepts the JSON-serialized input -// and returns nothing. Pulled out as a package-level value so all -// transforms share the same function pointer (cheap dedup, easier to -// recognize in logs). -func noopBuiltin(_ context.Context, _ *hooks.Input, _ []string) (*hooks.Output, error) { - return nil, nil -} - -// applyMessageTransformDefaults appends a `{type: builtin, command: -// name}` entry to cfg.BeforeLLMCall for every registered -// [MessageTransform], mirroring the role of [applyCacheDefault] for -// the cache_response stop builtin and of [builtins.ApplyAgentDefaults] -// for the date / env / prompt-files turn_start builtins. -// -// Transforms are auto-injected in registration order (see -// [registerMessageTransform]), giving callers a stable, predictable -// chain even though the underlying lookup table is a map. -// -// The helper accepts (and may return) a nil cfg so callers can chain -// it after the other default helpers without an extra branch. It is a -// no-op when no transforms are registered, in which case it preserves -// the cfg-may-be-nil contract. -func (r *LocalRuntime) applyMessageTransformDefaults(cfg *hooks.Config) *hooks.Config { - if len(r.transformNames) == 0 { - return cfg - } - if cfg == nil { - cfg = &hooks.Config{} + r.transforms = append(r.transforms, registeredTransform{name: name, fn: fn}) } - for _, name := range r.transformNames { - cfg.BeforeLLMCall = append(cfg.BeforeLLMCall, hooks.Hook{ - Type: hooks.HookTypeBuiltin, - Command: name, - }) - } - return cfg -} - -// resolveTransforms walks cfg.BeforeLLMCall in configured order and -// returns the registered [MessageTransform]s to apply, deduplicated by -// (name, args) so a user-authored YAML entry that overlaps the -// runtime's auto-injected builtin doesn't run the transform twice. -// -// Returns nil for an empty resolution so callers can short-circuit -// cheaply on the (common) no-transforms path. -func (r *LocalRuntime) resolveTransforms(cfg *hooks.Config) []resolvedTransform { - if cfg == nil || len(cfg.BeforeLLMCall) == 0 || len(r.transforms) == 0 { - return nil - } - var out []resolvedTransform - seen := make(map[string]bool) - for _, h := range cfg.BeforeLLMCall { - if h.Type != hooks.HookTypeBuiltin { - continue - } - fn, ok := r.transforms[h.Command] - if !ok { - continue - } - key := transformDedupKey(h.Command, h.Args) - if seen[key] { - continue - } - seen[key] = true - out = append(out, resolvedTransform{name: h.Command, args: h.Args, fn: fn}) - } - return out } -// transformDedupKey mirrors [hooks.dedupKey]'s (command, args) shape so -// transforms are deduplicated on the same axis as the standard hook -// executor. Type is always `builtin` for transforms, so it's not part -// of the key. -func transformDedupKey(name string, args []string) string { - var b strings.Builder - b.WriteString(name) - for _, a := range args { - b.WriteByte(0) - b.WriteString(a) - } - return b.String() -} - -// applyBeforeLLMCallTransforms dispatches the agent's pre-resolved -// [MessageTransform] chain just before the model call, AFTER -// [LocalRuntime.executeBeforeLLMCallHooks] has run its gate. Transforms -// rewrite (or drop) messages but cannot abort the call — that -// responsibility lives with the gate. -// -// Returns the (possibly-rewritten) message slice. Errors from -// individual transforms are logged at warn level and the chain -// continues with the previous slice, matching the executor's "warn, -// don't break the loop" stance for non-fail-closed events. +// applyBeforeLLMCallTransforms runs every registered +// [MessageTransform] in chain order, just before the model call and +// AFTER [LocalRuntime.executeBeforeLLMCallHooks] has approved it. +// Errors from individual transforms are logged at warn level and the +// chain continues with the previous slice — a transform failure must +// never break the run loop. func (r *LocalRuntime) applyBeforeLLMCallTransforms( ctx context.Context, sess *session.Session, a *agent.Agent, msgs []chat.Message, ) []chat.Message { - transforms := r.transformsByAgent[a.Name()] - if len(transforms) == 0 { + if len(r.transforms) == 0 { return msgs } in := &hooks.Input{ @@ -209,8 +80,8 @@ func (r *LocalRuntime) applyBeforeLLMCallTransforms( HookEventName: hooks.EventBeforeLLMCall, Cwd: r.workingDir, } - for _, t := range transforms { - out, err := t.fn(ctx, in, t.args, msgs) + for _, t := range r.transforms { + out, err := t.fn(ctx, in, msgs) if err != nil { slog.Warn("Message transform failed; continuing with previous messages", "transform", t.name, "agent", a.Name(), "error", err) diff --git a/pkg/runtime/transforms_test.go b/pkg/runtime/transforms_test.go index 32cd0c094..3225e53dd 100644 --- a/pkg/runtime/transforms_test.go +++ b/pkg/runtime/transforms_test.go @@ -20,8 +20,7 @@ import ( // modalityModelStore returns a fixed [modelsdev.Model] regardless of // the requested ID. Tests configure its Modalities to exercise the // strip_unsupported_modalities transform's three branches: text-only -// (strip), image-supporting (no-op), and unknown-modality (no-op, -// no panic). +// (strip), image-supporting (no-op), and unknown-model (no-op). type modalityModelStore struct { ModelStore @@ -33,10 +32,10 @@ func (m modalityModelStore) GetModel(_ context.Context, _ string) (*modelsdev.Mo return m.model, m.err } -// recordingMsgProvider captures the messages each model call sees so a -// test can confirm a transform actually rewrote what reached the -// provider (rather than just what the in-memory slice ended up looking -// like). +// recordingMsgProvider captures the messages each model call sees so +// a test can confirm a transform actually rewrote what reached the +// provider (rather than just what the in-memory slice ended up +// looking like). type recordingMsgProvider struct { mockProvider @@ -44,95 +43,40 @@ type recordingMsgProvider struct { } func (p *recordingMsgProvider) CreateChatCompletionStream(_ context.Context, msgs []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { - snap := append([]chat.Message{}, msgs...) - p.got = append(p.got, snap) + p.got = append(p.got, append([]chat.Message{}, msgs...)) return p.stream, nil } -// TestStripUnsupportedModalitiesTransform_TextOnlyModelDropsImages -// pins the runtime-shipped [stripUnsupportedModalitiesTransform]'s -// happy path: a text-only model receives messages with all image -// content stripped, while text content is preserved. -func TestStripUnsupportedModalitiesTransform_TextOnlyModelDropsImages(t *testing.T) { +// TestStripUnsupportedModalitiesTransform pins the three branches of +// the runtime-shipped transform: a text-only model strips images, a +// multimodal model passes them through, and an unknown model also +// passes them through (the call surfaces any modality mismatch as a +// provider error rather than panicking transform-side). +func TestStripUnsupportedModalitiesTransform(t *testing.T) { t.Parallel() - prov := &mockProvider{id: "test/text-only", stream: &mockStream{}} + prov := &mockProvider{id: "test/model", stream: &mockStream{}} a := agent.New("root", "instructions", agent.WithModel(prov)) tm := team.New(team.WithAgents(a)) - store := modalityModelStore{model: &modelsdev.Model{ - Modalities: modelsdev.Modalities{Input: []string{"text"}}, - }} - r, err := NewLocalRuntime(tm, WithModelStore(store)) - require.NoError(t, err) - - in := &hooks.Input{AgentName: "root"} - msgs := []chat.Message{ - { - Role: chat.MessageRoleUser, - MultiContent: []chat.MessagePart{ - {Type: chat.MessagePartTypeText, Text: "look at this"}, - {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, - }, - }, - } - - got, err := r.stripUnsupportedModalitiesTransform(t.Context(), in, nil, msgs) - require.NoError(t, err) - require.Len(t, got, 1) - require.Len(t, got[0].MultiContent, 1, "image part must be stripped") - assert.Equal(t, chat.MessagePartTypeText, got[0].MultiContent[0].Type) -} - -// TestStripUnsupportedModalitiesTransform_ImageModelPassThrough pins -// the no-op branch: when the model's input modalities include "image", -// messages must reach the provider unchanged. -func TestStripUnsupportedModalitiesTransform_ImageModelPassThrough(t *testing.T) { - t.Parallel() - - prov := &mockProvider{id: "test/multimodal", stream: &mockStream{}} - a := agent.New("root", "instructions", agent.WithModel(prov)) - tm := team.New(team.WithAgents(a)) - - store := modalityModelStore{model: &modelsdev.Model{ - Modalities: modelsdev.Modalities{Input: []string{"text", "image"}}, - }} - r, err := NewLocalRuntime(tm, WithModelStore(store)) - require.NoError(t, err) - - in := &hooks.Input{AgentName: "root"} - msgs := []chat.Message{{ + imgMsg := chat.Message{ Role: chat.MessageRoleUser, MultiContent: []chat.MessagePart{ - {Type: chat.MessagePartTypeText, Text: "describe this"}, + {Type: chat.MessagePartTypeText, Text: "look at this"}, {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, }, - }} - - got, err := r.stripUnsupportedModalitiesTransform(t.Context(), in, nil, msgs) - require.NoError(t, err) - assert.Equal(t, msgs, got, "messages must reach a multimodal model untouched") -} - -// TestStripUnsupportedModalitiesTransform_UnknownModelPassThrough pins -// the safe-fallback branch: when the models.dev lookup fails (or -// returns nil), the transform returns msgs unchanged so the request -// still reaches the provider; any modality mismatch surfaces as a -// provider error rather than a transform-side panic. -func TestStripUnsupportedModalitiesTransform_UnknownModelPassThrough(t *testing.T) { - t.Parallel() - - prov := &mockProvider{id: "test/unknown", stream: &mockStream{}} - a := agent.New("root", "instructions", agent.WithModel(prov)) - tm := team.New(team.WithAgents(a)) + } cases := []struct { - name string - store modalityModelStore + name string + store modalityModelStore + wantStrip bool }{ - {name: "nil model", store: modalityModelStore{model: nil}}, - {name: "lookup error", store: modalityModelStore{err: errors.New("not found")}}, - {name: "empty modalities", store: modalityModelStore{model: &modelsdev.Model{}}}, + {name: "text-only model strips images", store: modalityModelStore{model: &modelsdev.Model{Modalities: modelsdev.Modalities{Input: []string{"text"}}}}, wantStrip: true}, + {name: "multimodal model passes through", store: modalityModelStore{model: &modelsdev.Model{Modalities: modelsdev.Modalities{Input: []string{"text", "image"}}}}}, + {name: "nil model passes through", store: modalityModelStore{model: nil}}, + {name: "lookup error passes through", store: modalityModelStore{err: errors.New("not found")}}, + {name: "empty modalities passes through", store: modalityModelStore{model: &modelsdev.Model{}}}, } for _, tc := range cases { @@ -140,24 +84,23 @@ func TestStripUnsupportedModalitiesTransform_UnknownModelPassThrough(t *testing. r, err := NewLocalRuntime(tm, WithModelStore(tc.store)) require.NoError(t, err) - msgs := []chat.Message{{ - Role: chat.MessageRoleUser, - MultiContent: []chat.MessagePart{ - {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "x"}}, - }, - }} - got, err := r.stripUnsupportedModalitiesTransform(t.Context(), &hooks.Input{AgentName: "root"}, nil, msgs) + got, err := r.stripUnsupportedModalitiesTransform(t.Context(), + &hooks.Input{AgentName: "root"}, []chat.Message{imgMsg}) require.NoError(t, err) - assert.Equal(t, msgs, got, "unknown model must fall through unchanged") + require.Len(t, got, 1) + if tc.wantStrip { + require.Len(t, got[0].MultiContent, 1, "image part must be stripped") + assert.Equal(t, chat.MessagePartTypeText, got[0].MultiContent[0].Type) + } else { + assert.Equal(t, imgMsg, got[0], "messages must reach the model untouched") + } }) } } // TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap covers the hot -// path: an agent without any registered transforms runs no allocator, -// no slog noise, and returns the input slice as-is. The test also -// covers the "agent not in transformsByAgent" branch for agents -// constructed outside the runtime's normal flow. +// path: a runtime with no registered transforms returns the input +// slice as-is without allocating a [hooks.Input]. func TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap(t *testing.T) { t.Parallel() @@ -167,11 +110,9 @@ func TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap(t *testing.T) { r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) require.NoError(t, err) - // Drop the auto-registered strip_unsupported_modalities so we can - // observe the cheap-path behavior. + // Drop the runtime-shipped strip transform so we can observe the + // cheap-path behavior. r.transforms = nil - r.transformNames = nil - r.transformsByAgent = nil sess := session.New(session.WithUserMessage("hi")) msgs := []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}} @@ -180,33 +121,23 @@ func TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap(t *testing.T) { assert.Equal(t, msgs, got) } -// TestApplyBeforeLLMCallTransforms_OrderAndArgs verifies that -// transforms registered via [WithMessageTransform] (a) auto-inject a -// before_llm_call entry on every agent, (b) run in configured order, -// and (c) receive the per-hook args from the YAML / auto-injection. -func TestApplyBeforeLLMCallTransforms_OrderAndArgs(t *testing.T) { +// TestApplyBeforeLLMCallTransforms_OrderAndChain verifies that +// transforms registered via [WithMessageTransform] run in +// registration order and feed each transform the cumulative output of +// the previous one (chain semantics, not parallel). +func TestApplyBeforeLLMCallTransforms_OrderAndChain(t *testing.T) { t.Parallel() type call struct { - name string - args []string - seen []chat.Message + name string + seenIn int } var calls []call - - tagA := func(_ context.Context, _ *hooks.Input, args []string, msgs []chat.Message) ([]chat.Message, error) { - seen := append([]chat.Message{}, msgs...) - calls = append(calls, call{name: "tag_a", args: args, seen: seen}) - out := append([]chat.Message{}, msgs...) - out = append(out, chat.Message{Role: chat.MessageRoleSystem, Content: "tag_a"}) - return out, nil - } - tagB := func(_ context.Context, _ *hooks.Input, args []string, msgs []chat.Message) ([]chat.Message, error) { - seen := append([]chat.Message{}, msgs...) - calls = append(calls, call{name: "tag_b", args: args, seen: seen}) - out := append([]chat.Message{}, msgs...) - out = append(out, chat.Message{Role: chat.MessageRoleSystem, Content: "tag_b"}) - return out, nil + tag := func(name string) MessageTransform { + return func(_ context.Context, _ *hooks.Input, msgs []chat.Message) ([]chat.Message, error) { + calls = append(calls, call{name: name, seenIn: len(msgs)}) + return append(msgs, chat.Message{Role: chat.MessageRoleSystem, Content: name}), nil + } } prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} @@ -214,53 +145,41 @@ func TestApplyBeforeLLMCallTransforms_OrderAndArgs(t *testing.T) { tm := team.New(team.WithAgents(a)) r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{}), - WithMessageTransform("tag_a", tagA), - WithMessageTransform("tag_b", tagB), + WithMessageTransform("tag_a", tag("tag_a")), + WithMessageTransform("tag_b", tag("tag_b")), ) require.NoError(t, err) sess := session.New(session.WithUserMessage("hi")) - msgs := []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}} - - got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, msgs) + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, + []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}}) - // The two registered tag transforms each fire exactly once. - // (The runtime-shipped strip transform also runs, but it doesn't - // append to `calls` since it's a different function.) require.Len(t, calls, 2, "expected tag_a + tag_b to fire exactly once each") - - // Registration order must be preserved: tag_a was registered first, - // so it must be invoked first; tag_b second. assert.Equal(t, "tag_a", calls[0].name, "transforms must run in registration order") - assert.Equal(t, "tag_b", calls[1].name, "transforms must run in registration order") - - // Cumulative semantics: the second transform must have observed the - // first transform's appended message. - assert.Greater(t, len(calls[1].seen), len(calls[0].seen), + assert.Equal(t, "tag_b", calls[1].name) + assert.Greater(t, calls[1].seenIn, calls[0].seenIn, "tag_b must see tag_a's appended message (chain semantics, not parallel)") - // The final slice must contain both tags. - var finalContent []string + var contents []string for _, m := range got { - finalContent = append(finalContent, m.Content) + contents = append(contents, m.Content) } - assert.Contains(t, finalContent, "tag_a") - assert.Contains(t, finalContent, "tag_b") + assert.Contains(t, contents, "tag_a") + assert.Contains(t, contents, "tag_b") } // TestApplyBeforeLLMCallTransforms_ErrorsAreSwallowed pins the // fail-soft contract: a transform that returns an error must NOT -// break the run loop; the previous slice continues through the chain. +// break the run loop; the previous slice continues through the +// chain. func TestApplyBeforeLLMCallTransforms_ErrorsAreSwallowed(t *testing.T) { t.Parallel() - failing := func(_ context.Context, _ *hooks.Input, _ []string, _ []chat.Message) ([]chat.Message, error) { + failing := func(_ context.Context, _ *hooks.Input, _ []chat.Message) ([]chat.Message, error) { return nil, errors.New("boom") } - tag := func(_ context.Context, _ *hooks.Input, _ []string, msgs []chat.Message) ([]chat.Message, error) { - out := append([]chat.Message{}, msgs...) - out = append(out, chat.Message{Role: chat.MessageRoleSystem, Content: "after_failure"}) - return out, nil + tag := func(_ context.Context, _ *hooks.Input, msgs []chat.Message) ([]chat.Message, error) { + return append(msgs, chat.Message{Role: chat.MessageRoleSystem, Content: "after_failure"}), nil } prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} @@ -274,12 +193,9 @@ func TestApplyBeforeLLMCallTransforms_ErrorsAreSwallowed(t *testing.T) { require.NoError(t, err) sess := session.New(session.WithUserMessage("hi")) - msgs := []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}} + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, + []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}}) - got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, msgs) - - // The "tag" transform must have run despite the failing one - // erroring out, and its output must be present. var contents []string for _, m := range got { contents = append(contents, m.Content) @@ -288,10 +204,10 @@ func TestApplyBeforeLLMCallTransforms_ErrorsAreSwallowed(t *testing.T) { "a transform error must not abort the chain") } -// TestRunStream_StripsImagesForTextOnlyModel confirms the inline -// strip in runStreamLoop has been replaced end-to-end: messages -// reaching the provider must no longer carry image parts when the -// agent's model is text-only. +// TestRunStream_StripsImagesForTextOnlyModel is the end-to-end smoke +// test confirming the inline strip in runStreamLoop has been +// replaced: messages reaching the provider must no longer carry +// image parts when the agent's model is text-only. func TestRunStream_StripsImagesForTextOnlyModel(t *testing.T) { t.Parallel() @@ -309,7 +225,6 @@ func TestRunStream_StripsImagesForTextOnlyModel(t *testing.T) { sess := session.New() sess.AddMessage(session.UserMessage("")) - // Replace the empty user message with a multi-part one carrying an image. last := &sess.Messages[len(sess.Messages)-1] last.Message.Message.MultiContent = []chat.MessagePart{ {Type: chat.MessagePartTypeText, Text: "describe"}, @@ -329,100 +244,16 @@ func TestRunStream_StripsImagesForTextOnlyModel(t *testing.T) { } } -// TestApplyMessageTransformDefaults_NoTransformsPreservesNil keeps the -// "nil cfg may stay nil when there are no defaults to add" contract, -// matching [applyCacheDefault]'s shape so [buildHooksExecutors] can -// continue to skip executor construction for agents with no hooks. -func TestApplyMessageTransformDefaults_NoTransformsPreservesNil(t *testing.T) { - t.Parallel() - - prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} - a := agent.New("root", "instructions", agent.WithModel(prov)) - tm := team.New(team.WithAgents(a)) - r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) - require.NoError(t, err) - - r.transforms = nil // simulate "no transforms registered" - r.transformNames = nil - got := r.applyMessageTransformDefaults(nil) - assert.Nil(t, got, "no transforms registered must preserve a nil cfg") -} - -// TestResolveTransforms_DedupsByCommandAndArgs guards against double -// invocation when an agent's user-authored YAML already lists a -// builtin that the runtime auto-injects on top. -func TestResolveTransforms_DedupsByCommandAndArgs(t *testing.T) { - t.Parallel() - - prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} - a := agent.New("root", "instructions", agent.WithModel(prov)) - tm := team.New(team.WithAgents(a)) - r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) - require.NoError(t, err) - - cfg := &hooks.Config{ - BeforeLLMCall: []hooks.Hook{ - {Type: hooks.HookTypeBuiltin, Command: BuiltinStripUnsupportedModalities}, - {Type: hooks.HookTypeBuiltin, Command: BuiltinStripUnsupportedModalities}, - {Type: hooks.HookTypeBuiltin, Command: BuiltinStripUnsupportedModalities, Args: []string{"foo"}}, - }, - } - - got := r.resolveTransforms(cfg) - require.Len(t, got, 2, "duplicate (name, args) must collapse to one") - assert.Equal(t, BuiltinStripUnsupportedModalities, got[0].name) - assert.Empty(t, got[0].args) - assert.Equal(t, []string{"foo"}, got[1].args, "differing args must NOT be deduplicated") -} - -// TestRegisterMessageTransform_ShimAvoidsExecutorErrors confirms the -// shim wired by [registerMessageTransform]: a hooks.Executor.Dispatch -// for a registered-transform builtin must succeed (Allowed=true, -// Result is a no-op) instead of failing with "no builtin hook -// registered as ...". -func TestRegisterMessageTransform_ShimAvoidsExecutorErrors(t *testing.T) { - t.Parallel() - - prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} - a := agent.New("root", "instructions", agent.WithModel(prov)) - tm := team.New(team.WithAgents(a)) - - called := 0 - tag := func(_ context.Context, _ *hooks.Input, _ []string, msgs []chat.Message) ([]chat.Message, error) { - called++ - return msgs, nil - } - r, err := NewLocalRuntime(tm, - WithModelStore(mockModelStore{}), - WithMessageTransform("tag", tag), - ) - require.NoError(t, err) - - exec := r.hooksExec(a) - require.NotNil(t, exec) - - res, err := exec.Dispatch(t.Context(), hooks.EventBeforeLLMCall, &hooks.Input{ - SessionID: "session", AgentName: "root", - }) - require.NoError(t, err, "executor must not error on a transform-only builtin") - assert.True(t, res.Allowed, "shim must report success") - // The transform itself isn't invoked through the executor — only via - // applyBeforeLLMCallTransforms — so `called` stays 0 here. - assert.Equal(t, 0, called, "executor path must NOT invoke the transform body") -} - -// TestRunStream_TransformErrorDoesNotBreakRun is an integration smoke -// test confirming end-to-end: a transform that returns an error must -// not prevent the model from being called; the run completes -// normally and the messages reaching the provider are the pre-error -// snapshot. +// TestRunStream_TransformErrorDoesNotBreakRun is the end-to-end smoke +// test confirming the fail-soft contract: a transform error must not +// prevent the model from being called and the run from completing. func TestRunStream_TransformErrorDoesNotBreakRun(t *testing.T) { t.Parallel() stream := newStreamBuilder().AddContent("ok").AddStopWithUsage(1, 1).Build() prov := &mockProvider{id: "test/mock-model", stream: stream} - failing := func(_ context.Context, _ *hooks.Input, _ []string, _ []chat.Message) ([]chat.Message, error) { + failing := func(_ context.Context, _ *hooks.Input, _ []chat.Message) ([]chat.Message, error) { return nil, errors.New("boom") } @@ -446,8 +277,8 @@ func TestRunStream_TransformErrorDoesNotBreakRun(t *testing.T) { } // TestWithMessageTransform_RejectsEmptyAndNil pins the input -// validation: empty name or nil fn must be silently ignored (matching -// the no-error shape of other Opts), with a slog warning. +// validation: empty name or nil fn must be silently ignored +// (matching the no-error shape of other Opts). func TestWithMessageTransform_RejectsEmptyAndNil(t *testing.T) { t.Parallel() @@ -457,15 +288,14 @@ func TestWithMessageTransform_RejectsEmptyAndNil(t *testing.T) { r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{}), - WithMessageTransform("", func(_ context.Context, _ *hooks.Input, _ []string, msgs []chat.Message) ([]chat.Message, error) { + WithMessageTransform("", func(_ context.Context, _ *hooks.Input, msgs []chat.Message) ([]chat.Message, error) { return msgs, nil }), WithMessageTransform("nilfn", nil), ) require.NoError(t, err, "WithMessageTransform must not surface a constructor error") - // Only the runtime-shipped strip transform should be in the table. + // Only the runtime-shipped strip transform should remain. require.Len(t, r.transforms, 1, "invalid transforms must be silently ignored") - _, ok := r.transforms[BuiltinStripUnsupportedModalities] - assert.True(t, ok) + assert.Equal(t, BuiltinStripUnsupportedModalities, r.transforms[0].name) } From a5adb9ecab5c91e8d02bacdfe1f54014255362a6 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Tue, 28 Apr 2026 11:30:35 +0200 Subject: [PATCH 3/3] fix strip transform reading wrong model in alloy / per-tool override mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The transform was calling agent.Model() which re-randomizes alloy picks and ignores per-tool overrides — it could end up consulting modalities for a different model than the one the loop was actually about to call. Pass the resolved modelID through hooks.Input.ModelID instead. --- pkg/hooks/types.go | 11 ++++ pkg/runtime/hooks.go | 9 +++- pkg/runtime/loop.go | 9 ++-- pkg/runtime/runtime.go | 2 +- pkg/runtime/strip_modalities.go | 31 ++++++----- pkg/runtime/transforms.go | 8 +++ pkg/runtime/transforms_test.go | 96 +++++++++++++++++++++++++++------ 7 files changed, 131 insertions(+), 35 deletions(-) diff --git a/pkg/hooks/types.go b/pkg/hooks/types.go index ef97e9232..ae9fbda9d 100644 --- a/pkg/hooks/types.go +++ b/pkg/hooks/types.go @@ -148,6 +148,17 @@ type Input struct { // closure (e.g. response cache). AgentName string `json:"agent_name,omitempty"` + // ModelID identifies the model the runtime is about to call (for + // before_llm_call) or just called (for after_llm_call), in the + // canonical "/" form expected by + // [modelsdev.Store.GetModel]. Populated from the loop's resolved + // model so it reflects per-tool model overrides and alloy-mode + // random selection — do NOT call Agent.Model() from a hook to + // recompute it, since alloy mode would re-randomize and a per-tool + // override would be invisible. Empty for events that aren't + // model-call-scoped. + ModelID string `json:"model_id,omitempty"` + // LastUserMessage is the text content of the latest user message in // the session at dispatch time. Populated for events that respond to // a user turn (stop, after_llm_call). Empty for events that aren't diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index 0a1fad321..097f155d1 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -284,9 +284,16 @@ func (r *LocalRuntime) executeOnToolApprovalDecisionHooks( // / exit 2) stops the run loop — see [hooks.EventBeforeLLMCall] for // the contract. Hooks that just want to contribute system messages // should target turn_start instead. -func (r *LocalRuntime) executeBeforeLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent) (stop bool, message string) { +// +// modelID is the canonical model identifier the loop has just +// resolved (after per-tool overrides and alloy-mode selection); it's +// surfaced to hooks via [hooks.Input.ModelID] so handlers don't need +// to recompute it from the agent. +func (r *LocalRuntime) executeBeforeLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent, modelID string) (stop bool, message string) { result := r.dispatchHook(ctx, a, hooks.EventBeforeLLMCall, &hooks.Input{ SessionID: sess.ID, + AgentName: a.Name(), + ModelID: modelID, }, nil) if result == nil || result.Allowed { return false, "" diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 35e1ab82a..7c1120e55 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -378,7 +378,7 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, // before_llm_call hooks fire just before the model is invoked. // A terminating verdict (e.g. from the max_iterations builtin) // stops the run loop here, before any tokens are spent. - if stop, msg := r.executeBeforeLLMCallHooks(ctx, sess, a); stop { + if stop, msg := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID); stop { slog.Warn("before_llm_call hook signalled run termination", "agent", a.Name(), "session_id", sess.ID, "reason", msg) r.emitHookDrivenShutdown(ctx, a, sess, msg, events) @@ -390,8 +390,11 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, // strip_unsupported_modalities for text-only models, plus any // embedder-supplied redactor / scrubber registered via // WithMessageTransform). Runs after the gate so a transform - // failure cannot waste the gate's allow verdict. - messages = r.applyBeforeLLMCallTransforms(ctx, sess, a, messages) + // failure cannot waste the gate's allow verdict. modelID is + // passed explicitly so transforms see the actual model the + // loop chose (per-tool override + alloy-mode selection), + // not whatever a fresh agent.Model() call would re-randomize. + messages = r.applyBeforeLLMCallTransforms(ctx, sess, a, modelID, messages) // Try primary model with fallback chain if configured res, usedModel, err := r.fallback.execute(streamCtx, a, model, messages, agentTools, sess, m, events) diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 19b9c0c0b..323c44d5e 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -163,7 +163,7 @@ type LocalRuntime struct { // transforms is the runtime's [MessageTransform] chain, applied to // every LLM call in registration order. Populated by - // [NewLocalRuntime] (for the runtime-shipped strippers) and by + // [NewLocalRuntime] (for the runtime-shipped strip transform) and by // [WithMessageTransform] (for embedder-supplied transforms). // Read-only after construction. transforms []registeredTransform diff --git a/pkg/runtime/strip_modalities.go b/pkg/runtime/strip_modalities.go index 63c679659..611b568ae 100644 --- a/pkg/runtime/strip_modalities.go +++ b/pkg/runtime/strip_modalities.go @@ -31,33 +31,36 @@ const BuiltinStripUnsupportedModalities = "strip_unsupported_modalities" const modalityImage = "image" // stripUnsupportedModalitiesTransform is the [MessageTransform] -// registered under [BuiltinStripUnsupportedModalities]. It resolves -// the agent (and therefore its current model) from -// [hooks.Input.AgentName], looks up the model's input modalities, and -// applies [stripImageContent] when image is missing from the list. +// registered under [BuiltinStripUnsupportedModalities]. It looks up +// the model definition from [hooks.Input.ModelID] (populated by the +// runtime with the actual model the loop chose, including per-tool +// overrides and alloy-mode selection) and applies +// [stripImageContent] when image is missing from the model's input +// modalities. // // The transform is a no-op for every "we don't know enough to act" -// case (missing agent, missing model, models.dev miss, empty -// modalities, image already supported): erring on the side of "send -// the messages as-is" matches the previous inline behavior in -// runStreamLoop, where an unknown model also fell through. +// case (missing ModelID, models.dev miss, empty modalities, image +// already supported): erring on the side of "send the messages +// as-is" matches the previous inline behavior in runStreamLoop, +// where an unknown model also fell through. Each fall-through emits +// a Debug log so operators can tell strip_unsupported_modalities +// from a transform that's silently inactive. func (r *LocalRuntime) stripUnsupportedModalitiesTransform( ctx context.Context, in *hooks.Input, msgs []chat.Message, ) ([]chat.Message, error) { - if in == nil || in.AgentName == "" { + if in == nil || in.ModelID == "" { + slog.Debug("strip_unsupported_modalities: skipping, no ModelID on input") return msgs, nil } - a, err := r.team.Agent(in.AgentName) - if err != nil || a == nil || a.Model() == nil { - return msgs, nil - } - m, err := r.modelsStore.GetModel(ctx, a.Model().ID()) + m, err := r.modelsStore.GetModel(ctx, in.ModelID) if err != nil || m == nil { // Unknown model: keep the previous (inline) behavior of // passing messages through untouched. The model call will // surface any modality mismatch as a provider error. + slog.Debug("strip_unsupported_modalities: skipping, model definition unavailable", + "model_id", in.ModelID, "error", err) return msgs, nil } if len(m.Modalities.Input) == 0 || slices.Contains(m.Modalities.Input, modalityImage) { diff --git a/pkg/runtime/transforms.go b/pkg/runtime/transforms.go index 81a468f0e..fd14d87f1 100644 --- a/pkg/runtime/transforms.go +++ b/pkg/runtime/transforms.go @@ -65,10 +65,17 @@ func WithMessageTransform(name string, fn MessageTransform) Opt { // Errors from individual transforms are logged at warn level and the // chain continues with the previous slice — a transform failure must // never break the run loop. +// +// modelID is the canonical model identifier the loop has just +// resolved (after per-tool overrides and alloy-mode selection); +// transforms read it via [hooks.Input.ModelID]. Calling +// agent.Model() from a transform would re-randomize the alloy pick +// and miss the per-tool override. func (r *LocalRuntime) applyBeforeLLMCallTransforms( ctx context.Context, sess *session.Session, a *agent.Agent, + modelID string, msgs []chat.Message, ) []chat.Message { if len(r.transforms) == 0 { @@ -77,6 +84,7 @@ func (r *LocalRuntime) applyBeforeLLMCallTransforms( in := &hooks.Input{ SessionID: sess.ID, AgentName: a.Name(), + ModelID: modelID, HookEventName: hooks.EventBeforeLLMCall, Cwd: r.workingDir, } diff --git a/pkg/runtime/transforms_test.go b/pkg/runtime/transforms_test.go index 3225e53dd..fc6c193ce 100644 --- a/pkg/runtime/transforms_test.go +++ b/pkg/runtime/transforms_test.go @@ -1,3 +1,4 @@ +// TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap covers the hot package runtime import ( @@ -32,6 +33,20 @@ func (m modalityModelStore) GetModel(_ context.Context, _ string) (*modelsdev.Mo return m.model, m.err } +// modalityByIDStore returns a different [modelsdev.Model] depending +// on the requested ID, letting tests prove the transform consulted +// the right ID (via [hooks.Input.ModelID]) rather than recomputing +// it from the agent. +type modalityByIDStore struct { + ModelStore + + models map[string]*modelsdev.Model +} + +func (m modalityByIDStore) GetModel(_ context.Context, id string) (*modelsdev.Model, error) { + return m.models[id], nil +} + // recordingMsgProvider captures the messages each model call sees so // a test can confirm a transform actually rewrote what reached the // provider (rather than just what the in-memory slice ended up @@ -70,13 +85,15 @@ func TestStripUnsupportedModalitiesTransform(t *testing.T) { cases := []struct { name string store modalityModelStore + modelID string wantStrip bool }{ - {name: "text-only model strips images", store: modalityModelStore{model: &modelsdev.Model{Modalities: modelsdev.Modalities{Input: []string{"text"}}}}, wantStrip: true}, - {name: "multimodal model passes through", store: modalityModelStore{model: &modelsdev.Model{Modalities: modelsdev.Modalities{Input: []string{"text", "image"}}}}}, - {name: "nil model passes through", store: modalityModelStore{model: nil}}, - {name: "lookup error passes through", store: modalityModelStore{err: errors.New("not found")}}, - {name: "empty modalities passes through", store: modalityModelStore{model: &modelsdev.Model{}}}, + {name: "text-only model strips images", modelID: "test/text", store: modalityModelStore{model: &modelsdev.Model{Modalities: modelsdev.Modalities{Input: []string{"text"}}}}, wantStrip: true}, + {name: "multimodal model passes through", modelID: "test/multimodal", store: modalityModelStore{model: &modelsdev.Model{Modalities: modelsdev.Modalities{Input: []string{"text", "image"}}}}}, + {name: "nil model passes through", modelID: "test/unknown", store: modalityModelStore{model: nil}}, + {name: "lookup error passes through", modelID: "test/unknown", store: modalityModelStore{err: errors.New("not found")}}, + {name: "empty modalities passes through", modelID: "test/empty", store: modalityModelStore{model: &modelsdev.Model{}}}, + {name: "empty ModelID passes through", modelID: "", store: modalityModelStore{model: &modelsdev.Model{Modalities: modelsdev.Modalities{Input: []string{"text"}}}}}, } for _, tc := range cases { @@ -85,7 +102,7 @@ func TestStripUnsupportedModalitiesTransform(t *testing.T) { require.NoError(t, err) got, err := r.stripUnsupportedModalitiesTransform(t.Context(), - &hooks.Input{AgentName: "root"}, []chat.Message{imgMsg}) + &hooks.Input{ModelID: tc.modelID}, []chat.Message{imgMsg}) require.NoError(t, err) require.Len(t, got, 1) if tc.wantStrip { @@ -98,7 +115,56 @@ func TestStripUnsupportedModalitiesTransform(t *testing.T) { } } -// TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap covers the hot +// TestStripUnsupportedModalitiesTransform_UsesInputModelID pins the +// fix for an alloy-mode / per-tool-override correctness bug: the +// transform must trust [hooks.Input.ModelID] (populated by the loop +// with the model it actually picked) and NOT recompute the model by +// calling agent.Model() — doing so would re-randomize the alloy +// pick and miss any per-tool override the loop had applied. +// +// The test wires a store that reports text-only for one ID and +// multimodal for another. Querying by the text-only ID must strip; +// querying by the multimodal ID must pass through. The agent's own +// model (its pool) is irrelevant — it's never consulted. +func TestStripUnsupportedModalitiesTransform_UsesInputModelID(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/agent-pool-model", stream: &mockStream{}} + a := agent.New("root", "instructions", agent.WithModel(prov)) + tm := team.New(team.WithAgents(a)) + + store := modalityByIDStore{models: map[string]*modelsdev.Model{ + "text/only": {Modalities: modelsdev.Modalities{Input: []string{"text"}}}, + "multi/modal": {Modalities: modelsdev.Modalities{Input: []string{"text", "image"}}}, + "test/agent-pool-model": {Modalities: modelsdev.Modalities{Input: []string{"text", "image"}}}, + }} + r, err := NewLocalRuntime(tm, WithModelStore(store)) + require.NoError(t, err) + + imgMsg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "describe"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, + }, + } + + // ModelID = text-only — strip must happen even though the agent's + // pool model is multimodal. + stripped, err := r.stripUnsupportedModalitiesTransform(t.Context(), + &hooks.Input{ModelID: "text/only"}, []chat.Message{imgMsg}) + require.NoError(t, err) + require.Len(t, stripped[0].MultiContent, 1, "image must be stripped when ModelID is text-only") + assert.Equal(t, chat.MessagePartTypeText, stripped[0].MultiContent[0].Type) + + // ModelID = multimodal — strip must NOT happen even if some other + // model in scope is text-only. Proves the lookup keys off ModelID. + passed, err := r.stripUnsupportedModalitiesTransform(t.Context(), + &hooks.Input{ModelID: "multi/modal"}, []chat.Message{imgMsg}) + require.NoError(t, err) + assert.Equal(t, imgMsg, passed[0], "images must reach a multimodal ModelID untouched") +} + // path: a runtime with no registered transforms returns the input // slice as-is without allocating a [hooks.Input]. func TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap(t *testing.T) { @@ -117,7 +183,7 @@ func TestApplyBeforeLLMCallTransforms_NoTransformsIsCheap(t *testing.T) { sess := session.New(session.WithUserMessage("hi")) msgs := []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}} - got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, msgs) + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, "", msgs) assert.Equal(t, msgs, got) } @@ -151,7 +217,7 @@ func TestApplyBeforeLLMCallTransforms_OrderAndChain(t *testing.T) { require.NoError(t, err) sess := session.New(session.WithUserMessage("hi")) - got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, "test/mock-model", []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}}) require.Len(t, calls, 2, "expected tag_a + tag_b to fire exactly once each") @@ -193,7 +259,7 @@ func TestApplyBeforeLLMCallTransforms_ErrorsAreSwallowed(t *testing.T) { require.NoError(t, err) sess := session.New(session.WithUserMessage("hi")) - got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, + got := r.applyBeforeLLMCallTransforms(t.Context(), sess, a, "test/mock-model", []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}}) var contents []string @@ -224,12 +290,10 @@ func TestRunStream_StripsImagesForTextOnlyModel(t *testing.T) { require.NoError(t, err) sess := session.New() - sess.AddMessage(session.UserMessage("")) - last := &sess.Messages[len(sess.Messages)-1] - last.Message.Message.MultiContent = []chat.MessagePart{ - {Type: chat.MessagePartTypeText, Text: "describe"}, - {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, - } + sess.AddMessage(session.UserMessage("", + chat.MessagePart{Type: chat.MessagePartTypeText, Text: "describe"}, + chat.MessagePart{Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,abc"}}, + )) for range r.RunStream(t.Context(), sess) { // drain — only the recorded provider state matters