diff options
Diffstat (limited to 'nixos/lib/test-driver/test_driver')
-rwxr-xr-x | nixos/lib/test-driver/test_driver/__init__.py | 20 | ||||
-rw-r--r-- | nixos/lib/test-driver/test_driver/driver.py | 33 | ||||
-rw-r--r-- | nixos/lib/test-driver/test_driver/logger.py | 14 | ||||
-rw-r--r-- | nixos/lib/test-driver/test_driver/machine.py | 76 | ||||
-rw-r--r-- | nixos/lib/test-driver/test_driver/polling_condition.py | 8 | ||||
-rw-r--r-- | nixos/lib/test-driver/test_driver/qmp.py | 98 | ||||
-rw-r--r-- | nixos/lib/test-driver/test_driver/vlan.py | 2 |
7 files changed, 220 insertions, 31 deletions
diff --git a/nixos/lib/test-driver/test_driver/__init__.py b/nixos/lib/test-driver/test_driver/__init__.py index c90e3d9e1cd..9daae1e941a 100755 --- a/nixos/lib/test-driver/test_driver/__init__.py +++ b/nixos/lib/test-driver/test_driver/__init__.py @@ -1,11 +1,12 @@ -from pathlib import Path import argparse -import ptpython.repl import os import time +from pathlib import Path + +import ptpython.repl -from test_driver.logger import rootlog from test_driver.driver import Driver +from test_driver.logger import rootlog class EnvDefault(argparse.Action): @@ -25,9 +26,7 @@ class EnvDefault(argparse.Action): ) if required and default: required = False - super(EnvDefault, self).__init__( - default=default, required=required, nargs=nargs, **kwargs - ) + super().__init__(default=default, required=required, nargs=nargs, **kwargs) def __call__(self, parser, namespace, values, option_string=None): # type: ignore setattr(namespace, self.dest, values) @@ -78,6 +77,14 @@ def main() -> None: help="vlans to span by the driver", ) arg_parser.add_argument( + "--global-timeout", + type=int, + metavar="GLOBAL_TIMEOUT", + action=EnvDefault, + envvar="globalTimeout", + help="Timeout in seconds for the whole test", + ) + arg_parser.add_argument( "-o", "--output_directory", help="""The path to the directory where outputs copied from the VM will be placed. @@ -104,6 +111,7 @@ def main() -> None: args.testscript.read_text(), args.output_directory.resolve(), args.keep_vm_state, + args.global_timeout, ) as driver: if args.interactive: history_dir = os.getcwd() diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py index 835d60ec3b4..786821b0cc0 100644 --- a/nixos/lib/test-driver/test_driver/driver.py +++ b/nixos/lib/test-driver/test_driver/driver.py @@ -1,14 +1,16 @@ -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager import os import re +import signal import tempfile +import threading +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Union from test_driver.logger import rootlog from test_driver.machine import Machine, NixStartScript, retry -from test_driver.vlan import VLan from test_driver.polling_condition import PollingCondition +from test_driver.vlan import VLan def get_tmp_dir() -> Path: @@ -41,6 +43,8 @@ class Driver: vlans: List[VLan] machines: List[Machine] polling_conditions: List[PollingCondition] + global_timeout: int + race_timer: threading.Timer def __init__( self, @@ -49,9 +53,12 @@ class Driver: tests: str, out_dir: Path, keep_vm_state: bool = False, + global_timeout: int = 24 * 60 * 60 * 7, ): self.tests = tests self.out_dir = out_dir + self.global_timeout = global_timeout + self.race_timer = threading.Timer(global_timeout, self.terminate_test) tmp_dir = get_tmp_dir() @@ -82,6 +89,7 @@ class Driver: def __exit__(self, *_: Any) -> None: with rootlog.nested("cleanup"): + self.race_timer.cancel() for machine in self.machines: machine.release() @@ -144,6 +152,10 @@ class Driver: def run_tests(self) -> None: """Run the test script (for non-interactive test runs)""" + rootlog.info( + f"Test will time out and terminate in {self.global_timeout} seconds" + ) + self.race_timer.start() self.test_script() # TODO: Collect coverage data for machine in self.machines: @@ -161,6 +173,19 @@ class Driver: with rootlog.nested("wait for all VMs to finish"): for machine in self.machines: machine.wait_for_shutdown() + self.race_timer.cancel() + + def terminate_test(self) -> None: + # This will be usually running in another thread than + # the thread actually executing the test script. + with rootlog.nested("timeout reached; test terminating..."): + for machine in self.machines: + machine.release() + # As we cannot `sys.exit` from another thread + # We can at least force the main thread to get SIGTERM'ed. + # This will prevent any user who caught all the exceptions + # to swallow them and prevent itself from terminating. + os.kill(os.getpid(), signal.SIGTERM) def create_machine(self, args: Dict[str, Any]) -> Machine: tmp_dir = get_tmp_dir() diff --git a/nixos/lib/test-driver/test_driver/logger.py b/nixos/lib/test-driver/test_driver/logger.py index e6182ff7c76..116244b5e4a 100644 --- a/nixos/lib/test-driver/test_driver/logger.py +++ b/nixos/lib/test-driver/test_driver/logger.py @@ -1,13 +1,17 @@ -from colorama import Style, Fore -from contextlib import contextmanager -from typing import Any, Dict, Iterator -from queue import Queue, Empty -from xml.sax.saxutils import XMLGenerator +# mypy: disable-error-code="no-untyped-call" +# drop the above line when mypy is upgraded to include +# https://github.com/python/typeshed/commit/49b717ca52bf0781a538b04c0d76a5513f7119b8 import codecs import os import sys import time import unicodedata +from contextlib import contextmanager +from queue import Empty, Queue +from typing import Any, Dict, Iterator +from xml.sax.saxutils import XMLGenerator + +from colorama import Fore, Style class Logger: diff --git a/nixos/lib/test-driver/test_driver/machine.py b/nixos/lib/test-driver/test_driver/machine.py index 2afcbc95c66..f430321bb60 100644 --- a/nixos/lib/test-driver/test_driver/machine.py +++ b/nixos/lib/test-driver/test_driver/machine.py @@ -1,7 +1,3 @@ -from contextlib import _GeneratorContextManager, nullcontext -from pathlib import Path -from queue import Queue -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import base64 import io import os @@ -16,9 +12,15 @@ import sys import tempfile import threading import time +from contextlib import _GeneratorContextManager, nullcontext +from pathlib import Path +from queue import Queue +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from test_driver.logger import rootlog +from .qmp import QMPSession + CHAR_TO_KEY = { "A": "shift-a", "N": "shift-n", @@ -144,6 +146,7 @@ class StartCommand: def cmd( self, monitor_socket_path: Path, + qmp_socket_path: Path, shell_socket_path: Path, allow_reboot: bool = False, ) -> str: @@ -167,6 +170,7 @@ class StartCommand: return ( f"{self._cmd}" + f" -qmp unix:{qmp_socket_path},server=on,wait=off" f" -monitor unix:{monitor_socket_path}" f" -chardev socket,id=shell,path={shell_socket_path}" f"{qemu_opts}" @@ -194,11 +198,14 @@ class StartCommand: state_dir: Path, shared_dir: Path, monitor_socket_path: Path, + qmp_socket_path: Path, shell_socket_path: Path, allow_reboot: bool, ) -> subprocess.Popen: return subprocess.Popen( - self.cmd(monitor_socket_path, shell_socket_path, allow_reboot), + self.cmd( + monitor_socket_path, qmp_socket_path, shell_socket_path, allow_reboot + ), stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, @@ -236,14 +243,14 @@ class LegacyStartCommand(StartCommand): def __init__( self, - netBackendArgs: Optional[str] = None, - netFrontendArgs: Optional[str] = None, + netBackendArgs: Optional[str] = None, # noqa: N803 + netFrontendArgs: Optional[str] = None, # noqa: N803 hda: Optional[Tuple[Path, str]] = None, cdrom: Optional[str] = None, usb: Optional[str] = None, bios: Optional[str] = None, - qemuBinary: Optional[str] = None, - qemuFlags: Optional[str] = None, + qemuBinary: Optional[str] = None, # noqa: N803 + qemuFlags: Optional[str] = None, # noqa: N803 ): if qemuBinary is not None: self._cmd = qemuBinary @@ -309,6 +316,7 @@ class Machine: shared_dir: Path state_dir: Path monitor_path: Path + qmp_path: Path shell_path: Path start_command: StartCommand @@ -317,6 +325,7 @@ class Machine: process: Optional[subprocess.Popen] pid: Optional[int] monitor: Optional[socket.socket] + qmp_client: Optional[QMPSession] shell: Optional[socket.socket] serial_thread: Optional[threading.Thread] @@ -352,6 +361,7 @@ class Machine: self.state_dir = self.tmp_dir / f"vm-state-{self.name}" self.monitor_path = self.state_dir / "monitor" + self.qmp_path = self.state_dir / "qmp" self.shell_path = self.state_dir / "shell" if (not self.keep_vm_state) and self.state_dir.exists(): self.cleanup_statedir() @@ -360,6 +370,7 @@ class Machine: self.process = None self.pid = None self.monitor = None + self.qmp_client = None self.shell = None self.serial_thread = None @@ -599,7 +610,7 @@ class Machine: return (-1, output.decode()) # Get the return code - self.shell.send("echo ${PIPESTATUS[0]}\n".encode()) + self.shell.send(b"echo ${PIPESTATUS[0]}\n") rc = int(self._next_newline_closed_block_from_shell().strip()) return (rc, output.decode(errors="replace")) @@ -791,6 +802,28 @@ class Machine: with self.nested(f"waiting for TCP port {port} on {addr}"): retry(port_is_open, timeout) + def wait_for_open_unix_socket( + self, addr: str, is_datagram: bool = False, timeout: int = 900 + ) -> None: + """ + Wait until a process is listening on the given UNIX-domain socket + (default to a UNIX-domain stream socket). + """ + + nc_flags = [ + "-z", + "-uU" if is_datagram else "-U", + ] + + def socket_is_open(_: Any) -> bool: + status, _ = self.execute(f"nc {' '.join(nc_flags)} {addr}") + return status == 0 + + with self.nested( + f"waiting for UNIX-domain {'datagram' if is_datagram else 'stream'} on '{addr}'" + ): + retry(socket_is_open, timeout) + def wait_for_closed_port( self, port: int, addr: str = "localhost", timeout: int = 900 ) -> None: @@ -843,6 +876,9 @@ class Machine: while True: chunk = self.shell.recv(1024) + # No need to print empty strings, it means we are waiting. + if len(chunk) == 0: + continue self.log(f"Guest shell says: {chunk!r}") # NOTE: for this to work, nothing must be printed after this line! if b"Spawning backdoor root shell..." in chunk: @@ -1087,11 +1123,13 @@ class Machine: self.state_dir, self.shared_dir, self.monitor_path, + self.qmp_path, self.shell_path, allow_reboot, ) self.monitor, _ = monitor_socket.accept() self.shell, _ = shell_socket.accept() + self.qmp_client = QMPSession.from_path(self.qmp_path) # Store last serial console lines for use # of wait_for_console_text @@ -1129,7 +1167,7 @@ class Machine: return assert self.shell - self.shell.send("poweroff\n".encode()) + self.shell.send(b"poweroff\n") self.wait_for_shutdown() def crash(self) -> None: @@ -1240,3 +1278,19 @@ class Machine: def run_callbacks(self) -> None: for callback in self.callbacks: callback() + + def switch_root(self) -> None: + """ + Transition from stage 1 to stage 2. This requires the + machine to be configured with `testing.initrdBackdoor = true` + and `boot.initrd.systemd.enable = true`. + """ + self.wait_for_unit("initrd.target") + self.execute( + "systemctl isolate --no-block initrd-switch-root.target 2>/dev/null >/dev/null", + check_return=False, + check_output=False, + ) + self.wait_for_console_text(r"systemd\[1\]:.*Switching root\.") + self.connected = False + self.connect() diff --git a/nixos/lib/test-driver/test_driver/polling_condition.py b/nixos/lib/test-driver/test_driver/polling_condition.py index 02ca0a03ab3..12cbad69e34 100644 --- a/nixos/lib/test-driver/test_driver/polling_condition.py +++ b/nixos/lib/test-driver/test_driver/polling_condition.py @@ -1,11 +1,11 @@ -from typing import Callable, Optional -from math import isfinite import time +from math import isfinite +from typing import Callable, Optional from .logger import rootlog -class PollingConditionFailed(Exception): +class PollingConditionError(Exception): pass @@ -60,7 +60,7 @@ class PollingCondition: def maybe_raise(self) -> None: if not self.check(): - raise PollingConditionFailed(self.status_message(False)) + raise PollingConditionError(self.status_message(False)) def status_message(self, status: bool) -> str: return f"Polling condition {'succeeded' if status else 'failed'}: {self.description}" diff --git a/nixos/lib/test-driver/test_driver/qmp.py b/nixos/lib/test-driver/test_driver/qmp.py new file mode 100644 index 00000000000..62ca6d7d5b8 --- /dev/null +++ b/nixos/lib/test-driver/test_driver/qmp.py @@ -0,0 +1,98 @@ +import json +import logging +import os +import socket +from collections.abc import Iterator +from pathlib import Path +from queue import Queue +from typing import Any + +logger = logging.getLogger(__name__) + + +class QMPAPIError(RuntimeError): + def __init__(self, message: dict[str, Any]): + assert "error" in message, "Not an error message!" + try: + self.class_name = message["class"] + self.description = message["desc"] + # NOTE: Some errors can occur before the Server is able to read the + # id member; in these cases the id member will not be part of the + # error response, even if provided by the client. + self.transaction_id = message.get("id") + except KeyError: + raise RuntimeError("Malformed QMP API error response") + + def __str__(self) -> str: + return f"<QMP API error related to transaction {self.transaction_id} [{self.class_name}]: {self.description}>" + + +class QMPSession: + def __init__(self, sock: socket.socket) -> None: + self.sock = sock + self.results: Queue[dict[str, str]] = Queue() + self.pending_events: Queue[dict[str, Any]] = Queue() + self.reader = sock.makefile("r") + self.writer = sock.makefile("w") + # Make the reader non-blocking so we can kind of select on it. + os.set_blocking(self.reader.fileno(), False) + hello = self._wait_for_new_result() + logger.debug(f"Got greeting from QMP API: {hello}") + # The greeting message format is: + # { "QMP": { "version": json-object, "capabilities": json-array } } + assert "QMP" in hello, f"Unexpected result: {hello}" + self.send("qmp_capabilities") + + @classmethod + def from_path(cls, path: Path) -> "QMPSession": + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(str(path)) + return cls(sock) + + def __del__(self) -> None: + self.sock.close() + + def _wait_for_new_result(self) -> dict[str, str]: + assert self.results.empty(), "Results set is not empty, missed results!" + while self.results.empty(): + self.read_pending_messages() + return self.results.get() + + def read_pending_messages(self) -> None: + line = self.reader.readline() + if not line: + return + evt_or_result = json.loads(line) + logger.debug(f"Received a message: {evt_or_result}") + + # It's a result + if "return" in evt_or_result or "QMP" in evt_or_result: + self.results.put(evt_or_result) + # It's an event + elif "event" in evt_or_result: + self.pending_events.put(evt_or_result) + else: + raise QMPAPIError(evt_or_result) + + def wait_for_event(self, timeout: int = 10) -> dict[str, Any]: + while self.pending_events.empty(): + self.read_pending_messages() + + return self.pending_events.get(timeout=timeout) + + def events(self, timeout: int = 10) -> Iterator[dict[str, Any]]: + while not self.pending_events.empty(): + yield self.pending_events.get(timeout=timeout) + + def send(self, cmd: str, args: dict[str, str] = {}) -> dict[str, str]: + self.read_pending_messages() + assert self.results.empty(), "Results set is not empty, missed results!" + data: dict[str, Any] = dict(execute=cmd) + if args != {}: + data["arguments"] = args + + logger.debug(f"Sending {data} to QMP...") + json.dump(data, self.writer) + self.writer.write("\n") + self.writer.flush() + return self._wait_for_new_result() diff --git a/nixos/lib/test-driver/test_driver/vlan.py b/nixos/lib/test-driver/test_driver/vlan.py index f2a7f250d1d..ec9679108e5 100644 --- a/nixos/lib/test-driver/test_driver/vlan.py +++ b/nixos/lib/test-driver/test_driver/vlan.py @@ -1,8 +1,8 @@ -from pathlib import Path import io import os import pty import subprocess +from pathlib import Path from test_driver.logger import rootlog |