summary refs log tree commit diff
path: root/nixos/lib/test-driver/test_driver
diff options
context:
space:
mode:
Diffstat (limited to 'nixos/lib/test-driver/test_driver')
-rwxr-xr-xnixos/lib/test-driver/test_driver/__init__.py20
-rw-r--r--nixos/lib/test-driver/test_driver/driver.py33
-rw-r--r--nixos/lib/test-driver/test_driver/logger.py14
-rw-r--r--nixos/lib/test-driver/test_driver/machine.py76
-rw-r--r--nixos/lib/test-driver/test_driver/polling_condition.py8
-rw-r--r--nixos/lib/test-driver/test_driver/qmp.py98
-rw-r--r--nixos/lib/test-driver/test_driver/vlan.py2
7 files changed, 220 insertions, 31 deletions
diff --git a/nixos/lib/test-driver/test_driver/__init__.py b/nixos/lib/test-driver/test_driver/__init__.py
index c90e3d9e1cd..9daae1e941a 100755
--- a/nixos/lib/test-driver/test_driver/__init__.py
+++ b/nixos/lib/test-driver/test_driver/__init__.py
@@ -1,11 +1,12 @@
-from pathlib import Path
 import argparse
-import ptpython.repl
 import os
 import time
+from pathlib import Path
+
+import ptpython.repl
 
-from test_driver.logger import rootlog
 from test_driver.driver import Driver
+from test_driver.logger import rootlog
 
 
 class EnvDefault(argparse.Action):
@@ -25,9 +26,7 @@ class EnvDefault(argparse.Action):
                 )
         if required and default:
             required = False
-        super(EnvDefault, self).__init__(
-            default=default, required=required, nargs=nargs, **kwargs
-        )
+        super().__init__(default=default, required=required, nargs=nargs, **kwargs)
 
     def __call__(self, parser, namespace, values, option_string=None):  # type: ignore
         setattr(namespace, self.dest, values)
