summary refs log tree commit diff
path: root/nixos/lib/test-driver/test_driver/driver.py
diff options
context:
space:
mode:
Diffstat (limited to 'nixos/lib/test-driver/test_driver/driver.py')
-rw-r--r--nixos/lib/test-driver/test_driver/driver.py225
1 files changed, 225 insertions, 0 deletions
diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py
new file mode 100644
index 00000000000..880b1c5fdec
--- /dev/null
+++ b/nixos/lib/test-driver/test_driver/driver.py
@@ -0,0 +1,225 @@
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager
+import os
+import tempfile
+
+from test_driver.logger import rootlog
+from test_driver.machine import Machine, NixStartScript, retry
+from test_driver.vlan import VLan
+from test_driver.polling_condition import PollingCondition
+
+
+def get_tmp_dir() -> Path:
+    """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
+    Raises an exception in case the retrieved temporary directory is not writeable
+    See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
+    """
+    tmp_dir = Path(tempfile.gettempdir())
+    tmp_dir.mkdir(mode=0o700, exist_ok=True)
+    if not tmp_dir.is_dir():
+        raise NotADirectoryError(
+            "The directory defined by TMPDIR, TEMP, TMP or CWD: {0} is not a directory".format(
+                tmp_dir
+            )
+        )
+    if not os.access(tmp_dir, os.W_OK):
+        raise PermissionError(
+            "The directory defined by TMPDIR, TEMP, TMP, or CWD: {0} is not writeable".format(
+                tmp_dir
+            )
+        )
+    return tmp_dir
+
+
+class Driver:
+    """A handle to the driver that sets up the environment
+    and runs the tests"""
+
+    tests: str
+    vlans: List[VLan]
+    machines: List[Machine]
+    polling_conditions: List[PollingCondition]
+
+    def __init__(
+        self,
+        start_scripts: List[str],
+        vlans: List[int],
+        tests: str,
+        out_dir: Path,
+        keep_vm_state: bool = False,
+    ):
+        self.tests = tests
+        self.out_dir = out_dir
+
+        tmp_dir = get_tmp_dir()
+
+        with rootlog.nested("start all VLans"):
+            self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
+
+        def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
+            for s in scripts:
+                yield NixStartScript(s)
+
+        self.polling_conditions = []
+
+        self.machines = [
+            Machine(
+                start_command=cmd,
+                keep_vm_state=keep_vm_state,
+                name=cmd.machine_name,
+                tmp_dir=tmp_dir,
+                callbacks=[self.check_polling_conditions],
+                out_dir=self.out_dir,
+            )
+            for cmd in cmd(start_scripts)
+        ]
+
+    def __enter__(self) -> "Driver":
+        return self
+
+    def __exit__(self, *_: Any) -> None:
+        with rootlog.nested("cleanup"):
+            for machine in self.machines:
+                machine.release()
+
+    def subtest(self, name: str) -> Iterator[None]:
+        """Group logs under a given test name"""
+        with rootlog.nested(name):
+            try:
+                yield
+                return True
+            except Exception as e:
+                rootlog.error(f'Test "{name}" failed with error: "{e}"')
+                raise e
+
+    def test_symbols(self) -> Dict[str, Any]:
+        @contextmanager
+        def subtest(name: str) -> Iterator[None]:
+            return self.subtest(name)
+
+        general_symbols = dict(
+            start_all=self.start_all,
+            test_script=self.test_script,
+            machines=self.machines,
+            vlans=self.vlans,
+            driver=self,
+            log=rootlog,
+            os=os,
+            create_machine=self.create_machine,
+            subtest=subtest,
+            run_tests=self.run_tests,
+            join_all=self.join_all,
+            retry=retry,
+            serial_stdout_off=self.serial_stdout_off,
+            serial_stdout_on=self.serial_stdout_on,
+            polling_condition=self.polling_condition,
+            Machine=Machine,  # for typing
+        )
+        machine_symbols = {m.name: m for m in self.machines}
+        # If there's exactly one machine, make it available under the name
+        # "machine", even if it's not called that.
+        if len(self.machines) == 1:
+            (machine_symbols["machine"],) = self.machines
+        vlan_symbols = {
+            f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
+        }
+        print(
+            "additionally exposed symbols:\n    "
+            + ", ".join(map(lambda m: m.name, self.machines))
+            + ",\n    "
+            + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
+            + ",\n    "
+            + ", ".join(list(general_symbols.keys()))
+        )
+        return {**general_symbols, **machine_symbols, **vlan_symbols}
+
+    def test_script(self) -> None:
+        """Run the test script"""
+        with rootlog.nested("run the VM test script"):
+            symbols = self.test_symbols()  # call eagerly
+            exec(self.tests, symbols, None)
+
+    def run_tests(self) -> None:
+        """Run the test script (for non-interactive test runs)"""
+        self.test_script()
+        # TODO: Collect coverage data
+        for machine in self.machines:
+            if machine.is_up():
+                machine.execute("sync")
+
+    def start_all(self) -> None:
+        """Start all machines"""
+        with rootlog.nested("start all VMs"):
+            for machine in self.machines:
+                machine.start()
+
+    def join_all(self) -> None:
+        """Wait for all machines to shut down"""
+        with rootlog.nested("wait for all VMs to finish"):
+            for machine in self.machines:
+                machine.wait_for_shutdown()
+
+    def create_machine(self, args: Dict[str, Any]) -> Machine:
+        rootlog.warning(
+            "Using legacy create_machine(), please instantiate the"
+            "Machine class directly, instead"
+        )
+
+        tmp_dir = get_tmp_dir()
+
+        if args.get("startCommand"):
+            start_command: str = args.get("startCommand", "")
+            cmd = NixStartScript(start_command)
+            name = args.get("name", cmd.machine_name)
+        else:
+            cmd = Machine.create_startcommand(args)  # type: ignore
+            name = args.get("name", "machine")
+
+        return Machine(
+            tmp_dir=tmp_dir,
+            out_dir=self.out_dir,
+            start_command=cmd,
+            name=name,
+            keep_vm_state=args.get("keep_vm_state", False),
+            allow_reboot=args.get("allow_reboot", False),
+        )
+
+    def serial_stdout_on(self) -> None:
+        rootlog._print_serial_logs = True
+
+    def serial_stdout_off(self) -> None:
+        rootlog._print_serial_logs = False
+
+    def check_polling_conditions(self) -> None:
+        for condition in self.polling_conditions:
+            condition.maybe_raise()
+
+    def polling_condition(
+        self,
+        fun_: Optional[Callable] = None,
+        *,
+        seconds_interval: float = 2.0,
+        description: Optional[str] = None,
+    ) -> Union[Callable[[Callable], ContextManager], ContextManager]:
+        driver = self
+
+        class Poll:
+            def __init__(self, fun: Callable):
+                self.condition = PollingCondition(
+                    fun,
+                    seconds_interval,
+                    description,
+                )
+
+            def __enter__(self) -> None:
+                driver.polling_conditions.append(self.condition)
+
+            def __exit__(self, a, b, c) -> None:  # type: ignore
+                res = driver.polling_conditions.pop()
+                assert res is self.condition
+
+        if fun_ is None:
+            return Poll
+        else:
+            return Poll(fun_)