Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/google/adk/tools/_ssrf_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared SSRF protection helpers for tools that make HTTP requests."""

from __future__ import annotations

import ipaddress
import socket
from urllib.parse import urlparse

_ALLOWED_URL_SCHEMES = frozenset({"http", "https"})
_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address


def is_blocked_hostname(hostname: str) -> bool:
normalized = hostname.rstrip(".").lower()
return normalized == "localhost" or normalized.endswith(".localhost")


def is_blocked_address(address: _ResolvedAddress) -> bool:
return not address.is_global


def resolve_host_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]:
try:
addr = ipaddress.ip_address(hostname)
return (addr,)
except ValueError:
pass

try:
address_info = socket.getaddrinfo(
hostname,
None,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
)
except (socket.gaierror, UnicodeError) as exc:
raise ValueError(f"Unable to resolve host: {hostname}") from exc

resolved: list[_ResolvedAddress] = []
for family, _, _, _, sockaddr in address_info:
if family not in (socket.AF_INET, socket.AF_INET6):
continue
resolved.append(ipaddress.ip_address(sockaddr[0]))

if not resolved:
raise ValueError(f"Unable to resolve host: {hostname}")

return tuple(dict.fromkeys(resolved))


def validate_url(url: str) -> None:
"""Validate a URL against SSRF attacks.

