Files
MCP_CyberArk/tests/test_database_server.py
2026-03-29 19:51:51 +02:00

304 lines
9.4 KiB
Python

"""
Tests for the Database MCP tool (db_query).
We patch _dispatch_query (the internal router) rather than individual drivers
so the tests stay driver-agnostic. Driver-specific tests (asyncpg / aiomysql /
pyodbc) are covered in the integration section at the bottom.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from mcp_privileged.config import settings
from mcp_privileged.database.server import (
_cell_str,
_format_result,
db_query,
)
from mcp_privileged.secret_store import secret_store
from tests.conftest import make_db_result
# ── Helpers ───────────────────────────────────────────────────────────────────
async def _handle(username: str = "db_svc", password: str = "DbP@ss!") -> str:
return await secret_store.store(username, password)
def _patch_dispatch(columns: list[str], rows: list[list]):
"""Patch _dispatch_query to return a pre-built result without hitting a DB."""
return patch(
"mcp_privileged.database.server._dispatch_query",
new=AsyncMock(return_value=make_db_result(columns, rows)),
)
# ── Tests ─────────────────────────────────────────────────────────────────────
async def test_db_query_success_postgres(mock_ctx) -> None:
"""Happy path: postgres query returns columns + rows."""
handle = await _handle()
cols = ["id", "name", "email"]
rows = [[1, "Alice", "alice@example.com"], [2, "Bob", "bob@example.com"]]
with _patch_dispatch(cols, rows):
result = await db_query(
host="pg.internal",
database="mydb",
query="SELECT id, name, email FROM users",
secret_handle=handle,
ctx=mock_ctx,
db_type="postgres",
)
assert "Rows returned: 2" in result
assert "id" in result and "name" in result and "email" in result
assert "Alice" in result
assert "Database: mydb (postgres)" in result
async def test_db_query_success_mysql(mock_ctx) -> None:
"""MySQL variant — db_type routing and label are correct."""
handle = await _handle()
with _patch_dispatch(["host_name"], [["mysql-server-01"]]):
result = await db_query(
host="mysql.internal",
database="ops",
query="SELECT @@hostname",
secret_handle=handle,
ctx=mock_ctx,
db_type="mysql",
)
assert "mysql" in result
assert "mysql-server-01" in result
async def test_db_query_success_mssql(mock_ctx) -> None:
"""SQL Server variant — db_type routing and label are correct."""
handle = await _handle()
with _patch_dispatch(["name"], [["SQLSERVER01"]]):
result = await db_query(
host="sql.internal",
database="master",
query="SELECT @@SERVERNAME",
secret_handle=handle,
ctx=mock_ctx,
db_type="mssql",
)
assert "mssql" in result
assert "SQLSERVER01" in result
async def test_db_query_default_port_resolved(mock_ctx) -> None:
"""port=0 triggers the default port for the db_type."""
handle = await _handle()
with _patch_dispatch(["v"], [[42]]) as mock_dispatch:
await db_query(
host="pg.internal",
database="mydb",
query="SELECT 42",
secret_handle=handle,
ctx=mock_ctx,
db_type="postgres",
port=0,
)
_, kwargs = mock_dispatch.call_args
assert kwargs["port"] == 5432
async def test_db_query_custom_port_forwarded(mock_ctx) -> None:
"""Explicit port is forwarded unchanged."""
handle = await _handle()
with _patch_dispatch(["v"], [[1]]) as mock_dispatch:
await db_query(
host="pg.internal",
database="mydb",
query="SELECT 1",
secret_handle=handle,
ctx=mock_ctx,
db_type="postgres",
port=15432,
)
_, kwargs = mock_dispatch.call_args
assert kwargs["port"] == 15432
async def test_db_query_username_override(mock_ctx) -> None:
"""username_override replaces the credential username."""
handle = await _handle(username="readonly_user")
with _patch_dispatch(["v"], [[1]]) as mock_dispatch:
await db_query(
host="pg.internal",
database="mydb",
query="SELECT 1",
secret_handle=handle,
ctx=mock_ctx,
db_type="postgres",
username_override="dba_user",
)
_, kwargs = mock_dispatch.call_args
assert kwargs["username"] == "dba_user"
async def test_db_query_invalid_db_type(mock_ctx) -> None:
"""Unknown db_type raises ValueError before touching the credential store."""
handle = await _handle()
with pytest.raises(ValueError, match="Unsupported db_type"):
await db_query(
host="db.internal",
database="mydb",
query="SELECT 1",
secret_handle=handle,
ctx=mock_ctx,
db_type="oracle",
)
async def test_db_query_invalid_handle(mock_ctx) -> None:
"""Unknown handle raises KeyError and calls ctx.error."""
with pytest.raises(KeyError):
await db_query(
host="pg.internal",
database="mydb",
query="SELECT 1",
secret_handle="secret://doesnotexist0000000000000000",
ctx=mock_ctx,
db_type="postgres",
)
mock_ctx.error.assert_awaited_once()
async def test_db_query_driver_exception_propagates(mock_ctx) -> None:
"""Exceptions from _dispatch_query propagate and call ctx.error."""
handle = await _handle()
with patch(
"mcp_privileged.database.server._dispatch_query",
new=AsyncMock(side_effect=ConnectionRefusedError("DB port closed")),
):
with pytest.raises(ConnectionRefusedError):
await db_query(
host="pg.internal",
database="mydb",
query="SELECT 1",
secret_handle=handle,
ctx=mock_ctx,
db_type="postgres",
)
mock_ctx.error.assert_awaited_once()
async def test_db_query_rows_capped(mock_ctx) -> None:
"""Rows exceeding db_max_rows are truncated and the result says so."""
handle = await _handle()
many_rows = [[i, f"user_{i}"] for i in range(2000)]
with _patch_dispatch(["id", "name"], many_rows):
with patch.object(settings, "db_max_rows", 10):
result = await db_query(
host="pg.internal",
database="mydb",
query="SELECT id, name FROM big_table",
secret_handle=handle,
ctx=mock_ctx,
db_type="postgres",
)
assert "Rows returned: 10" in result
assert "more rows exist" in result
async def test_db_query_empty_result(mock_ctx) -> None:
"""An empty result set is handled gracefully."""
handle = await _handle()
with _patch_dispatch([], []):
result = await db_query(
host="pg.internal",
database="mydb",
query="SELECT 1 WHERE false",
secret_handle=handle,
ctx=mock_ctx,
db_type="postgres",
)
assert "Rows returned: 0" in result
assert "No rows returned" in result
async def test_db_query_password_not_in_ctx_messages(mock_ctx) -> None:
"""The credential password must never leak into ctx.info or ctx.error."""
secret_password = "DB$ecretPass99"
handle = await secret_store.store("db_user", secret_password)
with _patch_dispatch(["v"], [[1]]):
await db_query(
host="pg.internal",
database="mydb",
query="SELECT 1",
secret_handle=handle,
ctx=mock_ctx,
db_type="postgres",
)
all_calls = mock_ctx.info.await_args_list + mock_ctx.error.await_args_list
for call in all_calls:
assert secret_password not in str(call)
# ── Unit tests for helpers ────────────────────────────────────────────────────
def test_cell_str_none() -> None:
assert _cell_str(None) == ""
def test_cell_str_normal() -> None:
assert _cell_str(42) == "42"
assert _cell_str("hello") == "hello"
def test_cell_str_truncated() -> None:
long_val = "x" * 10_000
with patch.object(settings, "db_max_cell_bytes", 10):
result = _cell_str(long_val)
assert "" in result
assert len(result) < 20
def test_format_result_no_rows() -> None:
result = _format_result("host", "db", "postgres", "SELECT 1", [], [], False, 5.0)
assert "No rows returned" in result
def test_format_result_with_rows() -> None:
cols = ["id", "name"]
rows = [[1, "Alice"], [2, "Bob"]]
result = _format_result("host", "db", "postgres", "SELECT ...", cols, rows, False, 12.3)
assert "id" in result
assert "Alice" in result
assert "Bob" in result
assert "Rows returned: 2" in result
def test_format_result_truncated_flag() -> None:
cols = ["id"]
rows = [[i] for i in range(5)]
result = _format_result("host", "db", "postgres", "SELECT ...", cols, rows, True, 1.0)
assert "capped" in result