diff --git a/src/google/adk/tools/_ssrf_protection.py b/src/google/adk/tools/_ssrf_protection.py new file mode 100644 index 0000000000..44dee5b98c --- /dev/null +++ b/src/google/adk/tools/_ssrf_protection.py @@ -0,0 +1,84 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared SSRF protection helpers for tools that make HTTP requests.""" + +from __future__ import annotations + +import ipaddress +import socket +from urllib.parse import urlparse + +_ALLOWED_URL_SCHEMES = frozenset({"http", "https"}) +_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address + + +def is_blocked_hostname(hostname: str) -> bool: + normalized = hostname.rstrip(".").lower() + return normalized == "localhost" or normalized.endswith(".localhost") + + +def is_blocked_address(address: _ResolvedAddress) -> bool: + return not address.is_global + + +def resolve_host_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]: + try: + addr = ipaddress.ip_address(hostname) + return (addr,) + except ValueError: + pass + + try: + address_info = socket.getaddrinfo( + hostname, + None, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + ) + except (socket.gaierror, UnicodeError) as exc: + raise ValueError(f"Unable to resolve host: {hostname}") from exc + + resolved: list[_ResolvedAddress] = [] + for family, _, _, _, sockaddr in address_info: + if family not in (socket.AF_INET, socket.AF_INET6): + continue + resolved.append(ipaddress.ip_address(sockaddr[0])) + + if not resolved: + raise ValueError(f"Unable to resolve host: {hostname}") + + return tuple(dict.fromkeys(resolved)) + + +def validate_url(url: str) -> None: + """Validate a URL against SSRF attacks. + + Raises ValueError if the URL targets a blocked host or scheme. + """ + parsed = urlparse(url) + scheme = parsed.scheme.lower() + if scheme not in _ALLOWED_URL_SCHEMES: + raise ValueError(f"Unsupported url scheme: {url}") + + hostname = parsed.hostname + if not hostname: + raise ValueError(f"URL is missing a hostname: {url}") + + if is_blocked_hostname(hostname): + raise ValueError(f"Blocked host: {hostname}") + + resolved = resolve_host_addresses(hostname) + if any(is_blocked_address(addr) for addr in resolved): + raise ValueError(f"Blocked host: {hostname}") diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index fa32ce932a..8eb900def6 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -570,8 +570,13 @@ def __repr__(self): async def _request(**request_params) -> httpx.Response: + from ..._ssrf_protection import validate_url + + url = request_params.get("url", "") + validate_url(url) + async with httpx.AsyncClient( verify=request_params.pop("verify", True), - timeout=None, + timeout=30, ) as client: return await client.request(**request_params) diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index fa21201488..6019684bbc 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -1117,7 +1117,7 @@ async def test_call_with_verify_options( else: assert call_kwargs["verify"] == expected_verify_in_call - async def test_request_uses_no_default_timeout( + async def test_request_uses_finite_timeout( self, mock_tool_context, sample_endpoint, @@ -1125,12 +1125,10 @@ async def test_request_uses_no_default_timeout( sample_auth_scheme, sample_auth_credential, ): - """Test that _request creates AsyncClient with timeout=None. + """Test that _request creates AsyncClient with a finite timeout. - httpx defaults to a 5-second timeout, which is too short for many - real-world API calls. Verify that we explicitly disable the timeout - to match the previous requests-library behavior (no timeout). - Regression test for https://github.com/google/adk-python/issues/4431. + An unbounded timeout allows hanging connections and resource exhaustion. + Verify that the client uses a reasonable finite timeout. """ mock_response = mock.create_autospec(requests.Response, instance=True) mock_response.json.return_value = {"result": "success"} @@ -1157,7 +1155,7 @@ async def test_request_uses_no_default_timeout( assert mock_async_client.called _, call_kwargs = mock_async_client.call_args - assert call_kwargs["timeout"] is None + assert call_kwargs["timeout"] == 30 async def test_call_with_configure_verify( self, @@ -1502,3 +1500,43 @@ def test_snake_to_lower_camel(): assert snake_to_lower_camel("three_word_example") == "threeWordExample" assert not snake_to_lower_camel("") assert snake_to_lower_camel("alreadyCamelCase") == "alreadyCamelCase" + + +class TestRequestSsrfProtection: + + @pytest.mark.asyncio + async def test_request_blocks_localhost(self): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request + + with pytest.raises(ValueError, match="Blocked host"): + await _request(method="GET", url="http://localhost:8080/internal") + + @pytest.mark.asyncio + async def test_request_blocks_loopback(self): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request + + with pytest.raises(ValueError, match="Blocked host"): + await _request(method="GET", url="http://127.0.0.1/internal") + + @pytest.mark.asyncio + async def test_request_blocks_metadata_endpoint(self): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request + + with pytest.raises(ValueError, match="Blocked host"): + await _request( + method="GET", url="http://169.254.169.254/latest/meta-data/" + ) + + @pytest.mark.asyncio + async def test_request_blocks_private_ip(self): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request + + with pytest.raises(ValueError, match="Blocked host"): + await _request(method="GET", url="http://10.0.0.1/admin") + + @pytest.mark.asyncio + async def test_request_blocks_file_scheme(self): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request + + with pytest.raises(ValueError, match="Unsupported url scheme"): + await _request(method="GET", url="file:///etc/passwd") diff --git a/tests/unittests/tools/test_ssrf_protection.py b/tests/unittests/tools/test_ssrf_protection.py new file mode 100644 index 0000000000..c421c3128e --- /dev/null +++ b/tests/unittests/tools/test_ssrf_protection.py @@ -0,0 +1,109 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.adk.tools._ssrf_protection import is_blocked_address +from google.adk.tools._ssrf_protection import is_blocked_hostname +from google.adk.tools._ssrf_protection import validate_url + +import ipaddress + + +class TestIsBlockedHostname: + + def test_localhost_blocked(self): + assert is_blocked_hostname("localhost") + + def test_localhost_trailing_dot(self): + assert is_blocked_hostname("localhost.") + + def test_subdomain_localhost_blocked(self): + assert is_blocked_hostname("foo.localhost") + + def test_case_insensitive(self): + assert is_blocked_hostname("LOCALHOST") + + def test_normal_hostname_allowed(self): + assert not is_blocked_hostname("example.com") + + def test_hostname_containing_localhost_allowed(self): + assert not is_blocked_hostname("notlocalhost.com") + + +class TestIsBlockedAddress: + + def test_loopback_blocked(self): + assert is_blocked_address(ipaddress.ip_address("127.0.0.1")) + + def test_link_local_blocked(self): + assert is_blocked_address(ipaddress.ip_address("169.254.169.254")) + + def test_private_blocked(self): + assert is_blocked_address(ipaddress.ip_address("10.0.0.1")) + assert is_blocked_address(ipaddress.ip_address("192.168.1.1")) + assert is_blocked_address(ipaddress.ip_address("172.16.0.1")) + + def test_ipv6_loopback_blocked(self): + assert is_blocked_address(ipaddress.ip_address("::1")) + + def test_global_allowed(self): + assert not is_blocked_address(ipaddress.ip_address("8.8.8.8")) + + +class TestValidateUrl: + + def test_localhost_blocked(self): + with pytest.raises(ValueError, match="Blocked host"): + validate_url("http://localhost:8080/path") + + def test_loopback_ip_blocked(self): + with pytest.raises(ValueError, match="Blocked host"): + validate_url("http://127.0.0.1/path") + + def test_link_local_blocked(self): + with pytest.raises(ValueError, match="Blocked host"): + validate_url("http://169.254.169.254/latest/meta-data/") + + def test_private_ip_blocked(self): + with pytest.raises(ValueError, match="Blocked host"): + validate_url("http://10.0.0.1/internal") + + def test_ftp_scheme_blocked(self): + with pytest.raises(ValueError, match="Unsupported url scheme"): + validate_url("ftp://example.com/file") + + def test_file_scheme_blocked(self): + with pytest.raises(ValueError, match="Unsupported url scheme"): + validate_url("file:///etc/passwd") + + def test_no_hostname_blocked(self): + with pytest.raises(ValueError, match="missing a hostname"): + validate_url("http:///path") + + @pytest.fixture(autouse=True) + def _patch_dns(self, monkeypatch): + import socket as _socket + + original = _socket.getaddrinfo + + def fake_getaddrinfo(host, port, *args, **kwargs): + if host == "api.example.com": + return [(_socket.AF_INET, _socket.SOCK_STREAM, 6, "", ("93.184.215.14", 0))] + return original(host, port, *args, **kwargs) + + monkeypatch.setattr(_socket, "getaddrinfo", fake_getaddrinfo) + + def test_public_url_allowed(self): + validate_url("https://api.example.com/v1/resource")