summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Arnold <dar@xoe.solutions>2021-06-12 17:47:25 -0500
committerDavid Arnold <david.arnold@iohk.io>2021-10-05 14:38:48 -0500
commitb0fc9da879812e47c1ed3438fb0fd51db00a3494 (patch)
treec238d3e8ce9c6ad17c47e8414001a29e137d8e52
parent3069ba0dd1dec75c5dc4f6a1ee238a4fab9828cd (diff)
downloadnixpkgs-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar
nixpkgs-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar.gz
nixpkgs-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar.bz2
nixpkgs-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar.lz
nixpkgs-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar.xz
nixpkgs-b0fc9da879812e47c1ed3438fb0fd51db00a3494.tar.zst
nixpkgs-b0fc9da879812e47c1ed3438fb0fd51db00a3494.zip
nixos/test/test-driver: Class-ify the test driver
This commit encapsulates the involved domain into classes and
defines explicit and typed arguments where untyped dicts where used.

It preserves backwards compatibility through legacy wrappers.
-rwxr-xr-xnixos/lib/test-driver/test-driver.py804
-rw-r--r--nixos/lib/testing-python.nix11
-rw-r--r--nixos/modules/installer/tools/nixos-build-vms/build-vms.nix19
3 files changed, 527 insertions, 307 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py
index f8502188bde..fdc440a896a 100755
--- a/nixos/lib/test-driver/test-driver.py
+++ b/nixos/lib/test-driver/test-driver.py
@@ -21,7 +21,6 @@ import shutil
 import socket
 import subprocess
 import sys
-import telnetlib
 import tempfile
 import time
 import unicodedata
@@ -89,55 +88,6 @@ CHAR_TO_KEY = {
     ")": "shift-0x0B",
 }
 
-global log, machines, test_script
-
-
-def eprint(*args: object, **kwargs: Any) -> None:
-    print(*args, file=sys.stderr, **kwargs)
-
-
-def make_command(args: list) -> str:
-    return " ".join(map(shlex.quote, (map(str, args))))
-
-
-def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
-    log.log("starting VDE switch for network {}".format(vlan_nr))
-    vde_socket = tempfile.mkdtemp(
-        prefix="nixos-test-vde-", suffix="-vde{}.ctl".format(vlan_nr)
-    )
-    pty_master, pty_slave = pty.openpty()
-    vde_process = subprocess.Popen(
-        ["vde_switch", "-s", vde_socket, "--dirmode", "0700"],
-        stdin=pty_slave,
-        stdout=subprocess.PIPE,
-        stderr=subprocess.PIPE,
-        shell=False,
-    )
-    fd = os.fdopen(pty_master, "w")
-    fd.write("version\n")
-    # TODO: perl version checks if this can be read from
-    # an if not, dies. we could hang here forever. Fix it.
-    assert vde_process.stdout is not None
-    vde_process.stdout.readline()
-    if not os.path.exists(os.path.join(vde_socket, "ctl")):
-        raise Exception("cannot start vde_switch")
-
-    return (vlan_nr, vde_socket, vde_process, fd)
-
-
-def retry(fn: Callable, timeout: int = 900) -> None:
-    """Call the given function repeatedly, with 1 second intervals,
-    until it returns True or a timeout is reached.
-    """
-
-    for _ in range(timeout):
-        if fn(False):
-            return
-        time.sleep(1)
-
-    if not fn(True):
-        raise Exception(f"action timed out after {timeout} seconds")
-
 
 class Logger:
     def __init__(self) -> None:
@@ -151,6 +101,10 @@ class Logger:
 
         self._print_serial_logs = True
 
+    @staticmethod
+    def _eprint(*args: object, **kwargs: Any) -> None:
+        print(*args, file=sys.stderr, **kwargs)
+
     def close(self) -> None:
         self.xml.endElement("logfile")
         self.xml.endDocument()
@@ -169,15 +123,27 @@ class Logger:
         self.xml.characters(message)
         self.xml.endElement("line")
 
