summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Hilhorst <git@hilhorst.be>2022-01-02 22:52:17 +0100
committerPatrick Hilhorst <git@hilhorst.be>2022-01-02 22:52:17 +0100
commita2f5092867927ea6a9bfc916ae191d3722350a33 (patch)
tree4583320837fd635d1990eeefa96551b98dab9ce8
parent7830f000c57bb616b178a6a8eaef9659938ca7ea (diff)
downloadnixpkgs-a2f5092867927ea6a9bfc916ae191d3722350a33.tar
nixpkgs-a2f5092867927ea6a9bfc916ae191d3722350a33.tar.gz
nixpkgs-a2f5092867927ea6a9bfc916ae191d3722350a33.tar.bz2
nixpkgs-a2f5092867927ea6a9bfc916ae191d3722350a33.tar.lz
nixpkgs-a2f5092867927ea6a9bfc916ae191d3722350a33.tar.xz
nixpkgs-a2f5092867927ea6a9bfc916ae191d3722350a33.tar.zst
nixpkgs-a2f5092867927ea6a9bfc916ae191d3722350a33.zip
nixos/test-driver: simplify logic, reduce interaction surface
-rw-r--r--nixos/lib/test-driver/test_driver/driver.py11
-rw-r--r--nixos/lib/test-driver/test_driver/machine.py16
-rw-r--r--nixos/lib/test-driver/test_driver/polling_condition.py25
3 files changed, 28 insertions, 24 deletions
diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py
index e22f9ee7a75..49a42fe5fb4 100644
--- a/nixos/lib/test-driver/test_driver/driver.py
+++ b/nixos/lib/test-driver/test_driver/driver.py
@@ -17,7 +17,7 @@ class Driver:
     tests: str
     vlans: List[VLan]
     machines: List[Machine]
-    polling_conditions: List[Callable]
+    polling_conditions: List[PollingCondition]
 
     def __init__(
         self,
@@ -46,7 +46,7 @@ class Driver:
                 keep_vm_state=keep_vm_state,
                 name=cmd.machine_name,
                 tmp_dir=tmp_dir,
-                fail_early=self.fail_early,
+                callbacks=[self.check_polling_conditions],
             )
             for cmd in cmd(start_scripts)
         ]
@@ -166,8 +166,9 @@ class Driver:
     def serial_stdout_off(self) -> None:
         rootlog._print_serial_logs = False
 
-    def fail_early(self) -> bool:
-        return any(not f() for f in self.polling_conditions)
+    def check_polling_conditions(self) -> None:
+        for condition in self.polling_conditions:
+            condition.maybe_raise()
 
     def polling_condition(
         self,
@@ -184,7 +185,7 @@ class Driver:
                     fun,
                     seconds_interval,
                     description,
-                ).check
+                )
 
             def __enter__(self) -> None:
                 driver.polling_conditions.append(self.condition)
diff --git a/nixos/lib/test-driver/test_driver/machine.py b/nixos/lib/test-driver/test_driver/machine.py
index dbf9fd24486..8615030b22e 100644
--- a/nixos/lib/test-driver/test_driver/machine.py
+++ b/nixos/lib/test-driver/test_driver/machine.py
@@ -17,7 +17,7 @@ import threading
 import time
 
 from test_driver.logger import rootlog
-from test_driver.polling_condition import PollingCondition, coopmulti
+from test_driver.polling_condition import PollingCondition
 
 CHAR_TO_KEY = {
     "A": "shift-a",
@@ -319,7 +319,7 @@ class Machine:
     # Store last serial console lines for use
     # of wait_for_console_text
     last_lines: Queue = Queue()
-    fail_early: Callable
+    callbacks: List[Callable]
 
     def __repr__(self) -> str:
         return f"<Machine '{self.name}'>"
@@ -331,14 +331,14 @@ class Machine:
         name: str = "machine",
         keep_vm_state: bool = False,
         allow_reboot: bool = False,
-        fail_early: Callable = lambda: False,
+        callbacks: Optional[List[Callable]] = None,
     ) -> None:
         self.tmp_dir = tmp_dir
         self.keep_vm_state = keep_vm_state
         self.allow_reboot = allow_reboot
         self.name = name
         self.start_command = start_command
-        self.fail_early = fail_early
+        self.callbacks = callbacks if callbacks is not None else []
 
         # set up directories
         self.shared_dir = self.tmp_dir / "shared-xchg"
@@ -409,8 +409,8 @@ class Machine:
                     break
             return answer
 
-    @coopmulti
     def send_monitor_command(self, command: str) -> str:
+        self.run_callbacks()
         with self.nested("sending monitor command: {}".format(command)):
             message = ("{}\n".format(command)).encode()
             assert self.monitor is not None
@@ -511,10 +511,10 @@ class Machine:
                 break
         return "".join(output_buffer)
 
-    @coopmulti
     def execute(
         self, command: str, check_return: bool = True, timeout: Optional[int] = 900
     ) -> Tuple[int, str]:
+        self.run_callbacks()
         self.connect()
 
         if timeout is not None:
@@ -975,3 +975,7 @@ class Machine:
         self.shell.close()
         self.monitor.close()
         self.serial_thread.join()
+
+    def run_callbacks(self) -> None:
+        for callback in self.callbacks:
+            callback()
diff --git a/nixos/lib/test-driver/test_driver/polling_condition.py b/nixos/lib/test-driver/test_driver/polling_condition.py
index fe064b1f830..65b00114336 100644
--- a/nixos/lib/test-driver/test_driver/polling_condition.py
+++ b/nixos/lib/test-driver/test_driver/polling_condition.py
@@ -10,17 +10,6 @@ class PollingConditionFailed(Exception):
     pass
 
 
-def coopmulti(fun: Callable) -> Callable:
-    @wraps(fun)
-    def wrapper(machine: Any, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
-        if machine.fail_early():  # type: ignore
-            raise PollingConditionFailed("Test interrupted early...")
-
-        return fun(machine, *args, **kwargs)
-
-    return wrapper
-
-
 class PollingCondition:
     condition: Callable[[], bool]
     seconds_interval: float
@@ -39,7 +28,10 @@ class PollingCondition:
         self.seconds_interval = seconds_interval
 
         if description is None:
-            self.description = condition.__doc__
+            if condition.__doc__:
+                self.description = condition.__doc__
+            else:
+                self.description = condition.__name__
         else:
             self.description = str(description)
 
@@ -57,9 +49,16 @@ class PollingCondition:
             except Exception:
                 res = False
             res = res is None or res
-            rootlog.info(f"Polling condition {'succeeded' if res else 'failed'}")
+            rootlog.info(self.status_message(res))
             return res
 
+    def maybe_raise(self) -> None:
+        if not self.check():
+            raise PollingConditionFailed(self.status_message(False))
+
+    def status_message(self, status: bool) -> str:
+        return f"Polling condition {'succeeded' if status else 'failed'}: {self.description}"
+
     @property
     def nested_message(self) -> str:
         nested_message = ["Checking polling condition"]