diff options
Diffstat (limited to 'nixos/lib')
-rw-r--r-- | nixos/lib/test-driver/default.nix | 4 | ||||
-rw-r--r-- | nixos/lib/test-driver/setup.py | 2 | ||||
-rw-r--r-- | nixos/lib/test-driver/test_driver/driver.py | 41 | ||||
-rw-r--r-- | nixos/lib/test-driver/test_driver/machine.py | 9 | ||||
-rw-r--r-- | nixos/lib/test-driver/test_driver/polling_condition.py | 77 |
5 files changed, 129 insertions, 4 deletions
diff --git a/nixos/lib/test-driver/default.nix b/nixos/lib/test-driver/default.nix index 3f63bc705b9..3aee9134318 100644 --- a/nixos/lib/test-driver/default.nix +++ b/nixos/lib/test-driver/default.nix @@ -14,7 +14,7 @@ python3Packages.buildPythonApplication rec { pname = "nixos-test-driver"; - version = "1.0"; + version = "1.1"; src = ./.; propagatedBuildInputs = [ coreutils netpbm python3Packages.colorama python3Packages.ptpython qemu_pkg socat vde2 ] @@ -26,7 +26,7 @@ python3Packages.buildPythonApplication rec { mypy --disallow-untyped-defs \ --no-implicit-optional \ --ignore-missing-imports ${src}/test_driver - pylint --errors-only ${src}/test_driver + pylint --errors-only --enable=unused-import ${src}/test_driver black --check --diff ${src}/test_driver ''; } diff --git a/nixos/lib/test-driver/setup.py b/nixos/lib/test-driver/setup.py index 15699547216..476c7b2dab2 100644 --- a/nixos/lib/test-driver/setup.py +++ b/nixos/lib/test-driver/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name="nixos-test-driver", - version='1.0', + version='1.1', packages=find_packages(), entry_points={ "console_scripts": [ diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py index f3af98537ad..49a42fe5fb4 100644 --- a/nixos/lib/test-driver/test_driver/driver.py +++ b/nixos/lib/test-driver/test_driver/driver.py @@ -1,12 +1,13 @@ from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, Iterator, List +from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager import os import tempfile from test_driver.logger import rootlog from test_driver.machine import Machine, NixStartScript, retry from test_driver.vlan import VLan +from test_driver.polling_condition import PollingCondition class Driver: @@ -16,6 +17,7 @@ class Driver: tests: str vlans: List[VLan] machines: List[Machine] + polling_conditions: List[PollingCondition] def __init__( self, @@ -36,12 +38,15 @@ class Driver: for s in scripts: yield NixStartScript(s) + self.polling_conditions = [] + self.machines = [ Machine( start_command=cmd, keep_vm_state=keep_vm_state, name=cmd.machine_name, tmp_dir=tmp_dir, + callbacks=[self.check_polling_conditions], ) for cmd in cmd(start_scripts) ] @@ -84,6 +89,7 @@ class Driver: retry=retry, serial_stdout_off=self.serial_stdout_off, serial_stdout_on=self.serial_stdout_on, + polling_condition=self.polling_condition, Machine=Machine, # for typing ) machine_symbols = {m.name: m for m in self.machines} @@ -159,3 +165,36 @@ class Driver: def serial_stdout_off(self) -> None: rootlog._print_serial_logs = False + + def check_polling_conditions(self) -> None: + for condition in self.polling_conditions: + condition.maybe_raise() + + def polling_condition( + self, + fun_: Optional[Callable] = None, + *, + seconds_interval: float = 2.0, + description: Optional[str] = None, + ) -> Union[Callable[[Callable], ContextManager], ContextManager]: + driver = self + + class Poll: + def __init__(self, fun: Callable): + self.condition = PollingCondition( + fun, + seconds_interval, + description, + ) + + def __enter__(self) -> None: + driver.polling_conditions.append(self.condition) + + def __exit__(self, a, b, c) -> None: # type: ignore + res = driver.polling_conditions.pop() + assert res is self.condition + + if fun_ is None: + return Poll + else: + return Poll(fun_) diff --git a/nixos/lib/test-driver/test_driver/machine.py b/nixos/lib/test-driver/test_driver/machine.py index b3dbe5126fc..e050cbd7d99 100644 --- a/nixos/lib/test-driver/test_driver/machine.py +++ b/nixos/lib/test-driver/test_driver/machine.py @@ -318,6 +318,7 @@ class Machine: # Store last serial console lines for use # of wait_for_console_text last_lines: Queue = Queue() + callbacks: List[Callable] def __repr__(self) -> str: return f"<Machine '{self.name}'>" @@ -329,12 +330,14 @@ class Machine: name: str = "machine", keep_vm_state: bool = False, allow_reboot: bool = False, + callbacks: Optional[List[Callable]] = None, ) -> None: self.tmp_dir = tmp_dir self.keep_vm_state = keep_vm_state self.allow_reboot = allow_reboot self.name = name self.start_command = start_command + self.callbacks = callbacks if callbacks is not None else [] # set up directories self.shared_dir = self.tmp_dir / "shared-xchg" @@ -406,6 +409,7 @@ class Machine: return answer def send_monitor_command(self, command: str) -> str: + self.run_callbacks() with self.nested("sending monitor command: {}".format(command)): message = ("{}\n".format(command)).encode() assert self.monitor is not None @@ -509,6 +513,7 @@ class Machine: def execute( self, command: str, check_return: bool = True, timeout: Optional[int] = 900 ) -> Tuple[int, str]: + self.run_callbacks() self.connect() if timeout is not None: @@ -969,3 +974,7 @@ class Machine: self.shell.close() self.monitor.close() self.serial_thread.join() + + def run_callbacks(self) -> None: + for callback in self.callbacks: + callback() diff --git a/nixos/lib/test-driver/test_driver/polling_condition.py b/nixos/lib/test-driver/test_driver/polling_condition.py new file mode 100644 index 00000000000..459845452fa --- /dev/null +++ b/nixos/lib/test-driver/test_driver/polling_condition.py @@ -0,0 +1,77 @@ +from typing import Callable, Optional +import time + +from .logger import rootlog + + +class PollingConditionFailed(Exception): + pass + + +class PollingCondition: + condition: Callable[[], bool] + seconds_interval: float + description: Optional[str] + + last_called: float + entered: bool + + def __init__( + self, + condition: Callable[[], Optional[bool]], + seconds_interval: float = 2.0, + description: Optional[str] = None, + ): + self.condition = condition # type: ignore + self.seconds_interval = seconds_interval + + if description is None: + if condition.__doc__: + self.description = condition.__doc__ + else: + self.description = condition.__name__ + else: + self.description = str(description) + + self.last_called = float("-inf") + self.entered = False + + def check(self) -> bool: + if self.entered or not self.overdue: + return True + + with self, rootlog.nested(self.nested_message): + rootlog.info(f"Time since last: {time.monotonic() - self.last_called:.2f}s") + try: + res = self.condition() # type: ignore + except Exception: + res = False + res = res is None or res + rootlog.info(self.status_message(res)) + return res + + def maybe_raise(self) -> None: + if not self.check(): + raise PollingConditionFailed(self.status_message(False)) + + def status_message(self, status: bool) -> str: + return f"Polling condition {'succeeded' if status else 'failed'}: {self.description}" + + @property + def nested_message(self) -> str: + nested_message = ["Checking polling condition"] + if self.description is not None: + nested_message.append(repr(self.description)) + + return " ".join(nested_message) + + @property + def overdue(self) -> bool: + return self.last_called + self.seconds_interval < time.monotonic() + + def __enter__(self) -> None: + self.entered = True + + def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore + self.entered = False + self.last_called = time.monotonic() |