Skip to content

Commit e59e163

Browse files
authored
Merge pull request #2573 from dgageot/board/extracting-runtime-features-into-builtin-d52e607b
runtime: extract image-stripping into a registered MessageTransform
2 parents 6a124b2 + a5adb9e commit e59e163

8 files changed

Lines changed: 629 additions & 46 deletions

File tree

pkg/hooks/types.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,17 @@ type Input struct {
148148
// closure (e.g. response cache).
149149
AgentName string `json:"agent_name,omitempty"`
150150

151+
// ModelID identifies the model the runtime is about to call (for
152+
// before_llm_call) or just called (for after_llm_call), in the
153+
// canonical "<provider>/<model>" form expected by
154+
// [modelsdev.Store.GetModel]. Populated from the loop's resolved
155+
// model so it reflects per-tool model overrides and alloy-mode
156+
// random selection — do NOT call Agent.Model() from a hook to
157+
// recompute it, since alloy mode would re-randomize and a per-tool
158+
// override would be invisible. Empty for events that aren't
159+
// model-call-scoped.
160+
ModelID string `json:"model_id,omitempty"`
161+
151162
// LastUserMessage is the text content of the latest user message in
152163
// the session at dispatch time. Populated for events that respond to
153164
// a user turn (stop, after_llm_call). Empty for events that aren't

pkg/runtime/hooks.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,16 @@ func (r *LocalRuntime) executeOnToolApprovalDecisionHooks(
284284
// / exit 2) stops the run loop — see [hooks.EventBeforeLLMCall] for
285285
// the contract. Hooks that just want to contribute system messages
286286
// should target turn_start instead.
287-
func (r *LocalRuntime) executeBeforeLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent) (stop bool, message string) {
287+
//
288+
// modelID is the canonical model identifier the loop has just
289+
// resolved (after per-tool overrides and alloy-mode selection); it's
290+
// surfaced to hooks via [hooks.Input.ModelID] so handlers don't need
291+
// to recompute it from the agent.
292+
func (r *LocalRuntime) executeBeforeLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent, modelID string) (stop bool, message string) {
288293
result := r.dispatchHook(ctx, a, hooks.EventBeforeLLMCall, &hooks.Input{
289294
SessionID: sess.ID,
295+
AgentName: a.Name(),
296+
ModelID: modelID,
290297
}, nil)
291298
if result == nil || result.Allowed {
292299
return false, ""

pkg/runtime/loop.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -375,24 +375,27 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session,
375375
messages := sess.GetMessages(a, slices.Concat(sessionStartMsgs, userPromptMsgs, turnStartMsgs)...)
376376
slog.Debug("Retrieved messages for processing", "agent", a.Name(), "message_count", len(messages))
377377

378-
// Strip image content from messages if the model doesn't support image input.
379-
// This prevents API errors when conversation history contains images (e.g. from
380-
// tool results or user attachments) but the current model is text-only.
381-
if m != nil && len(m.Modalities.Input) > 0 && !slices.Contains(m.Modalities.Input, "image") {
382-
messages = stripImageContent(messages)
383-
}
384-
385378
// before_llm_call hooks fire just before the model is invoked.
386379
// A terminating verdict (e.g. from the max_iterations builtin)
387380
// stops the run loop here, before any tokens are spent.
388-
if stop, msg := r.executeBeforeLLMCallHooks(ctx, sess, a); stop {
381+
if stop, msg := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID); stop {
389382
slog.Warn("before_llm_call hook signalled run termination",
390383
"agent", a.Name(), "session_id", sess.ID, "reason", msg)
391384
r.emitHookDrivenShutdown(ctx, a, sess, msg, events)
392385
streamSpan.End()
393386
return
394387
}
395388

389+
// Apply registered before_llm_call message transforms (e.g.
390+
// strip_unsupported_modalities for text-only models, plus any
391+
// embedder-supplied redactor / scrubber registered via
392+
// WithMessageTransform). Runs after the gate so a transform
393+
// failure cannot waste the gate's allow verdict. modelID is
394+
// passed explicitly so transforms see the actual model the
395+
// loop chose (per-tool override + alloy-mode selection),
396+
// not whatever a fresh agent.Model() call would re-randomize.
397+
messages = r.applyBeforeLLMCallTransforms(ctx, sess, a, modelID, messages)
398+
396399
// Try primary model with fallback chain if configured
397400
res, usedModel, err := r.fallback.execute(streamCtx, a, model, messages, agentTools, sess, m, events)
398401
if err != nil {

pkg/runtime/runtime.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ type LocalRuntime struct {
161161
// construction, so no locking is needed.
162162
hooksExecByAgent map[string]*hooks.Executor
163163

164+
// transforms is the runtime's [MessageTransform] chain, applied to
165+
// every LLM call in registration order. Populated by
166+
// [NewLocalRuntime] (for the runtime-shipped strip transform) and by
167+
// [WithMessageTransform] (for embedder-supplied transforms).
168+
// Read-only after construction.
169+
transforms []registeredTransform
170+
164171
fallback *fallbackExecutor
165172

166173
// observers receive every event the runtime produces, in
@@ -392,6 +399,17 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
392399
return nil, fmt.Errorf("register %q builtin: %w", BuiltinCacheResponse, err)
393400
}
394401

402+
// strip_unsupported_modalities is the runtime-shipped
403+
// before_llm_call message transform that drops image content from
404+
// messages when the agent's model is text-only. Like
405+
// cache_response it captures the runtime closure (to resolve the
406+
// agent and its model from Input.AgentName) and is therefore
407+
// registered here rather than in pkg/hooks/builtins.
408+
r.transforms = append(r.transforms, registeredTransform{
409+
name: BuiltinStripUnsupportedModalities,
410+
fn: r.stripUnsupportedModalitiesTransform,
411+
})
412+
395413
for _, opt := range opts {
396414
opt(r)
397415
}

pkg/runtime/streaming.go

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"io"
8-
"log/slog"
98
"strings"
109

1110
"github.com/docker/docker-agent/pkg/agent"
@@ -236,39 +235,3 @@ func handleStream(ctx context.Context, stream chat.MessageStream, a *agent.Agent
236235
Usage: messageUsage,
237236
}, nil
238237
}
239-
240-
// stripImageContent returns a copy of messages with all image-related content
241-
// removed. This is used when the target model doesn't support image input to
242-
// prevent API errors. Text content is preserved; image parts in MultiContent
243-
// are filtered out, and file attachments with image MIME types are dropped.
244-
func stripImageContent(messages []chat.Message) []chat.Message {
245-
result := make([]chat.Message, len(messages))
246-
for i, msg := range messages {
247-
result[i] = msg
248-
249-
if len(msg.MultiContent) == 0 {
250-
continue
251-
}
252-
253-
var filtered []chat.MessagePart
254-
for _, part := range msg.MultiContent {
255-
switch part.Type {
256-
case chat.MessagePartTypeImageURL:
257-
// Drop image URL parts entirely.
258-
continue
259-
case chat.MessagePartTypeFile:
260-
// Drop file parts that are images.
261-
if part.File != nil && chat.IsImageMimeType(part.File.MimeType) {
262-
continue
263-
}
264-
}
265-
filtered = append(filtered, part)
266-
}
267-
268-
if len(filtered) != len(msg.MultiContent) {
269-
result[i].MultiContent = filtered
270-
slog.Debug("Stripped image content from message", "role", msg.Role, "original_parts", len(msg.MultiContent), "remaining_parts", len(filtered))
271-
}
272-
}
273-
return result
274-
}

pkg/runtime/strip_modalities.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
"slices"
7+
8+
"github.com/docker/docker-agent/pkg/chat"
9+
"github.com/docker/docker-agent/pkg/hooks"
10+
)
11+
12+
// BuiltinStripUnsupportedModalities is the name of the runtime-shipped
13+
// before_llm_call message transform that drops image content from the
14+
// outgoing messages when the agent's current model doesn't list image
15+
// in its input modalities. It's the runtime-shipped peer of
16+
// [BuiltinCacheResponse] (a stop hook) — the constant exists mostly
17+
// for log filtering and diagnostics.
18+
//
19+
// Sending images to a text-only model produces hard provider errors
20+
// (HTTP 400 from OpenAI, "image input is not supported" from
21+
// Anthropic text variants, etc.); promoting the strip into a
22+
// registered transform replaces an inline branch in runStreamLoop and
23+
// opens the door to a family of message-mutating transforms
24+
// (redactors, scrubbers, ...).
25+
const BuiltinStripUnsupportedModalities = "strip_unsupported_modalities"
26+
27+
// modalityImage is the canonical models.dev modality name for image
28+
// input. A constant instead of a literal so a typo trips a compile
29+
// error and the contract with [modelsdev.Modalities.Input] is
30+
// discoverable from the runtime side.
31+
const modalityImage = "image"
32+
33+
// stripUnsupportedModalitiesTransform is the [MessageTransform]
34+
// registered under [BuiltinStripUnsupportedModalities]. It looks up
35+
// the model definition from [hooks.Input.ModelID] (populated by the
36+
// runtime with the actual model the loop chose, including per-tool
37+
// overrides and alloy-mode selection) and applies
38+
// [stripImageContent] when image is missing from the model's input
39+
// modalities.
40+
//
41+
// The transform is a no-op for every "we don't know enough to act"
42+
// case (missing ModelID, models.dev miss, empty modalities, image
43+
// already supported): erring on the side of "send the messages
44+
// as-is" matches the previous inline behavior in runStreamLoop,
45+
// where an unknown model also fell through. Each fall-through emits
46+
// a Debug log so operators can tell strip_unsupported_modalities
47+
// from a transform that's silently inactive.
48+
func (r *LocalRuntime) stripUnsupportedModalitiesTransform(
49+
ctx context.Context,
50+
in *hooks.Input,
51+
msgs []chat.Message,
52+
) ([]chat.Message, error) {
53+
if in == nil || in.ModelID == "" {
54+
slog.Debug("strip_unsupported_modalities: skipping, no ModelID on input")
55+
return msgs, nil
56+
}
57+
m, err := r.modelsStore.GetModel(ctx, in.ModelID)
58+
if err != nil || m == nil {
59+
// Unknown model: keep the previous (inline) behavior of
60+
// passing messages through untouched. The model call will
61+
// surface any modality mismatch as a provider error.
62+
slog.Debug("strip_unsupported_modalities: skipping, model definition unavailable",
63+
"model_id", in.ModelID, "error", err)
64+
return msgs, nil
65+
}
66+
if len(m.Modalities.Input) == 0 || slices.Contains(m.Modalities.Input, modalityImage) {
67+
return msgs, nil
68+
}
69+
return stripImageContent(msgs), nil
70+
}
71+
72+
// stripImageContent returns a copy of messages with all image-related
73+
// content removed. Text content is preserved; image parts in
74+
// [chat.Message.MultiContent] are filtered out, and file attachments
75+
// with image MIME types are dropped.
76+
//
77+
// Lives next to [stripUnsupportedModalitiesTransform] (rather than in
78+
// streaming.go where it originated) so the builtin's storage,
79+
// transform, and helper are co-located. Kept as an unexported helper
80+
// because the only legitimate caller is the transform itself — direct
81+
// use bypasses the modality check.
82+
func stripImageContent(messages []chat.Message) []chat.Message {
83+
result := make([]chat.Message, len(messages))
84+
for i, msg := range messages {
85+
result[i] = msg
86+
87+
if len(msg.MultiContent) == 0 {
88+
continue
89+
}
90+
91+
var filtered []chat.MessagePart
92+
for _, part := range msg.MultiContent {
93+
switch part.Type {
94+
case chat.MessagePartTypeImageURL:
95+
// Drop image URL parts entirely.
96+
continue
97+
case chat.MessagePartTypeFile:
98+
// Drop file parts that are images.
99+
if part.File != nil && chat.IsImageMimeType(part.File.MimeType) {
100+
continue
101+
}
102+
}
103+
filtered = append(filtered, part)
104+
}
105+
106+
if len(filtered) != len(msg.MultiContent) {
107+
result[i].MultiContent = filtered
108+
slog.Debug("Stripped image content from message",
109+
"role", msg.Role,
110+
"original_parts", len(msg.MultiContent),
111+
"remaining_parts", len(filtered))
112+
}
113+
}
114+
return result
115+
}

pkg/runtime/transforms.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
7+
"github.com/docker/docker-agent/pkg/agent"
8+
"github.com/docker/docker-agent/pkg/chat"
9+
"github.com/docker/docker-agent/pkg/hooks"
10+
"github.com/docker/docker-agent/pkg/session"
11+
)
12+
13+
// MessageTransform is the in-process-only handler signature for a
14+
// before_llm_call transform that rewrites the chat messages about to
15+
// be sent to the model. It receives the full message slice in chain
16+
// order and returns the (possibly-rewritten) replacement.
17+
//
18+
// Transforms are intentionally a runtime-private contract: the cost of
19+
// JSON-roundtripping a full conversation through the cross-process
20+
// hook protocol would be prohibitive, so command and model hooks
21+
// cannot rewrite messages. Embedders register transforms via
22+
// [WithMessageTransform]; the runtime ships
23+
// [BuiltinStripUnsupportedModalities] out of the box.
24+
//
25+
// Transforms run AFTER the standard before_llm_call gate (see
26+
// [LocalRuntime.executeBeforeLLMCallHooks]) — a hook that wants to
27+
// abort the call should target the gate, not a transform.
28+
//
29+
// Returning a non-nil error logs a warning and falls through to the
30+
// previous message slice; a transform failure must never break the
31+
// run loop.
32+
type MessageTransform func(ctx context.Context, in *hooks.Input, msgs []chat.Message) ([]chat.Message, error)
33+
34+
// registeredTransform pairs a [MessageTransform] with the name it was
35+
// registered under. The name is purely diagnostic — it shows up in
36+
// slog records when a transform errors out — so re-registering the
37+
// same name simply appends another entry without any de-duplication.
38+
type registeredTransform struct {
39+
name string
40+
fn MessageTransform
41+
}
42+
43+
// WithMessageTransform registers a [MessageTransform] under name so
44+
// it is applied to every LLM call, in registration order, after the
45+
// before_llm_call gate. Transforms are runtime-global: per-agent
46+
// scoping (if needed) lives in the transform body, where
47+
// [hooks.Input.AgentName] is available — the runtime-shipped strip
48+
// transform is an example.
49+
//
50+
// Empty name or nil fn are silently ignored, matching the no-error
51+
// shape of the other [Opt] helpers.
52+
func WithMessageTransform(name string, fn MessageTransform) Opt {
53+
return func(r *LocalRuntime) {
54+
if name == "" || fn == nil {
55+
slog.Warn("Ignoring message transform with empty name or nil fn", "name", name)
56+
return
57+
}
58+
r.transforms = append(r.transforms, registeredTransform{name: name, fn: fn})
59+
}
60+
}
61+
62+
// applyBeforeLLMCallTransforms runs every registered
63+
// [MessageTransform] in chain order, just before the model call and
64+
// AFTER [LocalRuntime.executeBeforeLLMCallHooks] has approved it.
65+
// Errors from individual transforms are logged at warn level and the
66+
// chain continues with the previous slice — a transform failure must
67+
// never break the run loop.
68+
//
69+
// modelID is the canonical model identifier the loop has just
70+
// resolved (after per-tool overrides and alloy-mode selection);
71+
// transforms read it via [hooks.Input.ModelID]. Calling
72+
// agent.Model() from a transform would re-randomize the alloy pick
73+
// and miss the per-tool override.
74+
func (r *LocalRuntime) applyBeforeLLMCallTransforms(
75+
ctx context.Context,
76+
sess *session.Session,
77+
a *agent.Agent,
78+
modelID string,
79+
msgs []chat.Message,
80+
) []chat.Message {
81+
if len(r.transforms) == 0 {
82+
return msgs
83+
}
84+
in := &hooks.Input{
85+
SessionID: sess.ID,
86+
AgentName: a.Name(),
87+
ModelID: modelID,
88+
HookEventName: hooks.EventBeforeLLMCall,
89+
Cwd: r.workingDir,
90+
}
91+
for _, t := range r.transforms {
92+
out, err := t.fn(ctx, in, msgs)
93+
if err != nil {
94+
slog.Warn("Message transform failed; continuing with previous messages",
95+
"transform", t.name, "agent", a.Name(), "error", err)
96+
continue
97+
}
98+
msgs = out
99+
}
100+
return msgs
101+
}

0 commit comments

Comments
 (0)