Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pkg/hooks/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,17 @@ type Input struct {
// closure (e.g. response cache).
AgentName string `json:"agent_name,omitempty"`

// ModelID identifies the model the runtime is about to call (for
// before_llm_call) or just called (for after_llm_call), in the
// canonical "<provider>/<model>" form expected by
// [modelsdev.Store.GetModel]. Populated from the loop's resolved
// model so it reflects per-tool model overrides and alloy-mode
// random selection — do NOT call Agent.Model() from a hook to
// recompute it, since alloy mode would re-randomize and a per-tool
// override would be invisible. Empty for events that aren't
// model-call-scoped.
ModelID string `json:"model_id,omitempty"`

// LastUserMessage is the text content of the latest user message in
// the session at dispatch time. Populated for events that respond to
// a user turn (stop, after_llm_call). Empty for events that aren't
Expand Down
9 changes: 8 additions & 1 deletion pkg/runtime/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,16 @@ func (r *LocalRuntime) executeOnToolApprovalDecisionHooks(
// / exit 2) stops the run loop — see [hooks.EventBeforeLLMCall] for
// the contract. Hooks that just want to contribute system messages
// should target turn_start instead.
func (r *LocalRuntime) executeBeforeLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent) (stop bool, message string) {
//
// modelID is the canonical model identifier the loop has just
// resolved (after per-tool overrides and alloy-mode selection); it's
// surfaced to hooks via [hooks.Input.ModelID] so handlers don't need
// to recompute it from the agent.
func (r *LocalRuntime) executeBeforeLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent, modelID string) (stop bool, message string) {
result := r.dispatchHook(ctx, a, hooks.EventBeforeLLMCall, &hooks.Input{
SessionID: sess.ID,
AgentName: a.Name(),
ModelID: modelID,
}, nil)
if result == nil || result.Allowed {
return false, ""
Expand Down
19 changes: 11 additions & 8 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,24 +375,27 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session,
messages := sess.GetMessages(a, slices.Concat(sessionStartMsgs, userPromptMsgs, turnStartMsgs)...)
slog.Debug("Retrieved messages for processing", "agent", a.Name(), "message_count", len(messages))

// Strip image content from messages if the model doesn't support image input.
// This prevents API errors when conversation history contains images (e.g. from
// tool results or user attachments) but the current model is text-only.
if m != nil && len(m.Modalities.Input) > 0 && !slices.Contains(m.Modalities.Input, "image") {
messages = stripImageContent(messages)
}

// before_llm_call hooks fire just before the model is invoked.
// A terminating verdict (e.g. from the max_iterations builtin)
// stops the run loop here, before any tokens are spent.
if stop, msg := r.executeBeforeLLMCallHooks(ctx, sess, a); stop {
if stop, msg := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID); stop {
slog.Warn("before_llm_call hook signalled run termination",
"agent", a.Name(), "session_id", sess.ID, "reason", msg)
r.emitHookDrivenShutdown(ctx, a, sess, msg, events)
streamSpan.End()
return
}

// Apply registered before_llm_call message transforms (e.g.
// strip_unsupported_modalities for text-only models, plus any
// embedder-supplied redactor / scrubber registered via
// WithMessageTransform). Runs after the gate so a transform
// failure cannot waste the gate's allow verdict. modelID is
// passed explicitly so transforms see the actual model the
// loop chose (per-tool override + alloy-mode selection),
// not whatever a fresh agent.Model() call would re-randomize.
messages = r.applyBeforeLLMCallTransforms(ctx, sess, a, modelID, messages)

// Try primary model with fallback chain if configured
res, usedModel, err := r.fallback.execute(streamCtx, a, model, messages, agentTools, sess, m, events)
if err != nil {
Expand Down
18 changes: 18 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ type LocalRuntime struct {
// construction, so no locking is needed.
hooksExecByAgent map[string]*hooks.Executor

// transforms is the runtime's [MessageTransform] chain, applied to
// every LLM call in registration order. Populated by
// [NewLocalRuntime] (for the runtime-shipped strip transform) and by
// [WithMessageTransform] (for embedder-supplied transforms).
// Read-only after construction.
transforms []registeredTransform

fallback *fallbackExecutor

// observers receive every event the runtime produces, in
Expand Down Expand Up @@ -392,6 +399,17 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
return nil, fmt.Errorf("register %q builtin: %w", BuiltinCacheResponse, err)
}

// strip_unsupported_modalities is the runtime-shipped
// before_llm_call message transform that drops image content from
// messages when the agent's model is text-only. Like
// cache_response it captures the runtime closure (to resolve the
// agent and its model from Input.AgentName) and is therefore
// registered here rather than in pkg/hooks/builtins.
r.transforms = append(r.transforms, registeredTransform{
name: BuiltinStripUnsupportedModalities,
fn: r.stripUnsupportedModalitiesTransform,
})

for _, opt := range opts {
opt(r)
}
Expand Down
37 changes: 0 additions & 37 deletions pkg/runtime/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"strings"

"github.com/docker/docker-agent/pkg/agent"
Expand Down Expand Up @@ -236,39 +235,3 @@ func handleStream(ctx context.Context, stream chat.MessageStream, a *agent.Agent
Usage: messageUsage,
}, nil
}

// stripImageContent returns a copy of messages with all image-related content
// removed. This is used when the target model doesn't support image input to
// prevent API errors. Text content is preserved; image parts in MultiContent
// are filtered out, and file attachments with image MIME types are dropped.
func stripImageContent(messages []chat.Message) []chat.Message {
result := make([]chat.Message, len(messages))
for i, msg := range messages {
result[i] = msg

if len(msg.MultiContent) == 0 {
continue
}

var filtered []chat.MessagePart
for _, part := range msg.MultiContent {
switch part.Type {
case chat.MessagePartTypeImageURL:
// Drop image URL parts entirely.
continue
case chat.MessagePartTypeFile:
// Drop file parts that are images.
if part.File != nil && chat.IsImageMimeType(part.File.MimeType) {
continue
}
}
filtered = append(filtered, part)
}

if len(filtered) != len(msg.MultiContent) {
result[i].MultiContent = filtered
slog.Debug("Stripped image content from message", "role", msg.Role, "original_parts", len(msg.MultiContent), "remaining_parts", len(filtered))
}
}
return result
}
115 changes: 115 additions & 0 deletions pkg/runtime/strip_modalities.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package runtime

import (
"context"
"log/slog"
"slices"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/hooks"
)

// BuiltinStripUnsupportedModalities is the name of the runtime-shipped
// before_llm_call message transform that drops image content from the
// outgoing messages when the agent's current model doesn't list image
// in its input modalities. It's the runtime-shipped peer of
// [BuiltinCacheResponse] (a stop hook) — the constant exists mostly
// for log filtering and diagnostics.
//
// Sending images to a text-only model produces hard provider errors
// (HTTP 400 from OpenAI, "image input is not supported" from
// Anthropic text variants, etc.); promoting the strip into a
// registered transform replaces an inline branch in runStreamLoop and
// opens the door to a family of message-mutating transforms
// (redactors, scrubbers, ...).
const BuiltinStripUnsupportedModalities = "strip_unsupported_modalities"

// modalityImage is the canonical models.dev modality name for image
// input. A constant instead of a literal so a typo trips a compile
// error and the contract with [modelsdev.Modalities.Input] is
// discoverable from the runtime side.
const modalityImage = "image"

// stripUnsupportedModalitiesTransform is the [MessageTransform]
// registered under [BuiltinStripUnsupportedModalities]. It looks up
// the model definition from [hooks.Input.ModelID] (populated by the
// runtime with the actual model the loop chose, including per-tool
// overrides and alloy-mode selection) and applies
// [stripImageContent] when image is missing from the model's input
// modalities.
//
// The transform is a no-op for every "we don't know enough to act"
// case (missing ModelID, models.dev miss, empty modalities, image
// already supported): erring on the side of "send the messages
// as-is" matches the previous inline behavior in runStreamLoop,
// where an unknown model also fell through. Each fall-through emits
// a Debug log so operators can tell strip_unsupported_modalities
// from a transform that's silently inactive.
func (r *LocalRuntime) stripUnsupportedModalitiesTransform(
ctx context.Context,
in *hooks.Input,
msgs []chat.Message,
) ([]chat.Message, error) {
if in == nil || in.ModelID == "" {
slog.Debug("strip_unsupported_modalities: skipping, no ModelID on input")
return msgs, nil
}
m, err := r.modelsStore.GetModel(ctx, in.ModelID)
if err != nil || m == nil {
// Unknown model: keep the previous (inline) behavior of
// passing messages through untouched. The model call will
// surface any modality mismatch as a provider error.
slog.Debug("strip_unsupported_modalities: skipping, model definition unavailable",
"model_id", in.ModelID, "error", err)
return msgs, nil
}
if len(m.Modalities.Input) == 0 || slices.Contains(m.Modalities.Input, modalityImage) {
return msgs, nil
}
return stripImageContent(msgs), nil
}

// stripImageContent returns a copy of messages with all image-related
// content removed. Text content is preserved; image parts in
// [chat.Message.MultiContent] are filtered out, and file attachments
// with image MIME types are dropped.
//
// Lives next to [stripUnsupportedModalitiesTransform] (rather than in
// streaming.go where it originated) so the builtin's storage,
// transform, and helper are co-located. Kept as an unexported helper
// because the only legitimate caller is the transform itself — direct
// use bypasses the modality check.
func stripImageContent(messages []chat.Message) []chat.Message {
result := make([]chat.Message, len(messages))
for i, msg := range messages {
result[i] = msg

if len(msg.MultiContent) == 0 {
continue
}

var filtered []chat.MessagePart
for _, part := range msg.MultiContent {
switch part.Type {
case chat.MessagePartTypeImageURL:
// Drop image URL parts entirely.
continue
case chat.MessagePartTypeFile:
// Drop file parts that are images.
if part.File != nil && chat.IsImageMimeType(part.File.MimeType) {
continue
}
}
filtered = append(filtered, part)
}

if len(filtered) != len(msg.MultiContent) {
result[i].MultiContent = filtered
slog.Debug("Stripped image content from message",
"role", msg.Role,
"original_parts", len(msg.MultiContent),
"remaining_parts", len(filtered))
}
}
return result
}
101 changes: 101 additions & 0 deletions pkg/runtime/transforms.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package runtime

import (
"context"
"log/slog"

"github.com/docker/docker-agent/pkg/agent"
"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/hooks"
"github.com/docker/docker-agent/pkg/session"
)

// MessageTransform is the in-process-only handler signature for a
// before_llm_call transform that rewrites the chat messages about to
// be sent to the model. It receives the full message slice in chain
// order and returns the (possibly-rewritten) replacement.
//
// Transforms are intentionally a runtime-private contract: the cost of
// JSON-roundtripping a full conversation through the cross-process
// hook protocol would be prohibitive, so command and model hooks
// cannot rewrite messages. Embedders register transforms via
// [WithMessageTransform]; the runtime ships
// [BuiltinStripUnsupportedModalities] out of the box.
//
// Transforms run AFTER the standard before_llm_call gate (see
// [LocalRuntime.executeBeforeLLMCallHooks]) — a hook that wants to
// abort the call should target the gate, not a transform.
//
// Returning a non-nil error logs a warning and falls through to the
// previous message slice; a transform failure must never break the
// run loop.
type MessageTransform func(ctx context.Context, in *hooks.Input, msgs []chat.Message) ([]chat.Message, error)

// registeredTransform pairs a [MessageTransform] with the name it was
// registered under. The name is purely diagnostic — it shows up in
// slog records when a transform errors out — so re-registering the
// same name simply appends another entry without any de-duplication.
type registeredTransform struct {
name string
fn MessageTransform
}

// WithMessageTransform registers a [MessageTransform] under name so
// it is applied to every LLM call, in registration order, after the
// before_llm_call gate. Transforms are runtime-global: per-agent
// scoping (if needed) lives in the transform body, where
// [hooks.Input.AgentName] is available — the runtime-shipped strip
// transform is an example.
//
// Empty name or nil fn are silently ignored, matching the no-error
// shape of the other [Opt] helpers.
func WithMessageTransform(name string, fn MessageTransform) Opt {
return func(r *LocalRuntime) {
if name == "" || fn == nil {
slog.Warn("Ignoring message transform with empty name or nil fn", "name", name)
return
}
r.transforms = append(r.transforms, registeredTransform{name: name, fn: fn})
}
}

// applyBeforeLLMCallTransforms runs every registered
// [MessageTransform] in chain order, just before the model call and
// AFTER [LocalRuntime.executeBeforeLLMCallHooks] has approved it.
// Errors from individual transforms are logged at warn level and the
// chain continues with the previous slice — a transform failure must
// never break the run loop.
//
// modelID is the canonical model identifier the loop has just
// resolved (after per-tool overrides and alloy-mode selection);
// transforms read it via [hooks.Input.ModelID]. Calling
// agent.Model() from a transform would re-randomize the alloy pick
// and miss the per-tool override.
func (r *LocalRuntime) applyBeforeLLMCallTransforms(
ctx context.Context,
sess *session.Session,
a *agent.Agent,
modelID string,
msgs []chat.Message,
) []chat.Message {
if len(r.transforms) == 0 {
return msgs
}
in := &hooks.Input{
SessionID: sess.ID,
AgentName: a.Name(),
ModelID: modelID,
HookEventName: hooks.EventBeforeLLMCall,
Cwd: r.workingDir,
}
for _, t := range r.transforms {
out, err := t.fn(ctx, in, msgs)
if err != nil {
slog.Warn("Message transform failed; continuing with previous messages",
"transform", t.name, "agent", a.Name(), "error", err)
continue
}
msgs = out
}
return msgs
}
Loading
Loading