292 lines
9.7 KiB
Python
292 lines
9.7 KiB
Python
"""
|
|
Tests for the SSH MCP tool (ssh_execute).
|
|
|
|
All tests mock asyncssh.connect — no real SSH connections are made.
|
|
The secret_store is used directly so handle issuance/resolution is tested
|
|
end-to-end through the real store.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import asyncssh
|
|
import pytest
|
|
|
|
from mcp_privileged.secret_store import secret_store
|
|
from mcp_privileged.ssh.server import _truncate, _format_result, ssh_execute
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
def _make_ctx() -> MagicMock:
|
|
"""Return a minimal mock MCP Context."""
|
|
ctx = MagicMock()
|
|
ctx.info = AsyncMock()
|
|
ctx.error = AsyncMock()
|
|
# _extract_client_ip uses these
|
|
ctx.request_context.request.headers = {}
|
|
ctx.request_context.request.client = None
|
|
return ctx
|
|
|
|
|
|
def _make_ssh_cm(
|
|
stdout: str = "",
|
|
stderr: str = "",
|
|
exit_status: int = 0,
|
|
) -> tuple[AsyncMock, AsyncMock]:
|
|
"""
|
|
Build a mock for asyncssh.connect used as an async context manager.
|
|
|
|
Returns (context_manager_mock, conn_mock).
|
|
Patch asyncssh.connect with return_value=context_manager_mock.
|
|
"""
|
|
mock_conn = AsyncMock()
|
|
mock_conn.run = AsyncMock(
|
|
return_value=MagicMock(stdout=stdout, stderr=stderr, exit_status=exit_status)
|
|
)
|
|
mock_cm = AsyncMock()
|
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
|
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
|
return mock_cm, mock_conn
|
|
|
|
|
|
async def _fresh_handle(username: str = "svc_user", password: str = "P@ssw0rd!") -> str:
|
|
"""Store a credential and return a fresh (unconsumed) handle."""
|
|
return await secret_store.store(username, password)
|
|
|
|
|
|
# ── Tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
async def test_ssh_execute_success() -> None:
|
|
"""Happy path: command runs, stdout is returned, exit code is 0."""
|
|
handle = await _fresh_handle()
|
|
ctx = _make_ctx()
|
|
mock_cm, _ = _make_ssh_cm(stdout="hello world\n", exit_status=0)
|
|
|
|
with patch("mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm):
|
|
result = await ssh_execute(
|
|
host="linux01.internal",
|
|
command="echo hello world",
|
|
secret_handle=handle,
|
|
ctx=ctx,
|
|
)
|
|
|
|
assert "Exit code: 0" in result
|
|
assert "hello world" in result
|
|
assert "Host: linux01.internal" in result
|
|
assert "Command: echo hello world" in result
|
|
|
|
|
|
async def test_ssh_execute_nonzero_exit_not_raised() -> None:
|
|
"""A non-zero exit code is returned in the result, not raised as an exception."""
|
|
handle = await _fresh_handle()
|
|
ctx = _make_ctx()
|
|
mock_cm, _ = _make_ssh_cm(stdout="", stderr="command not found\n", exit_status=127)
|
|
|
|
with patch("mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm):
|
|
result = await ssh_execute(
|
|
host="linux01.internal",
|
|
command="notacommand",
|
|
secret_handle=handle,
|
|
ctx=ctx,
|
|
)
|
|
|
|
assert "Exit code: 127" in result
|
|
assert "command not found" in result
|
|
|
|
|
|
async def test_ssh_execute_stderr_included() -> None:
|
|
"""Both stdout and stderr appear in the result when both are non-empty."""
|
|
handle = await _fresh_handle()
|
|
ctx = _make_ctx()
|
|
mock_cm, _ = _make_ssh_cm(stdout="result\n", stderr="warning: low disk\n", exit_status=0)
|
|
|
|
with patch("mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm):
|
|
result = await ssh_execute(
|
|
host="host1",
|
|
command="df -h",
|
|
secret_handle=handle,
|
|
ctx=ctx,
|
|
)
|
|
|
|
assert "result" in result
|
|
assert "warning: low disk" in result
|
|
|
|
|
|
async def test_ssh_execute_username_override() -> None:
|
|
"""username_override replaces the credential's username in the connect call."""
|
|
handle = await _fresh_handle(username="original_user")
|
|
ctx = _make_ctx()
|
|
mock_cm, _ = _make_ssh_cm(stdout="uid=0(root)\n", exit_status=0)
|
|
|
|
with patch(
|
|
"mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm
|
|
) as mock_connect:
|
|
await ssh_execute(
|
|
host="host1",
|
|
command="id",
|
|
secret_handle=handle,
|
|
ctx=ctx,
|
|
username_override="root",
|
|
)
|
|
|
|
_args, _kwargs = mock_connect.call_args
|
|
assert _kwargs["username"] == "root"
|
|
|
|
|
|
async def test_ssh_execute_credential_username_used_by_default() -> None:
|
|
"""Without username_override, the credential's username is passed to connect."""
|
|
handle = await _fresh_handle(username="db_admin")
|
|
ctx = _make_ctx()
|
|
mock_cm, _ = _make_ssh_cm(stdout="ok\n", exit_status=0)
|
|
|
|
with patch(
|
|
"mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm
|
|
) as mock_connect:
|
|
await ssh_execute(
|
|
host="host1",
|
|
command="whoami",
|
|
secret_handle=handle,
|
|
ctx=ctx,
|
|
)
|
|
|
|
_args, _kwargs = mock_connect.call_args
|
|
assert _kwargs["username"] == "db_admin"
|
|
|
|
|
|
async def test_ssh_execute_invalid_handle_raises() -> None:
|
|
"""An unknown handle raises KeyError and calls ctx.error."""
|
|
ctx = _make_ctx()
|
|
|
|
with pytest.raises(KeyError):
|
|
await ssh_execute(
|
|
host="host1",
|
|
command="id",
|
|
secret_handle="secret://doesnotexist0000000000000000",
|
|
ctx=ctx,
|
|
)
|
|
|
|
ctx.error.assert_awaited_once()
|
|
|
|
|
|
async def test_ssh_execute_connect_os_error_propagates() -> None:
|
|
"""An OSError (e.g. connection refused) propagates and calls ctx.error."""
|
|
handle = await _fresh_handle()
|
|
ctx = _make_ctx()
|
|
|
|
with patch(
|
|
"mcp_privileged.ssh.server.asyncssh.connect",
|
|
side_effect=OSError("Connection refused"),
|
|
):
|
|
with pytest.raises(OSError):
|
|
await ssh_execute(
|
|
host="dead.host",
|
|
command="id",
|
|
secret_handle=handle,
|
|
ctx=ctx,
|
|
)
|
|
|
|
ctx.error.assert_awaited_once()
|
|
|
|
|
|
async def test_ssh_execute_permission_denied_propagates() -> None:
|
|
"""asyncssh.PermissionDenied propagates and calls ctx.error."""
|
|
handle = await _fresh_handle()
|
|
ctx = _make_ctx()
|
|
|
|
with patch(
|
|
"mcp_privileged.ssh.server.asyncssh.connect",
|
|
side_effect=asyncssh.PermissionDenied("Permission denied"),
|
|
):
|
|
with pytest.raises(asyncssh.PermissionDenied):
|
|
await ssh_execute(
|
|
host="host1",
|
|
command="id",
|
|
secret_handle=handle,
|
|
ctx=ctx,
|
|
)
|
|
|
|
ctx.error.assert_awaited_once()
|
|
|
|
|
|
async def test_ssh_execute_command_timeout_propagates() -> None:
|
|
"""asyncio.TimeoutError from conn.run propagates and calls ctx.error."""
|
|
handle = await _fresh_handle()
|
|
ctx = _make_ctx()
|
|
|
|
mock_conn = AsyncMock()
|
|
mock_conn.run = AsyncMock(side_effect=asyncio.TimeoutError())
|
|
mock_cm = AsyncMock()
|
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
|
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
with patch("mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm):
|
|
with pytest.raises(asyncio.TimeoutError):
|
|
await ssh_execute(
|
|
host="slow.host",
|
|
command="sleep 999",
|
|
secret_handle=handle,
|
|
ctx=ctx,
|
|
timeout_seconds=1,
|
|
)
|
|
|
|
ctx.error.assert_awaited_once()
|
|
|
|
|
|
async def test_ssh_execute_password_not_in_result() -> None:
|
|
"""The credential password must never appear in the tool's return value."""
|
|
secret_password = "SuperSecret!123"
|
|
handle = await _fresh_handle(password=secret_password)
|
|
ctx = _make_ctx()
|
|
# Simulate a misconfigured command that echoes env vars containing the password
|
|
mock_cm, _ = _make_ssh_cm(stdout=f"PASSWORD={secret_password}\n", exit_status=0)
|
|
|
|
with patch("mcp_privileged.ssh.server.asyncssh.connect", return_value=mock_cm):
|
|
result = await ssh_execute(
|
|
host="host1",
|
|
command="env",
|
|
secret_handle=handle,
|
|
ctx=ctx,
|
|
)
|
|
|
|
# The password leaking from stdout is the application's problem, not ours —
|
|
# what we must guarantee is that the *handle resolution* never injects it.
|
|
# Verify it doesn't appear in any ctx.error/ctx.info call from our code:
|
|
for call in ctx.error.await_args_list + ctx.info.await_args_list:
|
|
assert secret_password not in str(call), "Password leaked into MCP context log"
|
|
|
|
|
|
# ── Unit tests for helpers ────────────────────────────────────────────────────
|
|
|
|
def test_truncate_short_text_unchanged() -> None:
|
|
text = "hello world"
|
|
assert _truncate(text, 1024, "stdout") == text
|
|
|
|
|
|
def test_truncate_long_text_truncated() -> None:
|
|
text = "x" * 10_000
|
|
result = _truncate(text, 100, "stdout")
|
|
assert "truncated" in result
|
|
assert len(result.encode("utf-8")) <= 200 # marker adds a short suffix
|
|
|
|
|
|
def test_format_result_no_stderr() -> None:
|
|
result = _format_result("myhost", "ls /", 0, "bin\nlib\n", "")
|
|
assert "--- stderr ---" not in result
|
|
assert "Exit code: 0" in result
|
|
assert "bin" in result
|
|
|
|
|
|
def test_format_result_with_stderr() -> None:
|
|
result = _format_result("myhost", "bad_cmd", 1, "", "not found\n")
|
|
assert "--- stderr ---" in result
|
|
assert "not found" in result
|
|
assert "Exit code: 1" in result
|
|
|
|
|
|
def test_format_result_empty_stdout_shows_empty_marker() -> None:
|
|
result = _format_result("myhost", "true", 0, "", "")
|
|
assert "(empty)" in result
|