@@ -78,6 +77,14 @@ def main() -> None:
         help="vlans to span by the driver",
     )
     arg_parser.add_argument(
+        "--global-timeout",
+        type=int,
+        metavar="GLOBAL_TIMEOUT",
+        action=EnvDefault,
+        envvar="globalTimeout",
+        help="Timeout in seconds for the whole test",
+    )
+    arg_parser.add_argument(
         "-o",
         "--output_directory",
         help="""The path to the directory where outputs copied from the VM will be placed.
@@ -104,6 +111,7 @@ def main() -> None:
         args.testscript.read_text(),
         args.output_directory.resolve(),
         args.keep_vm_state,
+        args.global_timeout,
     ) as driver:
         if args.interactive:
             history_dir = os.getcwd()
diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py
index 835d60ec3b4..786821b0cc0 100644
--- a/nixos/lib/test-driver/test_driver/driver.py
+++ b/nixos/lib/test-driver/test_driver/driver.py
@@ -1,14 +1,16 @@
-from contextlib import contextmanager
-from pathlib import Path
-from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager
 import os
 import re
+import signal
 import tempfile
+import threading
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Union
 
 from test_driver.logger import rootlog
 from test_driver.machine import Machine, NixStartScript, retry
-from test_driver.vlan import VLan
 from test_driver.polling_condition import PollingCondition
+from test_driver.vlan import VLan
 
 
 def get_tmp_dir() -> Path:
@@ -41,6 +43,8 @@ class Driver:
     vlans: List[VLan]
     machines: List[Machine]
     polling_conditions: List[PollingCondition]
+    global_timeout: int
+    race_timer: threading.Timer
 
     def __init__(
         self,
@@ -49,9 +53,12 @@ class Driver:
         tests: str,
         out_dir: Path,
         keep_vm_state: bool = False,
+        global_timeout: int = 24 * 60 * 60 * 7,
     ):
         self.tests = tests
         self.out_dir = out_dir
+        self.global_timeout = global_timeout
+        self.race_timer = threading.Timer(global_timeout, self.terminate_test)
 
         tmp_dir = get_tmp_dir()
 
@@ -82,6 +89,7 @@ class Driver:
 
     def __exit__(self, *_: Any) -> None:
         with rootlog.nested("cleanup"):
+            self.race_timer.cancel()
             for machine in self.machines:
                 machine.release()
 
@@ -144,6 +152,10 @@ class Driver:
 
     def run_tests(self) -> None:
         """Run the test script (for non-interactive test runs)"""
+        rootlog.info(
+            f"Test will time out and terminate in {self.global_timeout} seconds"
+        )
+        self.race_timer.start()
         self.test_script()
         # TODO: Collect coverage data
         for machine in self.machines:
@@ -161,6 +173,19 @@ class Driver:
         with rootlog.nested("wait for all VMs to finish"):
             for machine in self.machines:
                 machine.wait_for_shutdown()
+            self.race_timer.cancel()
+
+    def terminate_test(self) -> None:
+        # This will be usually running in another thread than
+        # the thread actually executing the test script.
+        with rootlog.nested("timeout reached; test terminating..."):
+            for machine in self.machines:
+                machine.release()
+            # As we cannot `sys.exit` from another thread
+            # We can at least force the main thread to get SIGTERM'ed.
+            # This will prevent any user who caught all the exceptions
+            # to swallow them and prevent itself from terminating.
+            os.kill(os.getpid(), signal.SIGTERM)
 
     def create_machine(self, args: Dict[str, Any]) -> Machine:
         tmp_dir = get_tmp_dir()
diff --git a/nixos/lib/test-driver/test_driver/logger.py b/nixos/lib/test-driver/test_driver/logger.py
index e6182ff7c76..116244b5e4a 100644
--- a/nixos/lib/test-driver/test_driver/logger.py
+++ b/nixos/lib/test-driver/test_driver/logger.py
@@ -1,13 +1,17 @@
-from colorama import Style, Fore
-from contextlib import contextmanager
-from typing import Any, Dict, Iterator
-from queue import Queue, Empty
-from xml.sax.saxutils import XMLGenerator
+# mypy: disable-error-code="no-untyped-call"
+# drop the above line when mypy is upgraded to include
+# https://github.com/python/typeshed/commit/49b717ca52bf0781a538b04c0d76a5513f7119b8
 import codecs
 import os
 import sys
 import time
 import unicodedata
+from contextlib import contextmanager
+from queue import Empty, Queue
+from typing import Any, Dict, Iterator
+from xml.sax.saxutils import XMLGenerator
+
+from colorama import Fore, Style
 
 
 class Logger:
diff --git a/nixos/lib/test-driver/test_driver/machine.py b/nixos/lib/test-driver/test_driver/machine.py
index 2afcbc95c66..f430321bb60 100644
--- a/nixos/lib/test-driver/test_driver/machine.py
+++ b/nixos/lib/test-driver/test_driver/machine.py
@@ -1,7 +1,3 @@
-from contextlib import _GeneratorContextManager, nullcontext
-from pathlib import Path
-from queue import Queue
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
 import base64
 import io
 import os
@@ -16,9 +12,15 @@ import sys
 import tempfile
 import threading
 import time
+from contextlib import _GeneratorContextManager, nullcontext
+from pathlib import Path
+from queue import Queue
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
 
 from test_driver.logger import rootlog
 
+from .qmp import QMPSession
+
 CHAR_TO_KEY = {
     "A": "shift-a",
     "N": "shift-n",
@@ -144,6 +146,7 @@ class StartCommand:
     def cmd(
         self,
         monitor_socket_path: Path,
+        qmp_socket_path: Path,
         shell_socket_path: Path,
         allow_reboot: bool = False,
     ) -> str:
@@ -167,6 +170,7 @@ class StartCommand:
 
         return (
             f"{self._cmd}"
+            f" -qmp unix:{qmp_socket_path},server=on,wait=off"
             f" -monitor unix:{monitor_socket_path}"
             f" -chardev socket,id=shell,path={shell_socket_path}"
             f"{qemu_opts}"
@@ -194,11 +198,14 @@ class StartCommand:
         state_dir: Path,
         shared_dir: Path,
         monitor_socket_path: Path,
+        qmp_socket_path: Path,
         shell_socket_path: Path,
         allow_reboot: bool,
     ) -> subprocess.Popen:
         return subprocess.Popen(
-            self.cmd(monitor_socket_path, shell_socket_path, allow_reboot),
+            self.cmd(
+                monitor_socket_path, qmp_socket_path, shell_socket_path, allow_reboot
+            ),
             stdin=subprocess.PIPE,
             stdout=subprocess.PIPE,
             stderr=subprocess.STDOUT,
@@ -236,14 +243,14 @@ class LegacyStartCommand(StartCommand):
 
     def __init__(
         self,
-        netBackendArgs: Optional[str] = None,
-        netFrontendArgs: Optional[str] = None,
+        netBackendArgs: Optional[str] = None,  # noqa: N803
+        netFrontendArgs: Optional[str] = None,  # noqa: N803
         hda: Optional[Tuple[Path, str]] = None,
         cdrom: Optional[str] = None,
         usb: Optional[str] = None,
         bios: Optional[str] = None,
-        qemuBinary: Optional[str] = None,
-        qemuFlags: Optional[str] = None,
+        qemuBinary: Optional[str] = None,  # noqa: N803
+        qemuFlags: Optional[str] = None,  # noqa: N803
     ):
         if qemuBinary is not None:
             self._cmd = qemuBinary
@@ -309,6 +316,7 @@ class Machine:
     shared_dir: Path
     state_dir: Path
     monitor_path: Path
+    qmp_path: Path
     shell_path: Path
 
     start_command: StartCommand
@@ -317,6 +325,7 @@ class Machine:
     process: Optional[subprocess.Popen]
     pid: Optional[int]
     monitor: Optional[socket.socket]
+    qmp_client: Optional[QMPSession]
     shell: Optional[socket.socket]
     serial_thread: Optional[threading.Thread]
 
@@ -352,6 +361,7 @@ class Machine:
 
         self.state_dir = self.tmp_dir / f"vm-state-{self.name}"
         self.monitor_path = self.state_dir / "monitor"
+        self.qmp_path = self.state_dir / "qmp"
         self.shell_path = self.state_dir / "shell"
         if (not self.keep_vm_state) and self.state_dir.exists():
             self.cleanup_statedir()
@@ -360,6 +370,7 @@ class Machine:
         self.process = None
         self.pid = None
         self.monitor = None
+        self.qmp_client = None
         self.shell = None
         self.serial_thread = None
 
@@ -599,7 +610,7 @@ class Machine:
             return (-1, output.decode())
 
         # Get the return code
-        self.shell.send("echo ${PIPESTATUS[0]}\n".encode())
+        self.shell.send(b"echo ${PIPESTATUS[0]}\n")
         rc = int(self._next_newline_closed_block_from_shell().strip())
 
         return (rc, output.decode(errors="replace"))
@@ -791,6 +802,28 @@ class Machine:
         with self.nested(f"waiting for TCP port {port} on {addr}"):
             retry(port_is_open, timeout)
 
+    def wait_for_open_unix_socket(
+        self, addr: str, is_datagram: bool = False, timeout: int = 900
+    ) -> None:
+        """
+        Wait until a process is listening on the given UNIX-domain socket
+        (default to a UNIX-domain stream socket).
+        """
+
+        nc_flags = [
+            "-z",
+            "-uU" if is_datagram else "-U",
+        ]
+
+        def socket_is_open(_: Any) -> bool:
+            status, _ = self.execute(f"nc {' '.join(nc_flags)} {addr}")
+            return status == 0
+
+        with self.nested(
+            f"waiting for UNIX-domain {'datagram' if is_datagram else 'stream'} on '{addr}'"
+        ):
+            retry(socket_is_open, timeout)
+
     def wait_for_closed_port(
         self, port: int, addr: str = "localhost", timeout: int = 900
     ) -> None:
@@ -843,6 +876,9 @@ class Machine:
 
             while True:
                 chunk = self.shell.recv(1024)
+                # No need to print empty strings, it means we are waiting.
+                if len(chunk) == 0:
+                    continue
                 self.log(f"Guest shell says: {chunk!r}")
                 # NOTE: for this to work, nothing must be printed after this line!
                 if b"Spawning backdoor root shell..." in chunk:
@@ -1087,11 +1123,13 @@ class Machine:
             self.state_dir,
             self.shared_dir,
             self.monitor_path,
+            self.qmp_path,
             self.shell_path,
             allow_reboot,
         )
         self.monitor, _ = monitor_socket.accept()
         self.shell, _ = shell_socket.accept()
+        self.qmp_client = QMPSession.from_path(self.qmp_path)
 
         # Store last serial console lines for use
         # of wait_for_console_text
@@ -1129,7 +1167,7 @@ class Machine:
             return
 
         assert self.shell
-        self.shell.send("poweroff\n".encode())
+        self.shell.send(b"poweroff\n")
         self.wait_for_shutdown()
 
     def crash(self) -> None:
@@ -1240,3 +1278,19 @@ class Machine:
     def run_callbacks(self) -> None:
         for callback in self.callbacks:
             callback()
+
+    def switch_root(self) -> None:
+        """
+        Transition from stage 1 to stage 2. This requires the
+        machine to be configured with `testing.initrdBackdoor = true`
+        and `boot.initrd.systemd.enable = true`.
+        """
+        self.wait_for_unit("initrd.target")
+        self.execute(
+            "systemctl isolate --no-block initrd-switch-root.target 2>/dev/null >/dev/null",
+            check_return=False,
+            check_output=False,
+        )
+        self.wait_for_console_text(r"systemd\[1\]:.*Switching root\.")
+        self.connected = False
+        self.connect()
diff --git a/nixos/lib/test-driver/test_driver/polling_condition.py b/nixos/lib/test-driver/test_driver/polling_condition.py
index 02ca0a03ab3..12cbad69e34 100644
--- a/nixos/lib/test-driver/test_driver/polling_condition.py
+++ b/nixos/lib/test-driver/test_driver/polling_condition.py
@@ -1,11 +1,11 @@
-from typing import Callable, Optional
-from math import isfinite
 import time
+from math import isfinite
+from typing import Callable, Optional
 
 from .logger import rootlog
 
 
-class PollingConditionFailed(Exception):
+class PollingConditionError(Exception):
     pass
 
 
@@ -60,7 +60,7 @@ class PollingCondition:
 
     def maybe_raise(self) -> None:
         if not self.check():
-            raise PollingConditionFailed(self.status_message(False))
+            raise PollingConditionError(self.status_message(False))
 
     def status_message(self, status: bool) -> str:
         return f"Polling condition {'succeeded' if status else 'failed'}: {self.description}"
diff --git a/nixos/lib/test-driver/test_driver/qmp.py b/nixos/lib/test-driver/test_driver/qmp.py
new file mode 100644
index 00000000000..62ca6d7d5b8
--- /dev/null
+++ b/nixos/lib/test-driver/test_driver/qmp.py
@@ -0,0 +1,98 @@
+import json
+import logging
+import os
+import socket
+from collections.abc import Iterator
+from pathlib import Path
+from queue import Queue
+from typing import Any
+
+logger = logging.getLogger(__name__)
+
+
+class QMPAPIError(RuntimeError):
+    def __init__(self, message: dict[str, Any]):
+        assert "error" in message, "Not an error message!"
+        try:
+            self.class_name = message["class"]
+            self.description = message["desc"]
+            # NOTE: Some errors can occur before the Server is able to read the
+            # id member; in these cases the id member will not be part of the
+            # error response, even if provided by the client.
+            self.transaction_id = message.get("id")
+        except KeyError:
+            raise RuntimeError("Malformed QMP API error response")
+
+    def __str__(self) -> str:
+        return f"<QMP API error related to transaction {self.transaction_id} [{self.class_name}]: {self.description}>"
+
+
+class QMPSession:
+    def __init__(self, sock: socket.socket) -> None:
+        self.sock = sock
+        self.results: Queue[dict[str, str]] = Queue()
+        self.pending_events: Queue[dict[str, Any]] = Queue()
+        self.reader = sock.makefile("r")
+        self.writer = sock.makefile("w")
+        # Make the reader non-blocking so we can kind of select on it.
+        os.set_blocking(self.reader.fileno(), False)
+        hello = self._wait_for_new_result()
+        logger.debug(f"Got greeting from QMP API: {hello}")
+        # The greeting message format is:
+        # { "QMP": { "version": json-object, "capabilities": json-array } }
+        assert "QMP" in hello, f"Unexpected result: {hello}"
+        self.send("qmp_capabilities")
+
+    @classmethod
+    def from_path(cls, path: Path) -> "QMPSession":
+        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+        sock.connect(str(path))
+        return cls(sock)
+
+    def __del__(self) -> None:
+        self.sock.close()
+
+    def _wait_for_new_result(self) -> dict[str, str]:
+        assert self.results.empty(), "Results set is not empty, missed results!"
+        while self.results.empty():
+            self.read_pending_messages()
+        return self.results.get()
+
+    def read_pending_messages(self) -> None:
+        line = self.reader.readline()
+        if not line:
+            return
+        evt_or_result = json.loads(line)
+        logger.debug(f"Received a message: {evt_or_result}")
+
+        # It's a result
+        if "return" in evt_or_result or "QMP" in evt_or_result:
+            self.results.put(evt_or_result)
+        # It's an event
+        elif "event" in evt_or_result:
+            self.pending_events.put(evt_or_result)
+        else:
+            raise QMPAPIError(evt_or_result)
+
+    def wait_for_event(self, timeout: int = 10) -> dict[str, Any]:
+        while self.pending_events.empty():
+            self.read_pending_messages()
+
+        return self.pending_events.get(timeout=timeout)
+
+    def events(self, timeout: int = 10) -> Iterator[dict[str, Any]]:
+        while not self.pending_events.empty():
+            yield self.pending_events.get(timeout=timeout)
+
+    def send(self, cmd: str, args: dict[str, str] = {}) -> dict[str, str]:
+        self.read_pending_messages()
+        assert self.results.empty(), "Results set is not empty, missed results!"
+        data: dict[str, Any] = dict(execute=cmd)
+        if args != {}:
+            data["arguments"] = args
+
+        logger.debug(f"Sending {data} to QMP...")
+        json.dump(data, self.writer)
+        self.writer.write("\n")
+        self.writer.flush()
+        return self._wait_for_new_result()
diff --git a/nixos/lib/test-driver/test_driver/vlan.py b/nixos/lib/test-driver/test_driver/vlan.py
index f2a7f250d1d..ec9679108e5 100644
--- a/nixos/lib/test-driver/test_driver/vlan.py
+++ b/nixos/lib/test-driver/test_driver/vlan.py
@@ -1,8 +1,8 @@
-from pathlib import Path
 import io
 import os
 import pty
 import subprocess
+from pathlib import Path
 
 from test_driver.logger import rootlog