228 lines
7.6 KiB
Python
228 lines
7.6 KiB
Python
"""
|
|
Tests for the PowerShell MCP tool (ps_execute).
|
|
|
|
pypsrp is a synchronous library. The server wraps _run_ps_sync() in
|
|
asyncio.run_in_executor so we patch _run_ps_sync directly — no real WinRM
|
|
connections are made.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
import pytest
|
|
|
|
from mcp_privileged.powershell.server import _format_result, _truncate, ps_execute
|
|
from mcp_privileged.secret_store import secret_store
|
|
from tests.conftest import make_ps_result
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
async def _handle(username: str = "svc_user", password: str = "P@ss!") -> str:
|
|
return await secret_store.store(username, password)
|
|
|
|
|
|
# ── Tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
async def test_ps_execute_success(mock_ctx) -> None:
|
|
"""Happy path: script runs, output is returned, had_errors=False."""
|
|
handle = await _handle()
|
|
ps_result = make_ps_result(output=["Win2022", "Server"], had_errors=False)
|
|
|
|
with patch("mcp_privileged.powershell.server._run_ps_sync", return_value=ps_result):
|
|
result = await ps_execute(
|
|
host="win01.internal",
|
|
script="$PSVersionTable.OS; hostname",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
)
|
|
|
|
assert "Had errors: False" in result
|
|
assert "Win2022" in result
|
|
assert "Server" in result
|
|
assert "Host: win01.internal" in result
|
|
|
|
|
|
async def test_ps_execute_with_errors(mock_ctx) -> None:
|
|
"""Script produces errors — had_errors=True and error records are included."""
|
|
handle = await _handle()
|
|
ps_result = make_ps_result(
|
|
output=[],
|
|
had_errors=True,
|
|
errors=["Get-Item : Cannot find path 'C:\\missing'"],
|
|
)
|
|
|
|
with patch("mcp_privileged.powershell.server._run_ps_sync", return_value=ps_result):
|
|
result = await ps_execute(
|
|
host="win01.internal",
|
|
script="Get-Item C:\\missing",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
)
|
|
|
|
assert "Had errors: True" in result
|
|
assert "Cannot find path" in result
|
|
assert "--- errors ---" in result
|
|
|
|
|
|
async def test_ps_execute_no_output(mock_ctx) -> None:
|
|
"""Script runs but produces no output (e.g. Set-* cmdlets)."""
|
|
handle = await _handle()
|
|
ps_result = make_ps_result(output=[], had_errors=False)
|
|
|
|
with patch("mcp_privileged.powershell.server._run_ps_sync", return_value=ps_result):
|
|
result = await ps_execute(
|
|
host="win01.internal",
|
|
script="Set-TimeZone -Id 'UTC'",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
)
|
|
|
|
assert "Had errors: False" in result
|
|
assert "(no output)" in result
|
|
|
|
|
|
async def test_ps_execute_username_override(mock_ctx) -> None:
|
|
"""username_override is forwarded to _run_ps_sync."""
|
|
handle = await _handle(username="domain\\original")
|
|
ps_result = make_ps_result(output=["ok"])
|
|
|
|
with patch(
|
|
"mcp_privileged.powershell.server._run_ps_sync", return_value=ps_result
|
|
) as mock_run:
|
|
await ps_execute(
|
|
host="win01.internal",
|
|
script="whoami",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
username_override="domain\\admin",
|
|
)
|
|
|
|
# Third positional arg to _run_ps_sync is username
|
|
_args, _ = mock_run.call_args
|
|
assert _args[2] == "domain\\admin"
|
|
|
|
|
|
async def test_ps_execute_credential_username_used_by_default(mock_ctx) -> None:
|
|
"""Without username_override, the credential username is forwarded."""
|
|
handle = await _handle(username="domain\\svc_ps")
|
|
ps_result = make_ps_result(output=["ok"])
|
|
|
|
with patch(
|
|
"mcp_privileged.powershell.server._run_ps_sync", return_value=ps_result
|
|
) as mock_run:
|
|
await ps_execute(
|
|
host="win01.internal",
|
|
script="whoami",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
)
|
|
|
|
_args, _ = mock_run.call_args
|
|
assert _args[2] == "domain\\svc_ps"
|
|
|
|
|
|
async def test_ps_execute_invalid_handle(mock_ctx) -> None:
|
|
"""Unknown handle raises KeyError before any WinRM connection is attempted."""
|
|
with pytest.raises(KeyError):
|
|
await ps_execute(
|
|
host="win01.internal",
|
|
script="hostname",
|
|
secret_handle="secret://doesnotexist0000000000000000",
|
|
ctx=mock_ctx,
|
|
)
|
|
|
|
mock_ctx.error.assert_awaited_once()
|
|
|
|
|
|
async def test_ps_execute_winrm_exception_propagates(mock_ctx) -> None:
|
|
"""Exceptions from _run_ps_sync propagate and call ctx.error."""
|
|
handle = await _handle()
|
|
|
|
with patch(
|
|
"mcp_privileged.powershell.server._run_ps_sync",
|
|
side_effect=ConnectionRefusedError("WinRM port closed"),
|
|
):
|
|
with pytest.raises(ConnectionRefusedError):
|
|
await ps_execute(
|
|
host="dead.host",
|
|
script="hostname",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
)
|
|
|
|
mock_ctx.error.assert_awaited_once()
|
|
|
|
|
|
async def test_ps_execute_password_not_in_ctx_messages(mock_ctx) -> None:
|
|
"""The password must never appear in any ctx.info or ctx.error call."""
|
|
secret_password = "WinRM$ecret99"
|
|
handle = await secret_store.store("user", secret_password)
|
|
ps_result = make_ps_result(output=["ok"])
|
|
|
|
with patch("mcp_privileged.powershell.server._run_ps_sync", return_value=ps_result):
|
|
await ps_execute(
|
|
host="win01.internal",
|
|
script="hostname",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
)
|
|
|
|
all_calls = mock_ctx.info.await_args_list + mock_ctx.error.await_args_list
|
|
for call in all_calls:
|
|
assert secret_password not in str(call), "Password leaked into MCP context log"
|
|
|
|
|
|
async def test_ps_execute_ssl_and_port_forwarded(mock_ctx) -> None:
|
|
"""use_ssl=True and custom port are forwarded to _run_ps_sync."""
|
|
handle = await _handle()
|
|
ps_result = make_ps_result(output=["ok"])
|
|
|
|
with patch(
|
|
"mcp_privileged.powershell.server._run_ps_sync", return_value=ps_result
|
|
) as mock_run:
|
|
await ps_execute(
|
|
host="win01.internal",
|
|
script="hostname",
|
|
secret_handle=handle,
|
|
ctx=mock_ctx,
|
|
port=5986,
|
|
use_ssl=True,
|
|
)
|
|
|
|
_args, _ = mock_run.call_args
|
|
assert _args[1] == 5986 # port
|
|
assert _args[5] is True # use_ssl
|
|
|
|
|
|
# ── Unit tests for helpers ─────────────────────────────────────────────────────
|
|
|
|
def test_truncate_passthrough() -> None:
|
|
assert _truncate("hello", 1024, "output") == "hello"
|
|
|
|
|
|
def test_truncate_applies_limit() -> None:
|
|
result = _truncate("x" * 10_000, 100, "output")
|
|
assert "truncated" in result
|
|
assert len(result.encode()) < 300
|
|
|
|
|
|
def test_format_result_no_errors() -> None:
|
|
result = _format_result("win01", "Get-Process", False, ["proc1", "proc2"], [])
|
|
assert "Had errors: False" in result
|
|
assert "proc1" in result
|
|
assert "--- errors ---" not in result
|
|
|
|
|
|
def test_format_result_with_errors() -> None:
|
|
result = _format_result("win01", "bad_cmd", True, [], ["Error: not found"])
|
|
assert "Had errors: True" in result
|
|
assert "--- errors ---" in result
|
|
assert "not found" in result
|
|
|
|
|
|
def test_format_result_empty_output() -> None:
|
|
result = _format_result("win01", "Set-X", False, [], [])
|
|
assert "(no output)" in result
|