Skip to content

Commit 7a516b4

Browse files
committed
refactor(fetch): simplify domain matcher and instructions
- Drop the matchesAnyDomain helper; use slices.ContainsFunc inline. - Collapse checkDomainAllowed branches into a switch. - Simplify matchesDomain by leveraging the leading dot directly (no more subdomainOnly bool, no dead IPv6 bracket-strip). - Tighten Instructions() with fmt.Fprintf and shorter phrasing. No functional change; matcher truth table and integration tests unchanged. Assisted-By: docker-agent
1 parent ff15023 commit 7a516b4

2 files changed

Lines changed: 28 additions & 50 deletions

File tree

pkg/tools/builtin/fetch.go

Lines changed: 26 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net/http"
99
"net/url"
10+
"slices"
1011
"strings"
1112
"time"
1213

@@ -270,56 +271,38 @@ func (h *fetchHandler) checkDomainAllowed(u *url.URL) error {
270271
if host == "" {
271272
return errors.New("URL has no host")
272273
}
273-
if len(h.blockedDomains) > 0 && matchesAnyDomain(host, h.blockedDomains) {
274-
return fmt.Errorf("URL host %q is blocked by blocked_domains", host)
274+
matchesAny := func(patterns []string) bool {
275+
return slices.ContainsFunc(patterns, func(p string) bool {
276+
return matchesDomain(host, p)
277+
})
275278
}
276-
if len(h.allowedDomains) > 0 && !matchesAnyDomain(host, h.allowedDomains) {
279+
switch {
280+
case len(h.blockedDomains) > 0 && matchesAny(h.blockedDomains):
281+
return fmt.Errorf("URL host %q is blocked by blocked_domains", host)
282+
case len(h.allowedDomains) > 0 && !matchesAny(h.allowedDomains):
277283
return fmt.Errorf("URL host %q is not in allowed_domains", host)
278284
}
279285
return nil
280286
}
281287

282-
// matchesAnyDomain reports whether host matches any of the supplied patterns.
283-
// See matchesDomain for the matching rules.
284-
func matchesAnyDomain(host string, patterns []string) bool {
285-
for _, p := range patterns {
286-
if matchesDomain(host, p) {
287-
return true
288-
}
289-
}
290-
return false
291-
}
292-
293-
// matchesDomain reports whether host matches pattern.
288+
// matchesDomain reports whether host matches pattern (case-insensitive).
294289
//
295-
// Matching rules (case-insensitive):
296-
// - An empty pattern matches nothing.
297-
// - A pattern with a leading dot (".example.com") matches strict subdomains
298-
// of example.com but NOT example.com itself.
299-
// - Any other pattern ("example.com") matches the host exactly and any of
300-
// its subdomains (e.g. "docs.example.com"). It does NOT match unrelated
301-
// hosts that share a suffix (e.g. "badexample.com").
290+
// A bare pattern ("example.com") matches the host exactly or any subdomain
291+
// ("docs.example.com"); it does NOT match unrelated hosts that share a suffix
292+
// ("badexample.com"). A pattern with a leading dot (".example.com") matches
293+
// strict subdomains only — the apex "example.com" is excluded.
302294
func matchesDomain(host, pattern string) bool {
303295
host = strings.ToLower(strings.TrimSpace(host))
304296
pattern = strings.ToLower(strings.TrimSpace(pattern))
305-
if host == "" || pattern == "" {
297+
if host == "" || pattern == "" || pattern == "." {
306298
return false
307299
}
308-
// Strip IPv6 brackets if any (url.URL.Hostname already does this, but be safe).
309-
host = strings.Trim(host, "[]")
310-
311-
subdomainOnly := strings.HasPrefix(pattern, ".")
312-
if subdomainOnly {
313-
pattern = strings.TrimPrefix(pattern, ".")
314-
if pattern == "" {
315-
return false
316-
}
317-
return strings.HasSuffix(host, "."+pattern)
300+
if strings.HasPrefix(pattern, ".") {
301+
// Strict subdomain match: ".example.com" matches "x.example.com" but not "example.com".
302+
return strings.HasSuffix(host, pattern)
318303
}
319-
if host == pattern {
320-
return true
321-
}
322-
return strings.HasSuffix(host, "."+pattern)
304+
// Apex or subdomain match.
305+
return host == pattern || strings.HasSuffix(host, "."+pattern)
323306
}
324307

325308
func htmlToMarkdown(html string) string {
@@ -376,17 +359,12 @@ func WithBlockedDomains(domains []string) FetchToolOption {
376359

377360
func (t *FetchTool) Instructions() string {
378361
var b strings.Builder
379-
b.WriteString("## Fetch Tool\n\n")
380-
b.WriteString("Fetch content from HTTP/HTTPS URLs. Supports multiple URLs per call, output format selection (text, markdown, html), and respects robots.txt.")
381-
if len(t.handler.allowedDomains) > 0 {
382-
b.WriteString("\n\nThis tool is restricted to the following domains (and their subdomains): ")
383-
b.WriteString(strings.Join(t.handler.allowedDomains, ", "))
384-
b.WriteString(". Requests to any other host will fail without making a network call.")
385-
}
386-
if len(t.handler.blockedDomains) > 0 {
387-
b.WriteString("\n\nThis tool is forbidden from fetching the following domains (and their subdomains): ")
388-
b.WriteString(strings.Join(t.handler.blockedDomains, ", "))
389-
b.WriteString(". Requests to those hosts will fail without making a network call.")
362+
b.WriteString("## Fetch Tool\n\nFetch content from HTTP/HTTPS URLs. Supports multiple URLs per call, output format selection (text, markdown, html), and respects robots.txt.")
363+
if d := t.handler.allowedDomains; len(d) > 0 {
364+
fmt.Fprintf(&b, "\n\nThis tool is restricted to these domains (and any subdomain): %s. Other hosts are rejected without a network call.", strings.Join(d, ", "))
365+
}
366+
if d := t.handler.blockedDomains; len(d) > 0 {
367+
fmt.Fprintf(&b, "\n\nThis tool must not fetch these domains (or any subdomain): %s. They are rejected without a network call.", strings.Join(d, ", "))
390368
}
391369
return b.String()
392370
}

pkg/tools/builtin/fetch_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ func TestFetchTool_AllowedDomainsAppearInInstructions(t *testing.T) {
399399

400400
instructions := tools.GetInstructions(tool)
401401

402-
assert.Contains(t, instructions, "restricted to the following domains")
402+
assert.Contains(t, instructions, "restricted to these domains")
403403
assert.Contains(t, instructions, "docker.com")
404404
assert.Contains(t, instructions, "github.com")
405405
}
@@ -409,7 +409,7 @@ func TestFetchTool_BlockedDomainsAppearInInstructions(t *testing.T) {
409409

410410
instructions := tools.GetInstructions(tool)
411411

412-
assert.Contains(t, instructions, "forbidden from fetching")
412+
assert.Contains(t, instructions, "must not fetch")
413413
assert.Contains(t, instructions, "169.254.169.254")
414414
}
415415

0 commit comments

Comments
 (0)