From 4e1556ed4d43da1f930b3fcf0fc20d827a34f3d2 Mon Sep 17 00:00:00 2001 From: Patrick Hilhorst Date: Sat, 1 Jan 2022 22:35:20 +0100 Subject: nixos/test-driver: add polling_condition --- nixos/lib/test-driver/test_driver/driver.py | 40 ++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) (limited to 'nixos/lib/test-driver/test_driver/driver.py') diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py index f3af98537ad..e22f9ee7a75 100644 --- a/nixos/lib/test-driver/test_driver/driver.py +++ b/nixos/lib/test-driver/test_driver/driver.py @@ -1,12 +1,13 @@ from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, Iterator, List +from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager import os import tempfile from test_driver.logger import rootlog from test_driver.machine import Machine, NixStartScript, retry from test_driver.vlan import VLan +from test_driver.polling_condition import PollingCondition class Driver: @@ -16,6 +17,7 @@ class Driver: tests: str vlans: List[VLan] machines: List[Machine] + polling_conditions: List[Callable] def __init__( self, @@ -36,12 +38,15 @@ class Driver: for s in scripts: yield NixStartScript(s) + self.polling_conditions = [] + self.machines = [ Machine( start_command=cmd, keep_vm_state=keep_vm_state, name=cmd.machine_name, tmp_dir=tmp_dir, + fail_early=self.fail_early, ) for cmd in cmd(start_scripts) ] @@ -84,6 +89,7 @@ class Driver: retry=retry, serial_stdout_off=self.serial_stdout_off, serial_stdout_on=self.serial_stdout_on, + polling_condition=self.polling_condition, Machine=Machine, # for typing ) machine_symbols = {m.name: m for m in self.machines} @@ -159,3 +165,35 @@ class Driver: def serial_stdout_off(self) -> None: rootlog._print_serial_logs = False + + def fail_early(self) -> bool: + return any(not f() for f in self.polling_conditions) + + def polling_condition( + self, + fun_: Optional[Callable] = None, + *, + seconds_interval: float = 2.0, + description: Optional[str] = None, + ) -> Union[Callable[[Callable], ContextManager], ContextManager]: + driver = self + + class Poll: + def __init__(self, fun: Callable): + self.condition = PollingCondition( + fun, + seconds_interval, + description, + ).check + + def __enter__(self) -> None: + driver.polling_conditions.append(self.condition) + + def __exit__(self, a, b, c) -> None: # type: ignore + res = driver.polling_conditions.pop() + assert res is self.condition + + if fun_ is None: + return Poll + else: + return Poll(fun_) -- cgit 1.4.1