summary refs log tree commit diff
path: root/nixos/lib/test-driver/test_driver/driver.py
diff options
context:
space:
mode:
authorPatrick Hilhorst <git@hilhorst.be>2022-01-01 22:35:20 +0100
committerPatrick Hilhorst <git@hilhorst.be>2022-01-01 23:17:32 +0100
commit4e1556ed4d43da1f930b3fcf0fc20d827a34f3d2 (patch)
tree30e14d078a28988a7234c66b41fb93e3bd0f1b6b /nixos/lib/test-driver/test_driver/driver.py
parent69856d9ba78905337407136f48012c23962871e7 (diff)
downloadnixpkgs-4e1556ed4d43da1f930b3fcf0fc20d827a34f3d2.tar
nixpkgs-4e1556ed4d43da1f930b3fcf0fc20d827a34f3d2.tar.gz
nixpkgs-4e1556ed4d43da1f930b3fcf0fc20d827a34f3d2.tar.bz2
nixpkgs-4e1556ed4d43da1f930b3fcf0fc20d827a34f3d2.tar.lz
nixpkgs-4e1556ed4d43da1f930b3fcf0fc20d827a34f3d2.tar.xz
nixpkgs-4e1556ed4d43da1f930b3fcf0fc20d827a34f3d2.tar.zst
nixpkgs-4e1556ed4d43da1f930b3fcf0fc20d827a34f3d2.zip
nixos/test-driver: add polling_condition
Diffstat (limited to 'nixos/lib/test-driver/test_driver/driver.py')
-rw-r--r--nixos/lib/test-driver/test_driver/driver.py40
1 files changed, 39 insertions, 1 deletions
diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py
index f3af98537ad..e22f9ee7a75 100644
--- a/nixos/lib/test-driver/test_driver/driver.py
+++ b/nixos/lib/test-driver/test_driver/driver.py
@@ -1,12 +1,13 @@
 from contextlib import contextmanager
 from pathlib import Path
-from typing import Any, Dict, Iterator, List
+from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager
 import os
 import tempfile
 
 from test_driver.logger import rootlog
 from test_driver.machine import Machine, NixStartScript, retry
 from test_driver.vlan import VLan
+from test_driver.polling_condition import PollingCondition
 
 
 class Driver:
@@ -16,6 +17,7 @@ class Driver:
     tests: str
     vlans: List[VLan]
     machines: List[Machine]
+    polling_conditions: List[Callable]
 
     def __init__(
         self,
@@ -36,12 +38,15 @@ class Driver:
             for s in scripts:
                 yield NixStartScript(s)
 
+        self.polling_conditions = []
+
         self.machines = [
             Machine(
                 start_command=cmd,
                 keep_vm_state=keep_vm_state,
                 name=cmd.machine_name,
                 tmp_dir=tmp_dir,
+                fail_early=self.fail_early,
             )
             for cmd in cmd(start_scripts)
         ]
@@ -84,6 +89,7 @@ class Driver:
             retry=retry,
             serial_stdout_off=self.serial_stdout_off,
             serial_stdout_on=self.serial_stdout_on,
+            polling_condition=self.polling_condition,
             Machine=Machine,  # for typing
         )
         machine_symbols = {m.name: m for m in self.machines}
@@ -159,3 +165,35 @@ class Driver:
 
     def serial_stdout_off(self) -> None:
         rootlog._print_serial_logs = False
+
+    def fail_early(self) -> bool:
+        return any(not f() for f in self.polling_conditions)
+
+    def polling_condition(
+        self,
+        fun_: Optional[Callable] = None,
+        *,
+        seconds_interval: float = 2.0,
+        description: Optional[str] = None,
+    ) -> Union[Callable[[Callable], ContextManager], ContextManager]:
+        driver = self
+
+        class Poll:
+            def __init__(self, fun: Callable):
+                self.condition = PollingCondition(
+                    fun,
+                    seconds_interval,
+                    description,
+                ).check
+
+            def __enter__(self) -> None:
+                driver.polling_conditions.append(self.condition)
+
+            def __exit__(self, a, b, c) -> None:  # type: ignore
+                res = driver.polling_conditions.pop()
+                assert res is self.condition
+
+        if fun_ is None:
+            return Poll
+        else:
+            return Poll(fun_)