304 lines
9.4 KiB
Python
304 lines
9.4 KiB
Python
"""
|
|
Tests for the Database MCP tool (db_query).
|
|
|
|
We patch _dispatch_query (the internal router) rather than individual drivers
|
|
so the tests stay driver-agnostic. Driver-specific tests (asyncpg / aiomysql /
|
|
pyodbc) are covered in the integration section at the bottom.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from mcp_privileged.config import settings
|
|
from mcp_privileged.database.server import (
|
|
_cell_str,
|
|
_format_result,
|
|
db_query,
|
|
)
|
|
from mcp_privileged.secret_store import secret_store
|
|
from tests.conftest import make_db_result
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
async def _handle(username: str = "db_svc", password: str = "DbP@ss!") -> str:
|
|
return await secret_store.store(username, password)
|
|
|
|
|
|
def _patch_dispatch(columns: list[str], rows: list[list]):
|
|
"""Patch _dispatch_query to return a pre-built result without hitting a DB."""
|
|
return patch(
|
|
"mcp_privileged.database.server._dispatch_query",
|
|
new=AsyncMock(return_value=make_db_result(columns, rows)),
|
|
)
|
|
|
|
|
|
# ── Tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
async def test_db_query_success_postgres(mock_ctx) -> None:
|
|
"""Happy path: postgres query returns columns + rows."""
|
|
handle = await _handle()
|
|
cols = ["id", "name", "email"]
|
|
rows = [[1, "Alice", "alice@example.com"], [2, "Bob", "bob@example.com"]]
|
|
|
|
with _patch_dispatch(cols, rows):
|
|
result = await db_query(
|
|
host="pg.internal",
|
|
database="mydb",
|
|
query="SELECT id, name, email FROM users",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="postgres",
|
|
)
|
|
|
|
assert "Rows returned: 2" in result
|
|
assert "id" in result and "name" in result and "email" in result
|
|
assert "Alice" in result
|
|
assert "Database: mydb (postgres)" in result
|
|
|
|
|
|
async def test_db_query_success_mysql(mock_ctx) -> None:
|
|
"""MySQL variant — db_type routing and label are correct."""
|
|
handle = await _handle()
|
|
|
|
with _patch_dispatch(["host_name"], [["mysql-server-01"]]):
|
|
result = await db_query(
|
|
host="mysql.internal",
|
|
database="ops",
|
|
query="SELECT @@hostname",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="mysql",
|
|
)
|
|
|
|
assert "mysql" in result
|
|
assert "mysql-server-01" in result
|
|
|
|
|
|
async def test_db_query_success_mssql(mock_ctx) -> None:
|
|
"""SQL Server variant — db_type routing and label are correct."""
|
|
handle = await _handle()
|
|
|
|
with _patch_dispatch(["name"], [["SQLSERVER01"]]):
|
|
result = await db_query(
|
|
host="sql.internal",
|
|
database="master",
|
|
query="SELECT @@SERVERNAME",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="mssql",
|
|
)
|
|
|
|
assert "mssql" in result
|
|
assert "SQLSERVER01" in result
|
|
|
|
|
|
async def test_db_query_default_port_resolved(mock_ctx) -> None:
|
|
"""port=0 triggers the default port for the db_type."""
|
|
handle = await _handle()
|
|
|
|
with _patch_dispatch(["v"], [[42]]) as mock_dispatch:
|
|
await db_query(
|
|
host="pg.internal",
|
|
database="mydb",
|
|
query="SELECT 42",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="postgres",
|
|
port=0,
|
|
)
|
|
|
|
_, kwargs = mock_dispatch.call_args
|
|
assert kwargs["port"] == 5432
|
|
|
|
|
|
async def test_db_query_custom_port_forwarded(mock_ctx) -> None:
|
|
"""Explicit port is forwarded unchanged."""
|
|
handle = await _handle()
|
|
|
|
with _patch_dispatch(["v"], [[1]]) as mock_dispatch:
|
|
await db_query(
|
|
host="pg.internal",
|
|
database="mydb",
|
|
query="SELECT 1",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="postgres",
|
|
port=15432,
|
|
)
|
|
|
|
_, kwargs = mock_dispatch.call_args
|
|
assert kwargs["port"] == 15432
|
|
|
|
|
|
async def test_db_query_username_override(mock_ctx) -> None:
|
|
"""username_override replaces the credential username."""
|
|
handle = await _handle(username="readonly_user")
|
|
|
|
with _patch_dispatch(["v"], [[1]]) as mock_dispatch:
|
|
await db_query(
|
|
host="pg.internal",
|
|
database="mydb",
|
|
query="SELECT 1",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="postgres",
|
|
username_override="dba_user",
|
|
)
|
|
|
|
_, kwargs = mock_dispatch.call_args
|
|
assert kwargs["username"] == "dba_user"
|
|
|
|
|
|
async def test_db_query_invalid_db_type(mock_ctx) -> None:
|
|
"""Unknown db_type raises ValueError before touching the credential store."""
|
|
handle = await _handle()
|
|
|
|
with pytest.raises(ValueError, match="Unsupported db_type"):
|
|
await db_query(
|
|
host="db.internal",
|
|
database="mydb",
|
|
query="SELECT 1",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="oracle",
|
|
)
|
|
|
|
|
|
async def test_db_query_invalid_handle(mock_ctx) -> None:
|
|
"""Unknown handle raises KeyError and calls ctx.error."""
|
|
with pytest.raises(KeyError):
|
|
await db_query(
|
|
host="pg.internal",
|
|
database="mydb",
|
|
query="SELECT 1",
|
|
secret_handle="secret://doesnotexist0000000000000000",
|
|
ctx=mock_ctx,
|
|
db_type="postgres",
|
|
)
|
|
|
|
mock_ctx.error.assert_awaited_once()
|
|
|
|
|
|
async def test_db_query_driver_exception_propagates(mock_ctx) -> None:
|
|
"""Exceptions from _dispatch_query propagate and call ctx.error."""
|
|
handle = await _handle()
|
|
|
|
with patch(
|
|
"mcp_privileged.database.server._dispatch_query",
|
|
new=AsyncMock(side_effect=ConnectionRefusedError("DB port closed")),
|
|
):
|
|
with pytest.raises(ConnectionRefusedError):
|
|
await db_query(
|
|
host="pg.internal",
|
|
database="mydb",
|
|
query="SELECT 1",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="postgres",
|
|
)
|
|
|
|
mock_ctx.error.assert_awaited_once()
|
|
|
|
|
|
async def test_db_query_rows_capped(mock_ctx) -> None:
|
|
"""Rows exceeding db_max_rows are truncated and the result says so."""
|
|
handle = await _handle()
|
|
many_rows = [[i, f"user_{i}"] for i in range(2000)]
|
|
|
|
with _patch_dispatch(["id", "name"], many_rows):
|
|
with patch.object(settings, "db_max_rows", 10):
|
|
result = await db_query(
|
|
host="pg.internal",
|
|
database="mydb",
|
|
query="SELECT id, name FROM big_table",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="postgres",
|
|
)
|
|
|
|
assert "Rows returned: 10" in result
|
|
assert "more rows exist" in result
|
|
|
|
|
|
async def test_db_query_empty_result(mock_ctx) -> None:
|
|
"""An empty result set is handled gracefully."""
|
|
handle = await _handle()
|
|
|
|
with _patch_dispatch([], []):
|
|
result = await db_query(
|
|
host="pg.internal",
|
|
database="mydb",
|
|
query="SELECT 1 WHERE false",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="postgres",
|
|
)
|
|
|
|
assert "Rows returned: 0" in result
|
|
assert "No rows returned" in result
|
|
|
|
|
|
async def test_db_query_password_not_in_ctx_messages(mock_ctx) -> None:
|
|
"""The credential password must never leak into ctx.info or ctx.error."""
|
|
secret_password = "DB$ecretPass99"
|
|
handle = await secret_store.store("db_user", secret_password)
|
|
|
|
with _patch_dispatch(["v"], [[1]]):
|
|
await db_query(
|
|
host="pg.internal",
|
|
database="mydb",
|
|
query="SELECT 1",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
db_type="postgres",
|
|
)
|
|
|
|
all_calls = mock_ctx.info.await_args_list + mock_ctx.error.await_args_list
|
|
for call in all_calls:
|
|
assert secret_password not in str(call)
|
|
|
|
|
|
# ── Unit tests for helpers ────────────────────────────────────────────────────
|
|
|
|
def test_cell_str_none() -> None:
|
|
assert _cell_str(None) == ""
|
|
|
|
|
|
def test_cell_str_normal() -> None:
|
|
assert _cell_str(42) == "42"
|
|
assert _cell_str("hello") == "hello"
|
|
|
|
|
|
def test_cell_str_truncated() -> None:
|
|
long_val = "x" * 10_000
|
|
with patch.object(settings, "db_max_cell_bytes", 10):
|
|
result = _cell_str(long_val)
|
|
assert "…" in result
|
|
assert len(result) < 20
|
|
|
|
|
|
def test_format_result_no_rows() -> None:
|
|
result = _format_result("host", "db", "postgres", "SELECT 1", [], [], False, 5.0)
|
|
assert "No rows returned" in result
|
|
|
|
|
|
def test_format_result_with_rows() -> None:
|
|
cols = ["id", "name"]
|
|
rows = [[1, "Alice"], [2, "Bob"]]
|
|
result = _format_result("host", "db", "postgres", "SELECT ...", cols, rows, False, 12.3)
|
|
assert "id" in result
|
|
assert "Alice" in result
|
|
assert "Bob" in result
|
|
assert "Rows returned: 2" in result
|
|
|
|
|
|
def test_format_result_truncated_flag() -> None:
|
|
cols = ["id"]
|
|
rows = [[i] for i in range(5)]
|
|
result = _format_result("host", "db", "postgres", "SELECT ...", cols, rows, True, 1.0)
|
|
assert "capped" in result
|