|
7 | 7 | "io" |
8 | 8 | "net/http" |
9 | 9 | "net/url" |
| 10 | + "slices" |
10 | 11 | "strings" |
11 | 12 | "time" |
12 | 13 |
|
@@ -270,56 +271,38 @@ func (h *fetchHandler) checkDomainAllowed(u *url.URL) error { |
270 | 271 | if host == "" { |
271 | 272 | return errors.New("URL has no host") |
272 | 273 | } |
273 | | - if len(h.blockedDomains) > 0 && matchesAnyDomain(host, h.blockedDomains) { |
274 | | - return fmt.Errorf("URL host %q is blocked by blocked_domains", host) |
| 274 | + matchesAny := func(patterns []string) bool { |
| 275 | + return slices.ContainsFunc(patterns, func(p string) bool { |
| 276 | + return matchesDomain(host, p) |
| 277 | + }) |
275 | 278 | } |
276 | | - if len(h.allowedDomains) > 0 && !matchesAnyDomain(host, h.allowedDomains) { |
| 279 | + switch { |
| 280 | + case len(h.blockedDomains) > 0 && matchesAny(h.blockedDomains): |
| 281 | + return fmt.Errorf("URL host %q is blocked by blocked_domains", host) |
| 282 | + case len(h.allowedDomains) > 0 && !matchesAny(h.allowedDomains): |
277 | 283 | return fmt.Errorf("URL host %q is not in allowed_domains", host) |
278 | 284 | } |
279 | 285 | return nil |
280 | 286 | } |
281 | 287 |
|
282 | | -// matchesAnyDomain reports whether host matches any of the supplied patterns. |
283 | | -// See matchesDomain for the matching rules. |
284 | | -func matchesAnyDomain(host string, patterns []string) bool { |
285 | | - for _, p := range patterns { |
286 | | - if matchesDomain(host, p) { |
287 | | - return true |
288 | | - } |
289 | | - } |
290 | | - return false |
291 | | -} |
292 | | - |
293 | | -// matchesDomain reports whether host matches pattern. |
| 288 | +// matchesDomain reports whether host matches pattern (case-insensitive). |
294 | 289 | // |
295 | | -// Matching rules (case-insensitive): |
296 | | -// - An empty pattern matches nothing. |
297 | | -// - A pattern with a leading dot (".example.com") matches strict subdomains |
298 | | -// of example.com but NOT example.com itself. |
299 | | -// - Any other pattern ("example.com") matches the host exactly and any of |
300 | | -// its subdomains (e.g. "docs.example.com"). It does NOT match unrelated |
301 | | -// hosts that share a suffix (e.g. "badexample.com"). |
| 290 | +// A bare pattern ("example.com") matches the host exactly or any subdomain |
| 291 | +// ("docs.example.com"); it does NOT match unrelated hosts that share a suffix |
| 292 | +// ("badexample.com"). A pattern with a leading dot (".example.com") matches |
| 293 | +// strict subdomains only — the apex "example.com" is excluded. |
302 | 294 | func matchesDomain(host, pattern string) bool { |
303 | 295 | host = strings.ToLower(strings.TrimSpace(host)) |
304 | 296 | pattern = strings.ToLower(strings.TrimSpace(pattern)) |
305 | | - if host == "" || pattern == "" { |
| 297 | + if host == "" || pattern == "" || pattern == "." { |
306 | 298 | return false |
307 | 299 | } |
308 | | - // Strip IPv6 brackets if any (url.URL.Hostname already does this, but be safe). |
309 | | - host = strings.Trim(host, "[]") |
310 | | - |
311 | | - subdomainOnly := strings.HasPrefix(pattern, ".") |
312 | | - if subdomainOnly { |
313 | | - pattern = strings.TrimPrefix(pattern, ".") |
314 | | - if pattern == "" { |
315 | | - return false |
316 | | - } |
317 | | - return strings.HasSuffix(host, "."+pattern) |
| 300 | + if strings.HasPrefix(pattern, ".") { |
| 301 | + // Strict subdomain match: ".example.com" matches "x.example.com" but not "example.com". |
| 302 | + return strings.HasSuffix(host, pattern) |
318 | 303 | } |
319 | | - if host == pattern { |
320 | | - return true |
321 | | - } |
322 | | - return strings.HasSuffix(host, "."+pattern) |
| 304 | + // Apex or subdomain match. |
| 305 | + return host == pattern || strings.HasSuffix(host, "."+pattern) |
323 | 306 | } |
324 | 307 |
|
325 | 308 | func htmlToMarkdown(html string) string { |
@@ -376,17 +359,12 @@ func WithBlockedDomains(domains []string) FetchToolOption { |
376 | 359 |
|
377 | 360 | func (t *FetchTool) Instructions() string { |
378 | 361 | var b strings.Builder |
379 | | - b.WriteString("## Fetch Tool\n\n") |
380 | | - b.WriteString("Fetch content from HTTP/HTTPS URLs. Supports multiple URLs per call, output format selection (text, markdown, html), and respects robots.txt.") |
381 | | - if len(t.handler.allowedDomains) > 0 { |
382 | | - b.WriteString("\n\nThis tool is restricted to the following domains (and their subdomains): ") |
383 | | - b.WriteString(strings.Join(t.handler.allowedDomains, ", ")) |
384 | | - b.WriteString(". Requests to any other host will fail without making a network call.") |
385 | | - } |
386 | | - if len(t.handler.blockedDomains) > 0 { |
387 | | - b.WriteString("\n\nThis tool is forbidden from fetching the following domains (and their subdomains): ") |
388 | | - b.WriteString(strings.Join(t.handler.blockedDomains, ", ")) |
389 | | - b.WriteString(". Requests to those hosts will fail without making a network call.") |
| 362 | + b.WriteString("## Fetch Tool\n\nFetch content from HTTP/HTTPS URLs. Supports multiple URLs per call, output format selection (text, markdown, html), and respects robots.txt.") |
| 363 | + if d := t.handler.allowedDomains; len(d) > 0 { |
| 364 | + fmt.Fprintf(&b, "\n\nThis tool is restricted to these domains (and any subdomain): %s. Other hosts are rejected without a network call.", strings.Join(d, ", ")) |
| 365 | + } |
| 366 | + if d := t.handler.blockedDomains; len(d) > 0 { |
| 367 | + fmt.Fprintf(&b, "\n\nThis tool must not fetch these domains (or any subdomain): %s. They are rejected without a network call.", strings.Join(d, ", ")) |
390 | 368 | } |
391 | 369 | return b.String() |
392 | 370 | } |
|
0 commit comments