Files
MCP_CyberArk/tests/test_ssh_server.py
2026-03-29 19:51:51 +02:00

292 lines
9.7 KiB
Python

"""
Tests for the SSH MCP tool (ssh_execute).
All tests mock asyncssh.connect — no real SSH connections are made.
The secret_store is used directly so handle issuance/resolution is tested
end-to-end through the real store.
"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import asyncssh
import pytest
from mcp_privileged.secret_store import secret_store
from mcp_privileged.ssh.server import _truncate, _format_result, ssh_execute
# ── Helpers ───────────────────────────────────────────────────────────────────
def _make_ctx() -> MagicMock:
"""Return a minimal mock MCP Context."""
ctx = MagicMock()
ctx.info = AsyncMock()
ctx.error = AsyncMock()
# _extract_client_ip uses these
ctx.request_context.request.headers = {}
ctx.request_context.request.client = None
return ctx
def _make_ssh_cm(
stdout: str = "",
stderr: str = "",
exit_status: int = 0,
) -> tuple[AsyncMock, AsyncMock]:
"""
Build a mock for asyncssh.connect used as an async context manager.
Returns (context_manager_mock, conn_mock).
Patch asyncssh.connect with return_value=context_manager_mock.
"""
mock_conn = AsyncMock()
mock_conn.run = AsyncMock(
return_value=MagicMock(stdout=stdout, stderr=stderr, exit_status=exit_status)
)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
mock_cm.__aexit__ = AsyncMock(return_value=False)
return mock_cm, mock_conn
async def _fresh_handle(username: str = "svc_user", password: str = "P@ssw0rd!") -> str:
"""Store a credential and return a fresh (unconsumed) handle."""
return await secret_store.store(username, password)
# ── Tests ─────────────────────────────────────────────────────────────────────
async def test_ssh_execute_success() -> None:
"""Happy path: command runs, stdout is returned, exit code is 0."""
handle = await _fresh_handle()
ctx = _make_ctx()
mock_cm, _ = _make_ssh_cm(stdout="hello world\n", exit_status=0)
with patch("mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm):
result = await ssh_execute(
host="linux01.internal",
command="echo hello world",
secret_handle=handle,
ctx=ctx,
)
assert "Exit code: 0" in result
assert "hello world" in result
assert "Host: linux01.internal" in result
assert "Command: echo hello world" in result
async def test_ssh_execute_nonzero_exit_not_raised() -> None:
"""A non-zero exit code is returned in the result, not raised as an exception."""
handle = await _fresh_handle()
ctx = _make_ctx()
mock_cm, _ = _make_ssh_cm(stdout="", stderr="command not found\n", exit_status=127)
with patch("mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm):
result = await ssh_execute(
host="linux01.internal",
command="notacommand",
secret_handle=handle,
ctx=ctx,
)
assert "Exit code: 127" in result
assert "command not found" in result
async def test_ssh_execute_stderr_included() -> None:
"""Both stdout and stderr appear in the result when both are non-empty."""
handle = await _fresh_handle()
ctx = _make_ctx()
mock_cm, _ = _make_ssh_cm(stdout="result\n", stderr="warning: low disk\n", exit_status=0)
with patch("mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm):
result = await ssh_execute(
host="host1",
command="df -h",
secret_handle=handle,
ctx=ctx,
)
assert "result" in result
assert "warning: low disk" in result
async def test_ssh_execute_username_override() -> None:
"""username_override replaces the credential's username in the connect call."""
handle = await _fresh_handle(username="original_user")
ctx = _make_ctx()
mock_cm, _ = _make_ssh_cm(stdout="uid=0(root)\n", exit_status=0)
with patch(
"mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm
) as mock_connect:
await ssh_execute(
host="host1",
command="id",
secret_handle=handle,
ctx=ctx,
username_override="root",
)
_args, _kwargs = mock_connect.call_args
assert _kwargs["username"] == "root"
async def test_ssh_execute_credential_username_used_by_default() -> None:
"""Without username_override, the credential's username is passed to connect."""
handle = await _fresh_handle(username="db_admin")
ctx = _make_ctx()
mock_cm, _ = _make_ssh_cm(stdout="ok\n", exit_status=0)
with patch(
"mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm
) as mock_connect:
await ssh_execute(
host="host1",
command="whoami",
secret_handle=handle,
ctx=ctx,
)
_args, _kwargs = mock_connect.call_args
assert _kwargs["username"] == "db_admin"
async def test_ssh_execute_invalid_handle_raises() -> None:
"""An unknown handle raises KeyError and calls ctx.error."""
ctx = _make_ctx()
with pytest.raises(KeyError):
await ssh_execute(
host="host1",
command="id",
secret_handle="secret://doesnotexist0000000000000000",
ctx=ctx,
)
ctx.error.assert_awaited_once()
async def test_ssh_execute_connect_os_error_propagates() -> None:
"""An OSError (e.g. connection refused) propagates and calls ctx.error."""
handle = await _fresh_handle()
ctx = _make_ctx()
with patch(
"mcp_privileged.ssh.server.asyncssh.connect",
side_effect=OSError("Connection refused"),
):
with pytest.raises(OSError):
await ssh_execute(
host="dead.host",
command="id",
secret_handle=handle,
ctx=ctx,
)
ctx.error.assert_awaited_once()
async def test_ssh_execute_permission_denied_propagates() -> None:
"""asyncssh.PermissionDenied propagates and calls ctx.error."""
handle = await _fresh_handle()
ctx = _make_ctx()
with patch(
"mcp_privileged.ssh.server.asyncssh.connect",
side_effect=asyncssh.PermissionDenied("Permission denied"),
):
with pytest.raises(asyncssh.PermissionDenied):
await ssh_execute(
host="host1",
command="id",
secret_handle=handle,
ctx=ctx,
)
ctx.error.assert_awaited_once()
async def test_ssh_execute_command_timeout_propagates() -> None:
"""asyncio.TimeoutError from conn.run propagates and calls ctx.error."""
handle = await _fresh_handle()
ctx = _make_ctx()
mock_conn = AsyncMock()
mock_conn.run = AsyncMock(side_effect=asyncio.TimeoutError())
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm):
with pytest.raises(asyncio.TimeoutError):
await ssh_execute(
host="slow.host",
command="sleep 999",
secret_handle=handle,
ctx=ctx,
timeout_seconds=1,
)
ctx.error.assert_awaited_once()
async def test_ssh_execute_password_not_in_result() -> None:
"""The credential password must never appear in the tool's return value."""
secret_password = "SuperSecret!123"
handle = await _fresh_handle(password=secret_password)
ctx = _make_ctx()
# Simulate a misconfigured command that echoes env vars containing the password
mock_cm, _ = _make_ssh_cm(stdout=f"PASSWORD={secret_password}\n", exit_status=0)
with patch("mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm):
result = await ssh_execute(
host="host1",
command="env",
secret_handle=handle,
ctx=ctx,
)
# The password leaking from stdout is the application's problem, not ours —
# what we must guarantee is that the *handle resolution* never injects it.
# Verify it doesn't appear in any ctx.error/ctx.info call from our code:
for call in ctx.error.await_args_list + ctx.info.await_args_list:
assert secret_password not in str(call), "Password leaked into MCP context log"
# ── Unit tests for helpers ────────────────────────────────────────────────────
def test_truncate_short_text_unchanged() -> None:
text = "hello world"
assert _truncate(text, 1024, "stdout") == text
def test_truncate_long_text_truncated() -> None:
text = "x" * 10_000
result = _truncate(text, 100, "stdout")
assert "truncated" in result
assert len(result.encode("utf-8")) <= 200 # marker adds a short suffix
def test_format_result_no_stderr() -> None:
result = _format_result("myhost", "ls /", 0, "bin\nlib\n", "")
assert "--- stderr ---" not in result
assert "Exit code: 0" in result
assert "bin" in result
def test_format_result_with_stderr() -> None:
result = _format_result("myhost", "bad_cmd", 1, "", "not found\n")
assert "--- stderr ---" in result
assert "not found" in result
assert "Exit code: 1" in result
def test_format_result_empty_stdout_shows_empty_marker() -> None:
result = _format_result("myhost", "true", 0, "", "")
assert "(empty)" in result