Raises ValueError if the URL targets a blocked host or scheme.
"""
parsed = urlparse(url)
scheme = parsed.scheme.lower()
if scheme not in _ALLOWED_URL_SCHEMES:
raise ValueError(f"Unsupported url scheme: {url}")

hostname = parsed.hostname
if not hostname:
raise ValueError(f"URL is missing a hostname: {url}")

if is_blocked_hostname(hostname):
raise ValueError(f"Blocked host: {hostname}")

resolved = resolve_host_addresses(hostname)
if any(is_blocked_address(addr) for addr in resolved):
raise ValueError(f"Blocked host: {hostname}")
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,13 @@ def __repr__(self):


async def _request(**request_params) -> httpx.Response:
from ..._ssrf_protection import validate_url

url = request_params.get("url", "")
validate_url(url)

async with httpx.AsyncClient(
verify=request_params.pop("verify", True),
timeout=None,
timeout=30,
) as client:
return await client.request(**request_params)
Original file line number Diff line number Diff line change
Expand Up @@ -1117,20 +1117,18 @@ async def test_call_with_verify_options(
else:
assert call_kwargs["verify"] == expected_verify_in_call

async def test_request_uses_no_default_timeout(
async def test_request_uses_finite_timeout(
self,
mock_tool_context,
sample_endpoint,
sample_operation,
sample_auth_scheme,
sample_auth_credential,
):
"""Test that _request creates AsyncClient with timeout=None.
"""Test that _request creates AsyncClient with a finite timeout.

httpx defaults to a 5-second timeout, which is too short for many
real-world API calls. Verify that we explicitly disable the timeout
to match the previous requests-library behavior (no timeout).
Regression test for https://github.com/google/adk-python/issues/4431.
An unbounded timeout allows hanging connections and resource exhaustion.
Verify that the client uses a reasonable finite timeout.
"""
mock_response = mock.create_autospec(requests.Response, instance=True)
mock_response.json.return_value = {"result": "success"}
Expand All @@ -1157,7 +1155,7 @@ async def test_request_uses_no_default_timeout(

assert mock_async_client.called
_, call_kwargs = mock_async_client.call_args
assert call_kwargs["timeout"] is None
assert call_kwargs["timeout"] == 30

async def test_call_with_configure_verify(
self,
Expand Down Expand Up @@ -1502,3 +1500,43 @@ def test_snake_to_lower_camel():
assert snake_to_lower_camel("three_word_example") == "threeWordExample"
assert not snake_to_lower_camel("")
assert snake_to_lower_camel("alreadyCamelCase") == "alreadyCamelCase"


class TestRequestSsrfProtection:

@pytest.mark.asyncio
async def test_request_blocks_localhost(self):
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request

with pytest.raises(ValueError, match="Blocked host"):
await _request(method="GET", url="http://localhost:8080/internal")

@pytest.mark.asyncio
async def test_request_blocks_loopback(self):
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request

with pytest.raises(ValueError, match="Blocked host"):
await _request(method="GET", url="http://127.0.0.1/internal")

@pytest.mark.asyncio
async def test_request_blocks_metadata_endpoint(self):
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request

with pytest.raises(ValueError, match="Blocked host"):
await _request(
method="GET", url="http://169.254.169.254/latest/meta-data/"
)

@pytest.mark.asyncio
async def test_request_blocks_private_ip(self):
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request

with pytest.raises(ValueError, match="Blocked host"):
await _request(method="GET", url="http://10.0.0.1/admin")

@pytest.mark.asyncio
async def test_request_blocks_file_scheme(self):
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request

with pytest.raises(ValueError, match="Unsupported url scheme"):
await _request(method="GET", url="file:///etc/passwd")
109 changes: 109 additions & 0 deletions tests/unittests/tools/test_ssrf_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from google.adk.tools._ssrf_protection import is_blocked_address
from google.adk.tools._ssrf_protection import is_blocked_hostname
from google.adk.tools._ssrf_protection import validate_url

import ipaddress


class TestIsBlockedHostname:

def test_localhost_blocked(self):
assert is_blocked_hostname("localhost")

def test_localhost_trailing_dot(self):
assert is_blocked_hostname("localhost.")

def test_subdomain_localhost_blocked(self):
assert is_blocked_hostname("foo.localhost")

def test_case_insensitive(self):
assert is_blocked_hostname("LOCALHOST")

def test_normal_hostname_allowed(self):
assert not is_blocked_hostname("example.com")

def test_hostname_containing_localhost_allowed(self):
assert not is_blocked_hostname("notlocalhost.com")


class TestIsBlockedAddress:

def test_loopback_blocked(self):
assert is_blocked_address(ipaddress.ip_address("127.0.0.1"))

def test_link_local_blocked(self):
assert is_blocked_address(ipaddress.ip_address("169.254.169.254"))

def test_private_blocked(self):
assert is_blocked_address(ipaddress.ip_address("10.0.0.1"))
assert is_blocked_address(ipaddress.ip_address("192.168.1.1"))
assert is_blocked_address(ipaddress.ip_address("172.16.0.1"))

def test_ipv6_loopback_blocked(self):
assert is_blocked_address(ipaddress.ip_address("::1"))

def test_global_allowed(self):
assert not is_blocked_address(ipaddress.ip_address("8.8.8.8"))


class TestValidateUrl:

def test_localhost_blocked(self):
with pytest.raises(ValueError, match="Blocked host"):
validate_url("http://localhost:8080/path")

def test_loopback_ip_blocked(self):
with pytest.raises(ValueError, match="Blocked host"):
validate_url("http://127.0.0.1/path")

def test_link_local_blocked(self):
with pytest.raises(ValueError, match="Blocked host"):
validate_url("http://169.254.169.254/latest/meta-data/")

def test_private_ip_blocked(self):
with pytest.raises(ValueError, match="Blocked host"):
validate_url("http://10.0.0.1/internal")

def test_ftp_scheme_blocked(self):
with pytest.raises(ValueError, match="Unsupported url scheme"):
validate_url("ftp://example.com/file")

def test_file_scheme_blocked(self):
with pytest.raises(ValueError, match="Unsupported url scheme"):
validate_url("file:///etc/passwd")

def test_no_hostname_blocked(self):
with pytest.raises(ValueError, match="missing a hostname"):
validate_url("http:///path")

@pytest.fixture(autouse=True)
def _patch_dns(self, monkeypatch):
import socket as _socket

original = _socket.getaddrinfo

def fake_getaddrinfo(host, port, *args, **kwargs):
if host == "api.example.com":
return [(_socket.AF_INET, _socket.SOCK_STREAM, 6, "", ("93.184.215.14", 0))]
return original(host, port, *args, **kwargs)

monkeypatch.setattr(_socket, "getaddrinfo", fake_getaddrinfo)

def test_public_url_allowed(self):
validate_url("https://api.example.com/v1/resource")
Loading