summary refs log tree commit diff
path: root/nixos/lib/test-driver/test_driver/machine.py
diff options
context:
space:
mode:
Diffstat (limited to 'nixos/lib/test-driver/test_driver/machine.py')
-rw-r--r--nixos/lib/test-driver/test_driver/machine.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/nixos/lib/test-driver/test_driver/machine.py b/nixos/lib/test-driver/test_driver/machine.py
index b3dbe5126fc..e050cbd7d99 100644
--- a/nixos/lib/test-driver/test_driver/machine.py
+++ b/nixos/lib/test-driver/test_driver/machine.py
@@ -318,6 +318,7 @@ class Machine:
     # Store last serial console lines for use
     # of wait_for_console_text
     last_lines: Queue = Queue()
+    callbacks: List[Callable]
 
     def __repr__(self) -> str:
         return f"<Machine '{self.name}'>"
@@ -329,12 +330,14 @@ class Machine:
         name: str = "machine",
         keep_vm_state: bool = False,
         allow_reboot: bool = False,
+        callbacks: Optional[List[Callable]] = None,
     ) -> None:
         self.tmp_dir = tmp_dir
         self.keep_vm_state = keep_vm_state
         self.allow_reboot = allow_reboot
         self.name = name
         self.start_command = start_command
+        self.callbacks = callbacks if callbacks is not None else []
 
         # set up directories
         self.shared_dir = self.tmp_dir / "shared-xchg"
@@ -406,6 +409,7 @@ class Machine:
             return answer
 
     def send_monitor_command(self, command: str) -> str:
+        self.run_callbacks()
         with self.nested("sending monitor command: {}".format(command)):
             message = ("{}\n".format(command)).encode()
             assert self.monitor is not None
@@ -509,6 +513,7 @@ class Machine:
     def execute(
         self, command: str, check_return: bool = True, timeout: Optional[int] = 900
     ) -> Tuple[int, str]:
+        self.run_callbacks()
         self.connect()
 
         if timeout is not None:
@@ -969,3 +974,7 @@ class Machine:
         self.shell.close()
         self.monitor.close()
         self.serial_thread.join()
+
+    def run_callbacks(self) -> None:
+        for callback in self.callbacks:
+            callback()