+    def info(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+
+    def warning(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+
+    def error(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+        sys.exit(1)
+
     def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
-        eprint(self.maybe_prefix(message, attributes))
+        self._eprint(self.maybe_prefix(message, attributes))
         self.drain_log_queue()
         self.log_line(message, attributes)
 
     def log_serial(self, message: str, machine: str) -> None:
         self.enqueue({"msg": message, "machine": machine, "type": "serial"})
         if self._print_serial_logs:
-            eprint(Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL)
+            self._eprint(
+                Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL
+            )
 
     def enqueue(self, item: Dict[str, str]) -> None:
         self.queue.put(item)
@@ -194,7 +160,7 @@ class Logger:
 
     @contextmanager
     def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
-        eprint(self.maybe_prefix(message, attributes))
+        self._eprint(self.maybe_prefix(message, attributes))
 
         self.xml.startElement("nest", attrs={})
         self.xml.startElement("head", attributes)
@@ -211,6 +177,27 @@ class Logger:
         self.xml.endElement("nest")
 
 
+rootlog = Logger()
+
+
+def make_command(args: list) -> str:
+    return " ".join(map(shlex.quote, (map(str, args))))
+
+
+def retry(fn: Callable, timeout: int = 900) -> None:
+    """Call the given function repeatedly, with 1 second intervals,
+    until it returns True or a timeout is reached.
+    """
+
+    for _ in range(timeout):
+        if fn(False):
+            return
+        time.sleep(1)
+
+    if not fn(True):
+        raise Exception(f"action timed out after {timeout} seconds")
+
+
 def _perform_ocr_on_screenshot(
     screenshot_path: str, model_ids: Iterable[int]
 ) -> List[str]:
@@ -242,113 +229,256 @@ def _perform_ocr_on_screenshot(
     return model_results
 
 
-class Machine:
-    def __repr__(self) -> str:
-        return f"<Machine '{self.name}'>"
-
-    def __init__(self, args: Dict[str, Any]) -> None:
-        if "name" in args:
-            self.name = args["name"]
-        else:
-            self.name = "machine"
-            cmd = args.get("startCommand", None)
-            if cmd:
-                match = re.search("run-(.+)-vm$", cmd)
-                if match:
-                    self.name = match.group(1)
-        self.logger = args["log"]
-        self.script = args.get("startCommand", self.create_startcommand(args))
-
-        tmp_dir = os.environ.get("TMPDIR", tempfile.gettempdir())
-
-        def create_dir(name: str) -> str:
-            path = os.path.join(tmp_dir, name)
-            os.makedirs(path, mode=0o700, exist_ok=True)
-            return path
+class StartCommand:
+    """The Base Start Command knows how to append the necesary
+    runtime qemu options as determined by a particular test driver
+    run. Any such start command is expected to happily receive and
+    append additional qemu args.
+    """
 
-        self.state_dir = os.path.join(tmp_dir, f"vm-state-{self.name}")
-        if not args.get("keepVmState", False):
-            self.cleanup_statedir()
-        os.makedirs(self.state_dir, mode=0o700, exist_ok=True)
-        self.shared_dir = create_dir("shared-xchg")
+    _cmd: str
 
-        self.booted = False
-        self.connected = False
-        self.pid: Optional[int] = None
-        self.socket = None
-        self.monitor: Optional[socket.socket] = None
-        self.allow_reboot = args.get("allowReboot", False)
+    def cmd(
+        self,
+        monitor_socket_path: pathlib.Path,
+        shell_socket_path: pathlib.Path,
+        allow_reboot: bool = False,  # TODO: unused, legacy?
+    ) -> str:
+        display_opts = ""
+        display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"])
+        if display_available:
+            display_opts += " -nographic"
+
+        # qemu options
+        qemu_opts = ""
+        qemu_opts += (
+            ""
+            if allow_reboot
+            else " -no-reboot"
+            " -device virtio-serial"
+            " -device virtconsole,chardev=shell"
+            " -device virtio-rng-pci"
+            " -serial stdio"
+        )
+        # TODO: qemu script already catpures this env variable, legacy?
+        qemu_opts += " " + os.environ.get("QEMU_OPTS", "")
+
+        return (
+            f"{self._cmd}"
+            f" -monitor unix:{monitor_socket_path}"
+            f" -chardev socket,id=shell,path={shell_socket_path}"
+            f"{qemu_opts}"
+            f"{display_opts}"
+        )
 
     @staticmethod
-    def create_startcommand(args: Dict[str, str]) -> str:
-        net_backend = "-netdev user,id=net0"
-        net_frontend = "-device virtio-net-pci,netdev=net0"
+    def build_environment(
+        state_dir: pathlib.Path,
+        shared_dir: pathlib.Path,
+    ) -> dict:
+        # We make a copy to not update the current environment
+        env = dict(os.environ)
+        env.update(
+            {
+                "TMPDIR": str(state_dir),
+                "SHARED_DIR": str(shared_dir),
+                "USE_TMPDIR": "1",
+            }
+        )
+        return env
+
+    def run(
+        self,
+        state_dir: pathlib.Path,
+        shared_dir: pathlib.Path,
+        monitor_socket_path: pathlib.Path,
+        shell_socket_path: pathlib.Path,
+    ) -> subprocess.Popen:
+        return subprocess.Popen(
+            self.cmd(monitor_socket_path, shell_socket_path),
+            stdin=subprocess.DEVNULL,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.STDOUT,
+            shell=True,
+            cwd=state_dir,
+            env=self.build_environment(state_dir, shared_dir),
+        )
+
 
-        if "netBackendArgs" in args:
-            net_backend += "," + args["netBackendArgs"]
+class NixStartScript(StartCommand):
+    """A start script from nixos/modules/virtualiation/qemu-vm.nix
+    that also satisfies the requirement of the BaseStartCommand.
+    These Nix commands have the particular charactersitic that the
+    machine name can be extracted out of them via a regex match.
+    (Admittedly a _very_ implicit contract, evtl. TODO fix)
+    """
 
-        if "netFrontendArgs" in args:
-            net_frontend += "," + args["netFrontendArgs"]
+    def __init__(self, script: str):
+        self._cmd = script
 
-        start_command = (
-            args.get("qemuBinary", "qemu-kvm")
-            + " -m 384 "
-            + net_backend
-            + " "
-            + net_frontend
-            + " $QEMU_OPTS "
-        )
+    @property
+    def machine_name(self) -> str:
+        match = re.search("run-(.+)-vm$", self._cmd)
+        name = "machine"
+        if match:
+            name = match.group(1)
+        return name
 
-        if "hda" in args:
-            hda_path = os.path.abspath(args["hda"])
-            if args.get("hdaInterface", "") == "scsi":
-                start_command += (
-                    "-drive id=hda,file="
-                    + hda_path
-                    + ",werror=report,if=none "
-                    + "-device scsi-hd,drive=hda "
+
+class LegacyStartCommand(StartCommand):
+    """Used in some places to create an ad-hoc machine instead of
+    using nix test instrumentation + module system for that purpose.
+    Legacy.
+    """
+
+    def __init__(
+        self,
+        netBackendArgs: Optional[str] = None,
+        netFrontendArgs: Optional[str] = None,
+        hda: Optional[Tuple[pathlib.Path, str]] = None,
+        cdrom: Optional[str] = None,
+        usb: Optional[str] = None,
+        bios: Optional[str] = None,
+        qemuFlags: Optional[str] = None,
+    ):
+        self._cmd = "qemu-kvm -m 384"
+
+        # networking
+        net_backend = "-netdev user,id=net0"
+        net_frontend = "-device virtio-net-pci,netdev=net0"
+        if netBackendArgs is not None:
+            net_backend += "," + netBackendArgs
+        if netFrontendArgs is not None:
+            net_frontend += "," + netFrontendArgs
+        self._cmd += f" {net_backend} {net_frontend}"
+
+        # hda
+        hda_cmd = ""
+        if hda is not None:
+            hda_path = hda[0].resolve()
+            hda_interface = hda[1]
+            if hda_interface == "scsi":
+                hda_cmd += (
+                    f" -drive id=hda,file={hda_path},werror=report,if=none"
+                    " -device scsi-hd,drive=hda"
                 )
             else:
-                start_command += (
-                    "-drive file="
-                    + hda_path
-                    + ",if="
-                    + args["hdaInterface"]
-                    + ",werror=report "
-                )
+                hda_cmd += f" -drive file={hda_path},if={hda_interface},werror=report"
+        self._cmd += hda_cmd
 
-        if "cdrom" in args:
-            start_command += "-cdrom " + args["cdrom"] + " "
+        # cdrom
+        if cdrom is not None:
+            self._cmd += f" -cdrom {cdrom}"
 
-        if "usb" in args:
+        # usb
+        usb_cmd = ""
+        if usb is not None:
             # https://github.com/qemu/qemu/blob/master/docs/usb2.txt
-            start_command += (
-                "-device usb-ehci -drive "
-                + "id=usbdisk,file="
-                + args["usb"]
-                + ",if=none,readonly "
-                + "-device usb-storage,drive=usbdisk "
+            usb_cmd += (
+                " -device usb-ehci"
+                f" -drive id=usbdisk,file={usb},if=none,readonly"
+                " -device usb-storage,drive=usbdisk "
             )
-        if "bios" in args:
-            start_command += "-bios " + args["bios"] + " "
+        self._cmd += usb_cmd
+
+        # bios
+        if bios is not None:
+            self._cmd += f" -bios {bios}"
+
+        # qemu flags
+        if qemuFlags is not None:
+            self._cmd += f" {qemuFlags}"
+
+
+class Machine:
+    """A handle to the machine with this name, that also knows how to manage
+    the machine lifecycle with the help of a start script / command."""
+
+    name: str
+    tmp_dir: pathlib.Path
+    shared_dir: pathlib.Path
+    state_dir: pathlib.Path
+    monitor_path: pathlib.Path
+    shell_path: pathlib.Path
+
+    start_command: StartCommand
+    keep_vm_state: bool
+    allow_reboot: bool
+
+    process: Optional[subprocess.Popen] = None
+    pid: Optional[int] = None
+    monitor: Optional[socket.socket] = None
+    shell: Optional[socket.socket] = None
+
+    booted: bool = False
+    connected: bool = False
+    # Store last serial console lines for use
+    # of wait_for_console_text
+    last_lines: Queue = Queue()
 
-        start_command += args.get("qemuFlags", "")
+    def __repr__(self) -> str:
+        return f"<Machine '{self.name}'>"
+
+    def __init__(
+        self,
+        tmp_dir: pathlib.Path,
+        start_command: StartCommand,
+        name: str = "machine",
+        keep_vm_state: bool = False,
+        allow_reboot: bool = False,
+    ) -> None:
+        self.tmp_dir = tmp_dir
+        self.keep_vm_state = keep_vm_state
+        self.allow_reboot = allow_reboot
+        self.name = name
+        self.start_command = start_command
+
+        # set up directories
+        self.shared_dir = self.tmp_dir / "shared-xchg"
+        self.shared_dir.mkdir(mode=0o700, exist_ok=True)
+
+        self.state_dir = self.tmp_dir / f"vm-state-{self.name}"
+        self.monitor_path = self.state_dir / "monitor"
+        self.shell_path = self.state_dir / "shell"
+        if (not self.keep_vm_state) and self.state_dir.exists():
+            self.cleanup_statedir()
+        self.state_dir.mkdir(mode=0o700, exist_ok=True)
 
-        return start_command
+    @staticmethod
+    def create_startcommand(args: Dict[str, str]) -> StartCommand:
+        rootlog.warning(
+            "Using legacy create_startcommand(),"
+            "please use proper nix test vm instrumentation, instead"
+            "to generate the appropriate nixos test vm qemu startup script"
+        )
+        hda = None
+        if args.get("hda"):
+            hda_arg: str = args.get("hda", "")
+            hda_arg_path: pathlib.Path = pathlib.Path(hda_arg)
+            hda = (hda_arg_path, args.get("hdaInterface", ""))
+        return LegacyStartCommand(
+            netBackendArgs=args.get("netBackendArgs"),
+            netFrontendArgs=args.get("netFrontendArgs"),
+            hda=hda,
+            cdrom=args.get("cdrom"),
+            usb=args.get("usb"),
+            bios=args.get("bios"),
+            qemuFlags=args.get("qemuFlags"),
+        )
 
     def is_up(self) -> bool:
         return self.booted and self.connected
 
     def log(self, msg: str) -> None:
-        self.logger.log(msg, {"machine": self.name})
+        rootlog.log(msg, {"machine": self.name})
 
     def log_serial(self, msg: str) -> None:
-        self.logger.log_serial(msg, self.name)
+        rootlog.log_serial(msg, self.name)
 
     def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
         my_attrs = {"machine": self.name}
         my_attrs.update(attrs)
-        return self.logger.nested(msg, my_attrs)
+        return rootlog.nested(msg, my_attrs)
 
     def wait_for_monitor_prompt(self) -> str:
         assert self.monitor is not None
@@ -446,6 +576,7 @@ class Machine:
         self.connect()
 
         out_command = "( set -euo pipefail; {} ); echo '|!=EOF' $?\n".format(command)
+        assert self.shell
         self.shell.send(out_command.encode())
 
         output = ""
@@ -466,6 +597,8 @@ class Machine:
         Should only be used during test development, not in the production test."""
         self.connect()
         self.log("Terminal is ready (there is no prompt):")
+
+        assert self.shell
         subprocess.run(
             ["socat", "READLINE", f"FD:{self.shell.fileno()}"],
             pass_fds=[self.shell.fileno()],
@@ -534,6 +667,7 @@ class Machine:
 
         with self.nested("waiting for the VM to power off"):
             sys.stdout.flush()
+            assert self.process
             self.process.wait()
 
             self.pid = None
@@ -611,6 +745,8 @@ class Machine:
         with self.nested("waiting for the VM to finish booting"):
             self.start()
 
+            assert self.shell
+
             tic = time.time()
             self.shell.recv(1024)
             # TODO: Timeout
@@ -750,65 +886,35 @@ class Machine:
 
         self.log("starting vm")
 
-        def create_socket(path: str) -> socket.socket:
-            if os.path.exists(path):
-                os.unlink(path)
+        def clear(path: pathlib.Path) -> pathlib.Path:
+            if path.exists():
+                path.unlink()
+            return path
+
+        def create_socket(path: pathlib.Path) -> socket.socket:
             s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
-            s.bind(path)
+            s.bind(str(path))
             s.listen(1)
             return s
 
-        monitor_path = os.path.join(self.state_dir, "monitor")
-        self.monitor_socket = create_socket(monitor_path)
-
-        shell_path = os.path.join(self.state_dir, "shell")
-        self.shell_socket = create_socket(shell_path)
-
-        display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"])
-        qemu_options = (
-            " ".join(
-                [
-                    "" if self.allow_reboot else "-no-reboot",
-                    "-monitor unix:{}".format(monitor_path),
-                    "-chardev socket,id=shell,path={}".format(shell_path),
-                    "-device virtio-serial",
-                    "-device virtconsole,chardev=shell",
-                    "-device virtio-rng-pci",
-                    "-serial stdio" if display_available else "-nographic",
-                ]
-            )
-            + " "
-            + os.environ.get("QEMU_OPTS", "")
+        monitor_socket = create_socket(clear(self.monitor_path))
+        shell_socket = create_socket(clear(self.shell_path))
+        self.process = self.start_command.run(
+            self.state_dir,
+            self.shared_dir,
+            self.monitor_path,
+            self.shell_path,
         )
-
-        environment = dict(os.environ)
-        environment.update(
-            {
-                "TMPDIR": self.state_dir,
-                "SHARED_DIR": self.shared_dir,
-                "USE_TMPDIR": "1",
-                "QEMU_OPTS": qemu_options,
-            }
-        )
-
-        self.process = subprocess.Popen(
-            self.script,
-            stdin=subprocess.DEVNULL,
-            stdout=subprocess.PIPE,
-            stderr=subprocess.STDOUT,
-            shell=True,
-            cwd=self.state_dir,
-            env=environment,
-        )
-        self.monitor, _ = self.monitor_socket.accept()
-        self.shell, _ = self.shell_socket.accept()
+        self.monitor, _ = monitor_socket.accept()
+        self.shell, _ = shell_socket.accept()
 
         # Store last serial console lines for use
         # of wait_for_console_text
         self.last_lines: Queue = Queue()
 
         def process_serial_output() -> None:
-            assert self.process.stdout is not None
+            assert self.process
+            assert self.process.stdout
             for _line in self.process.stdout:
                 # Ignore undecodable bytes that may occur in boot menus
                 line = _line.decode(errors="ignore").replace("\r", "").rstrip()
@@ -825,15 +931,15 @@ class Machine:
         self.log("QEMU running (pid {})".format(self.pid))
 
     def cleanup_statedir(self) -> None:
-        if os.path.isdir(self.state_dir):
-            shutil.rmtree(self.state_dir)
-            self.logger.log(f"deleting VM state directory {self.state_dir}")
-            self.logger.log("if you want to keep the VM state, pass --keep-vm-state")
+        shutil.rmtree(self.state_dir)
+        rootlog.log(f"deleting VM state directory {self.state_dir}")
+        rootlog.log("if you want to keep the VM state, pass --keep-vm-state")
 
     def shutdown(self) -> None:
         if not self.booted:
             return
 
+        assert self.shell
         self.shell.send("poweroff\n".encode())
         self.wait_for_shutdown()
 
@@ -908,41 +1014,225 @@ class Machine:
         """Make the machine reachable."""
         self.send_monitor_command("set_link virtio-net-pci.1 on")
 
+    def release(self) -> None:
+        if self.pid is None:
+            return
+        rootlog.info(f"kill machine (pid {self.pid})")
+        assert self.process
+        assert self.shell
+        assert self.monitor
+        self.process.terminate()
+        self.shell.close()
+        self.monitor.close()
+
+
+class VLan:
+    """A handle to the vlan with this number, that also knows how to manage
+    it's lifecycle.
+    """
 
-def create_machine(args: Dict[str, Any]) -> Machine:
-    args["log"] = log
-    return Machine(args)
+    nr: int
+    socket_dir: pathlib.Path
 
+    process: Optional[subprocess.Popen]
+    pid: Optional[int]
+    fd: Optional[io.TextIOBase]
 
-def start_all() -> None:
-    with log.nested("starting all VMs"):
-        for machine in machines:
-            machine.start()
+    def __repr__(self) -> str:
+        return f"<Vlan Nr. {self.nr}>"
 
+    def __init__(self, nr: int, tmp_dir: pathlib.Path):
+        self.nr = nr
+        self.socket_dir = tmp_dir / f"vde{self.nr}.ctl"
 
-def join_all() -> None:
-    with log.nested("waiting for all VMs to finish"):
-        for machine in machines:
-            machine.wait_for_shutdown()
+        # TODO: don't side-effect environment here
+        os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir)
 
+    def start(self) -> None:
 
-def run_tests(interactive: bool = False) -> None:
-    if interactive:
-        ptpython.repl.embed(test_symbols(), {})
-    else:
-        test_script()
+        rootlog.info("start vlan")
+        pty_master, pty_slave = pty.openpty()
+
+        self.process = subprocess.Popen(
+            ["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"],
+            stdin=pty_slave,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            shell=False,
+        )
+        self.pid = self.process.pid
+        self.fd = os.fdopen(pty_master, "w")
+        self.fd.write("version\n")
+
+        # TODO: perl version checks if this can be read from
+        # an if not, dies. we could hang here forever. Fix it.
+        assert self.process.stdout is not None
+        self.process.stdout.readline()
+        if not (self.socket_dir / "ctl").exists():
+            rootlog.error("cannot start vde_switch")
+
+        rootlog.info(f"running vlan (pid {self.pid})")
+
+    def release(self) -> None:
+        if self.pid is None:
+            return
+        rootlog.info(f"kill vlan (pid {self.pid})")
+        assert self.fd
+        assert self.process
+        self.fd.close()
+        self.process.terminate()
+
+
+class Driver:
+    """A handle to the driver that sets up the environment
+    and runs the tests"""
+
+    tests: str
+    vlans: List[VLan]
+    machines: List[Machine]
+
+    def __init__(
+        self,
+        start_scripts: List[str],
+        vlans: List[int],
+        tests: str,
+        keep_vm_state: bool = False,
+    ):
+        self.tests = tests
+
+        tmp_dir = pathlib.Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
+        tmp_dir.mkdir(mode=0o700, exist_ok=True)
+
+        self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
+        with rootlog.nested("start all VLans"):
+            for vlan in self.vlans:
+                vlan.start()
+
+        def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
+            for s in scripts:
+                yield NixStartScript(s)
+
+        self.machines = [
+            Machine(
+                start_command=cmd,
+                keep_vm_state=keep_vm_state,
+                name=cmd.machine_name,
+                tmp_dir=tmp_dir,
+            )
+            for cmd in cmd(start_scripts)
+        ]
+
+        @atexit.register
+        def clean_up() -> None:
+            with rootlog.nested("clean up"):
+                for machine in self.machines:
+                    machine.release()
+                for vlan in self.vlans:
+                    vlan.release()
+
+    def subtest(self, name: str) -> Iterator[None]:
+        """Group logs under a given test name"""
+        with rootlog.nested(name):
+            try:
+                yield
+                return True
+            except:
+                rootlog.error(f'Test "{name}" failed with error:')
+                raise
+
+    def test_symbols(self) -> Dict[str, Any]:
+        @contextmanager
+        def subtest(name: str) -> Iterator[None]:
+            return self.subtest(name)
+
+        general_symbols = dict(
+            start_all=self.start_all,
+            test_script=self.test_script,
+            machines=self.machines,
+            vlans=self.vlans,
+            driver=self,
+            log=rootlog,
+            os=os,
+            create_machine=self.create_machine,
+            subtest=subtest,
+            run_tests=self.run_tests,
+            join_all=self.join_all,
+            retry=retry,
+            serial_stdout_off=self.serial_stdout_off,
+            serial_stdout_on=self.serial_stdout_on,
+            Machine=Machine,  # for typing
+        )
+        machine_symbols = {
+            m.name: self.machines[idx] for idx, m in enumerate(self.machines)
+        }
+        vlan_symbols = {
+            f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
+        }
+        print(
+            "additionally exposed symbols:\n    "
+            + ", ".join(map(lambda m: m.name, self.machines))
+            + ",\n    "
+            + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
+            + ",\n    "
+            + ", ".join(list(general_symbols.keys()))
+        )
+        return {**general_symbols, **machine_symbols, **vlan_symbols}
+
+    def test_script(self) -> None:
+        """Run the test script"""
+        with rootlog.nested("run the VM test script"):
+            symbols = self.test_symbols()  # call eagerly
+            exec(self.tests, symbols, None)
+
+    def run_tests(self) -> None:
+        """Run the test script (for non-interactive test runs)"""
+        self.test_script()
         # TODO: Collect coverage data
-        for machine in machines:
+        for machine in self.machines:
             if machine.is_up():
                 machine.execute("sync")
 
+    def start_all(self) -> None:
+        """Start all machines"""
+        with rootlog.nested("start all VMs"):
+            for machine in self.machines:
+                machine.start()
+
+    def join_all(self) -> None:
+        """Wait for all machines to shut down"""
+        with rootlog.nested("wait for all VMs to finish"):
+            for machine in self.machines:
+                machine.wait_for_shutdown()
+
+    def create_machine(self, args: Dict[str, Any]) -> Machine:
+        rootlog.warning(
+            "Using legacy create_machine(), please instantiate the"
+            "Machine class directly, instead"
+        )
+        tmp_dir = pathlib.Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
+        tmp_dir.mkdir(mode=0o700, exist_ok=True)
 
-def serial_stdout_on() -> None:
-    log._print_serial_logs = True
+        if args.get("startCommand"):
+            start_command: str = args.get("startCommand", "")
+            cmd = NixStartScript(start_command)
+            name = args.get("name", cmd.machine_name)
+        else:
+            cmd = Machine.create_startcommand(args)  # type: ignore
+            name = args.get("name", "machine")
+
+        return Machine(
+            tmp_dir=tmp_dir,
+            start_command=cmd,
+            name=name,
+            keep_vm_state=args.get("keep_vm_state", False),
+            allow_reboot=args.get("allow_reboot", False),
+        )
 
+    def serial_stdout_on(self) -> None:
+        rootlog._print_serial_logs = True
 
-def serial_stdout_off() -> None:
-    log._print_serial_logs = False
+    def serial_stdout_off(self) -> None:
+        rootlog._print_serial_logs = False
 
 
 class EnvDefault(argparse.Action):
@@ -970,52 +1260,6 @@ class EnvDefault(argparse.Action):
         setattr(namespace, self.dest, values)
 
 
-@contextmanager
-def subtest(name: str) -> Iterator[None]:
-    with log.nested(name):
-        try:
-            yield
-            return True
-        except Exception as e:
-            log.log(f'Test "{name}" failed with error: "{e}"')
-            raise e
-
-    return False
-
-
-def _test_symbols() -> Dict[str, Any]:
-    general_symbols = dict(
-        start_all=start_all,
-        test_script=globals().get("test_script"),  # same
-        machines=globals().get("machines"),  # without being initialized
-        log=globals().get("log"),  # extracting those symbol keys
-        os=os,
-        create_machine=create_machine,
-        subtest=subtest,
-        run_tests=run_tests,
-        join_all=join_all,
-        retry=retry,
-        serial_stdout_off=serial_stdout_off,
-        serial_stdout_on=serial_stdout_on,
-        Machine=Machine,  # for typing
-    )
-    return general_symbols
-
-
-def test_symbols() -> Dict[str, Any]:
-
-    general_symbols = _test_symbols()
-
-    machine_symbols = {m.name: machines[idx] for idx, m in enumerate(machines)}
-    print(
-        "additionally exposed symbols:\n    "
-        + ", ".join(map(lambda m: m.name, machines))
-        + ",\n    "
-        + ", ".join(list(general_symbols.keys()))
-    )
-    return {**general_symbols, **machine_symbols}
-
-
 if __name__ == "__main__":
     arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
     arg_parser.add_argument(
@@ -1055,44 +1299,18 @@ if __name__ == "__main__":
     )
 
     args = arg_parser.parse_args()
-    testscript = pathlib.Path(args.testscript).read_text()
-
-    global log, machines, test_script
-
-    log = Logger()
-
-    vde_sockets = [create_vlan(v) for v in args.vlans]
-    for nr, vde_socket, _, _ in vde_sockets:
-        os.environ["QEMU_VDE_SOCKET_{}".format(nr)] = vde_socket
-
-    machines = [
-        create_machine({"startCommand": s, "keepVmState": args.keep_vm_state})
-        for s in args.start_scripts
-    ]
-    machine_eval = [
-        "{0} = machines[{1}]".format(m.name, idx) for idx, m in enumerate(machines)
-    ]
-    exec("\n".join(machine_eval))
-
-    @atexit.register
-    def clean_up() -> None:
-        with log.nested("cleaning up"):
-            for machine in machines:
-                if machine.pid is None:
-                    continue
-                log.log("killing {} (pid {})".format(machine.name, machine.pid))
-                machine.process.kill()
-            for _, _, process, _ in vde_sockets:
-                process.terminate()
-        log.close()
-
-    def test_script() -> None:
-        with log.nested("running the VM test script"):
-            symbols = test_symbols()  # call eagerly
-            exec(testscript, symbols, None)
-
-    interactive = args.interactive or (not bool(testscript))
-    tic = time.time()
-    run_tests(interactive)
-    toc = time.time()
-    print("test script finished in {:.2f}s".format(toc - tic))
+
+    if not args.keep_vm_state:
+        rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
+
+    driver = Driver(
+        args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state
+    )
+
+    if args.interactive:
+        ptpython.repl.embed(driver.test_symbols(), {})
+    else:
+        tic = time.time()
+        driver.run_tests()
+        toc = time.time()
+        rootlog.info(f"test script finished in {(toc-tic):.2f}s")
diff --git a/nixos/lib/testing-python.nix b/nixos/lib/testing-python.nix
index 43b4f9b159b..1969f40edb6 100644
--- a/nixos/lib/testing-python.nix
+++ b/nixos/lib/testing-python.nix
@@ -43,7 +43,8 @@ rec {
         from pydoc import importfile
         with open('driver-symbols', 'w') as fp:
           t = importfile('${testDriverScript}')
-          test_symbols = t._test_symbols()
+          d = t.Driver([],[],"")
+          test_symbols = d.test_symbols()
           fp.write(','.join(test_symbols.keys()))
         EOF
       '';
@@ -188,14 +189,6 @@ rec {
           --set startScripts "''${vmStartScripts[*]}" \
           --set testScript "$out/test-script" \
           --set vlans '${toString vlans}'
-
-        ${lib.optionalString (testScript == "") ''
-          ln -s ${testDriver}/bin/nixos-test-driver $out/bin/nixos-run-vms
-          wrapProgram $out/bin/nixos-run-vms \
-            --set startScripts "''${vmStartScripts[*]}" \
-            --set testScript "${pkgs.writeText "start-all" "start_all(); join_all();"}" \
-            --set vlans '${toString vlans}'
-        ''}
       '');
 
   # Make a full-blown test
diff --git a/nixos/modules/installer/tools/nixos-build-vms/build-vms.nix b/nixos/modules/installer/tools/nixos-build-vms/build-vms.nix
index e49ceba2424..ce69b16cffa 100644
--- a/nixos/modules/installer/tools/nixos-build-vms/build-vms.nix
+++ b/nixos/modules/installer/tools/nixos-build-vms/build-vms.nix
@@ -8,11 +8,20 @@ let
     _file = "${networkExpr}@node-${vm}";
     imports = [ module ];
   }) (import networkExpr);
+
+  testing = import ../../../../lib/testing-python.nix {
+    inherit system;
+    pkgs = import ../../../../.. { inherit system config; };
+  };
+
+  interactiveDriver = (testing.makeTest { inherit nodes; testScript = "start_all(); join_all();"; }).driverInteractive;
 in
 
-with import ../../../../lib/testing-python.nix {
-  inherit system;
-  pkgs = import ../../../../.. { inherit system config; };
-};
 
-(makeTest { inherit nodes; testScript = ""; }).driverInteractive
+pkgs.runCommand "nixos-build-vms" ''
+  mkdir -p $out/bin
+  ln -s ${interactiveDriver}/bin/nixos-test-driver $out/bin/nixos-test-driver
+  ln -s ${interactiveDriver}/bin/nixos-test-driver $out/bin/nixos-run-vms
+  wrapProgram $out/bin/nixos-test-driver \
+    --add-flags "--interactive"
+''