From b0fc9da879812e47c1ed3438fb0fd51db00a3494 Mon Sep 17 00:00:00 2001 From: David Arnold Date: Sat, 12 Jun 2021 17:47:25 -0500 Subject: nixos/test/test-driver: Class-ify the test driver This commit encapsulates the involved domain into classes and defines explicit and typed arguments where untyped dicts where used. It preserves backwards compatibility through legacy wrappers. --- nixos/lib/test-driver/test-driver.py | 804 +++++++++++++-------- nixos/lib/testing-python.nix | 11 +- .../installer/tools/nixos-build-vms/build-vms.nix | 19 +- 3 files changed, 527 insertions(+), 307 deletions(-) diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py index f8502188bde..fdc440a896a 100755 --- a/nixos/lib/test-driver/test-driver.py +++ b/nixos/lib/test-driver/test-driver.py @@ -21,7 +21,6 @@ import shutil import socket import subprocess import sys -import telnetlib import tempfile import time import unicodedata @@ -89,55 +88,6 @@ CHAR_TO_KEY = { ")": "shift-0x0B", } -global log, machines, test_script - - -def eprint(*args: object, **kwargs: Any) -> None: - print(*args, file=sys.stderr, **kwargs) - - -def make_command(args: list) -> str: - return " ".join(map(shlex.quote, (map(str, args)))) - - -def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]: - log.log("starting VDE switch for network {}".format(vlan_nr)) - vde_socket = tempfile.mkdtemp( - prefix="nixos-test-vde-", suffix="-vde{}.ctl".format(vlan_nr) - ) - pty_master, pty_slave = pty.openpty() - vde_process = subprocess.Popen( - ["vde_switch", "-s", vde_socket, "--dirmode", "0700"], - stdin=pty_slave, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=False, - ) - fd = os.fdopen(pty_master, "w") - fd.write("version\n") - # TODO: perl version checks if this can be read from - # an if not, dies. we could hang here forever. Fix it. - assert vde_process.stdout is not None - vde_process.stdout.readline() - if not os.path.exists(os.path.join(vde_socket, "ctl")): - raise Exception("cannot start vde_switch") - - return (vlan_nr, vde_socket, vde_process, fd) - - -def retry(fn: Callable, timeout: int = 900) -> None: - """Call the given function repeatedly, with 1 second intervals, - until it returns True or a timeout is reached. - """ - - for _ in range(timeout): - if fn(False): - return - time.sleep(1) - - if not fn(True): - raise Exception(f"action timed out after {timeout} seconds") - class Logger: def __init__(self) -> None: @@ -151,6 +101,10 @@ class Logger: self._print_serial_logs = True + @staticmethod + def _eprint(*args: object, **kwargs: Any) -> None: + print(*args, file=sys.stderr, **kwargs) + def close(self) -> None: self.xml.endElement("logfile") self.xml.endDocument() @@ -169,15 +123,27 @@ class Logger: self.xml.characters(message) self.xml.endElement("line") + def info(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + + def warning(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + + def error(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + sys.exit(1) + def log(self, message: str, attributes: Dict[str, str] = {}) -> None: - eprint(self.maybe_prefix(message, attributes)) + self._eprint(self.maybe_prefix(message, attributes)) self.drain_log_queue() self.log_line(message, attributes) def log_serial(self, message: str, machine: str) -> None: self.enqueue({"msg": message, "machine": machine, "type": "serial"}) if self._print_serial_logs: - eprint(Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL) + self._eprint( + Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL + ) def enqueue(self, item: Dict[str, str]) -> None: self.queue.put(item) @@ -194,7 +160,7 @@ class Logger: @contextmanager def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: - eprint(self.maybe_prefix(message, attributes)) + self._eprint(self.maybe_prefix(message, attributes)) self.xml.startElement("nest", attrs={}) self.xml.startElement("head", attributes) @@ -211,6 +177,27 @@ class Logger: self.xml.endElement("nest") +rootlog = Logger() + + +def make_command(args: list) -> str: + return " ".join(map(shlex.quote, (map(str, args)))) + + +def retry(fn: Callable, timeout: int = 900) -> None: + """Call the given function repeatedly, with 1 second intervals, + until it returns True or a timeout is reached. + """ + + for _ in range(timeout): + if fn(False): + return + time.sleep(1) + + if not fn(True): + raise Exception(f"action timed out after {timeout} seconds") + + def _perform_ocr_on_screenshot( screenshot_path: str, model_ids: Iterable[int] ) -> List[str]: @@ -242,113 +229,256 @@ def _perform_ocr_on_screenshot( return model_results -class Machine: - def __repr__(self) -> str: - return f"" - - def __init__(self, args: Dict[str, Any]) -> None: - if "name" in args: - self.name = args["name"] - else: - self.name = "machine" - cmd = args.get("startCommand", None) - if cmd: - match = re.search("run-(.+)-vm$", cmd) - if match: - self.name = match.group(1) - self.logger = args["log"] - self.script = args.get("startCommand", self.create_startcommand(args)) - - tmp_dir = os.environ.get("TMPDIR", tempfile.gettempdir()) - - def create_dir(name: str) -> str: - path = os.path.join(tmp_dir, name) - os.makedirs(path, mode=0o700, exist_ok=True) - return path +class StartCommand: + """The Base Start Command knows how to append the necesary + runtime qemu options as determined by a particular test driver + run. Any such start command is expected to happily receive and + append additional qemu args. + """ - self.state_dir = os.path.join(tmp_dir, f"vm-state-{self.name}") - if not args.get("keepVmState", False): - self.cleanup_statedir() - os.makedirs(self.state_dir, mode=0o700, exist_ok=True) - self.shared_dir = create_dir("shared-xchg") + _cmd: str - self.booted = False - self.connected = False - self.pid: Optional[int] = None - self.socket = None - self.monitor: Optional[socket.socket] = None - self.allow_reboot = args.get("allowReboot", False) + def cmd( + self, + monitor_socket_path: pathlib.Path, + shell_socket_path: pathlib.Path, + allow_reboot: bool = False, # TODO: unused, legacy? + ) -> str: + display_opts = "" + display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"]) + if display_available: + display_opts += " -nographic" + + # qemu options + qemu_opts = "" + qemu_opts += ( + "" + if allow_reboot + else " -no-reboot" + " -device virtio-serial" + " -device virtconsole,chardev=shell" + " -device virtio-rng-pci" + " -serial stdio" + ) + # TODO: qemu script already catpures this env variable, legacy? + qemu_opts += " " + os.environ.get("QEMU_OPTS", "") + + return ( + f"{self._cmd}" + f" -monitor unix:{monitor_socket_path}" + f" -chardev socket,id=shell,path={shell_socket_path}" + f"{qemu_opts}" + f"{display_opts}" + ) @staticmethod - def create_startcommand(args: Dict[str, str]) -> str: - net_backend = "-netdev user,id=net0" - net_frontend = "-device virtio-net-pci,netdev=net0" + def build_environment( + state_dir: pathlib.Path, + shared_dir: pathlib.Path, + ) -> dict: + # We make a copy to not update the current environment + env = dict(os.environ) + env.update( + { + "TMPDIR": str(state_dir), + "SHARED_DIR": str(shared_dir), + "USE_TMPDIR": "1", + } + ) + return env + + def run( + self, + state_dir: pathlib.Path, + shared_dir: pathlib.Path, + monitor_socket_path: pathlib.Path, + shell_socket_path: pathlib.Path, + ) -> subprocess.Popen: + return subprocess.Popen( + self.cmd(monitor_socket_path, shell_socket_path), + stdin=subprocess.DEVNULL, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=True, + cwd=state_dir, + env=self.build_environment(state_dir, shared_dir), + ) + - if "netBackendArgs" in args: - net_backend += "," + args["netBackendArgs"] +class NixStartScript(StartCommand): + """A start script from nixos/modules/virtualiation/qemu-vm.nix + that also satisfies the requirement of the BaseStartCommand. + These Nix commands have the particular charactersitic that the + machine name can be extracted out of them via a regex match. + (Admittedly a _very_ implicit contract, evtl. TODO fix) + """ - if "netFrontendArgs" in args: - net_frontend += "," + args["netFrontendArgs"] + def __init__(self, script: str): + self._cmd = script - start_command = ( - args.get("qemuBinary", "qemu-kvm") - + " -m 384 " - + net_backend - + " " - + net_frontend - + " $QEMU_OPTS " - ) + @property + def machine_name(self) -> str: + match = re.search("run-(.+)-vm$", self._cmd) + name = "machine" + if match: + name = match.group(1) + return name - if "hda" in args: - hda_path = os.path.abspath(args["hda"]) - if args.get("hdaInterface", "") == "scsi": - start_command += ( - "-drive id=hda,file=" - + hda_path - + ",werror=report,if=none " - + "-device scsi-hd,drive=hda " + +class LegacyStartCommand(StartCommand): + """Used in some places to create an ad-hoc machine instead of + using nix test instrumentation + module system for that purpose. + Legacy. + """ + + def __init__( + self, + netBackendArgs: Optional[str] = None, + netFrontendArgs: Optional[str] = None, + hda: Optional[Tuple[pathlib.Path, str]] = None, + cdrom: Optional[str] = None, + usb: Optional[str] = None, + bios: Optional[str] = None, + qemuFlags: Optional[str] = None, + ): + self._cmd = "qemu-kvm -m 384" + + # networking + net_backend = "-netdev user,id=net0" + net_frontend = "-device virtio-net-pci,netdev=net0" + if netBackendArgs is not None: + net_backend += "," + netBackendArgs + if netFrontendArgs is not None: + net_frontend += "," + netFrontendArgs + self._cmd += f" {net_backend} {net_frontend}" + + # hda + hda_cmd = "" + if hda is not None: + hda_path = hda[0].resolve() + hda_interface = hda[1] + if hda_interface == "scsi": + hda_cmd += ( + f" -drive id=hda,file={hda_path},werror=report,if=none" + " -device scsi-hd,drive=hda" ) else: - start_command += ( - "-drive file=" - + hda_path - + ",if=" - + args["hdaInterface"] - + ",werror=report " - ) + hda_cmd += f" -drive file={hda_path},if={hda_interface},werror=report" + self._cmd += hda_cmd - if "cdrom" in args: - start_command += "-cdrom " + args["cdrom"] + " " + # cdrom + if cdrom is not None: + self._cmd += f" -cdrom {cdrom}" - if "usb" in args: + # usb + usb_cmd = "" + if usb is not None: # https://github.com/qemu/qemu/blob/master/docs/usb2.txt - start_command += ( - "-device usb-ehci -drive " - + "id=usbdisk,file=" - + args["usb"] - + ",if=none,readonly " - + "-device usb-storage,drive=usbdisk " + usb_cmd += ( + " -device usb-ehci" + f" -drive id=usbdisk,file={usb},if=none,readonly" + " -device usb-storage,drive=usbdisk " ) - if "bios" in args: - start_command += "-bios " + args["bios"] + " " + self._cmd += usb_cmd + + # bios + if bios is not None: + self._cmd += f" -bios {bios}" + + # qemu flags + if qemuFlags is not None: + self._cmd += f" {qemuFlags}" + + +class Machine: + """A handle to the machine with this name, that also knows how to manage + the machine lifecycle with the help of a start script / command.""" + + name: str + tmp_dir: pathlib.Path + shared_dir: pathlib.Path + state_dir: pathlib.Path + monitor_path: pathlib.Path + shell_path: pathlib.Path + + start_command: StartCommand + keep_vm_state: bool + allow_reboot: bool + + process: Optional[subprocess.Popen] = None + pid: Optional[int] = None + monitor: Optional[socket.socket] = None + shell: Optional[socket.socket] = None + + booted: bool = False + connected: bool = False + # Store last serial console lines for use + # of wait_for_console_text + last_lines: Queue = Queue() - start_command += args.get("qemuFlags", "") + def __repr__(self) -> str: + return f"" + + def __init__( + self, + tmp_dir: pathlib.Path, + start_command: StartCommand, + name: str = "machine", + keep_vm_state: bool = False, + allow_reboot: bool = False, + ) -> None: + self.tmp_dir = tmp_dir + self.keep_vm_state = keep_vm_state + self.allow_reboot = allow_reboot + self.name = name + self.start_command = start_command + + # set up directories + self.shared_dir = self.tmp_dir / "shared-xchg" + self.shared_dir.mkdir(mode=0o700, exist_ok=True) + + self.state_dir = self.tmp_dir / f"vm-state-{self.name}" + self.monitor_path = self.state_dir / "monitor" + self.shell_path = self.state_dir / "shell" + if (not self.keep_vm_state) and self.state_dir.exists(): + self.cleanup_statedir() + self.state_dir.mkdir(mode=0o700, exist_ok=True) - return start_command + @staticmethod + def create_startcommand(args: Dict[str, str]) -> StartCommand: + rootlog.warning( + "Using legacy create_startcommand()," + "please use proper nix test vm instrumentation, instead" + "to generate the appropriate nixos test vm qemu startup script" + ) + hda = None + if args.get("hda"): + hda_arg: str = args.get("hda", "") + hda_arg_path: pathlib.Path = pathlib.Path(hda_arg) + hda = (hda_arg_path, args.get("hdaInterface", "")) + return LegacyStartCommand( + netBackendArgs=args.get("netBackendArgs"), + netFrontendArgs=args.get("netFrontendArgs"), + hda=hda, + cdrom=args.get("cdrom"), + usb=args.get("usb"), + bios=args.get("bios"), + qemuFlags=args.get("qemuFlags"), + ) def is_up(self) -> bool: return self.booted and self.connected def log(self, msg: str) -> None: - self.logger.log(msg, {"machine": self.name}) + rootlog.log(msg, {"machine": self.name}) def log_serial(self, msg: str) -> None: - self.logger.log_serial(msg, self.name) + rootlog.log_serial(msg, self.name) def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager: my_attrs = {"machine": self.name} my_attrs.update(attrs) - return self.logger.nested(msg, my_attrs) + return rootlog.nested(msg, my_attrs) def wait_for_monitor_prompt(self) -> str: assert self.monitor is not None @@ -446,6 +576,7 @@ class Machine: self.connect() out_command = "( set -euo pipefail; {} ); echo '|!=EOF' $?\n".format(command) + assert self.shell self.shell.send(out_command.encode()) output = "" @@ -466,6 +597,8 @@ class Machine: Should only be used during test development, not in the production test.""" self.connect() self.log("Terminal is ready (there is no prompt):") + + assert self.shell subprocess.run( ["socat", "READLINE", f"FD:{self.shell.fileno()}"], pass_fds=[self.shell.fileno()], @@ -534,6 +667,7 @@ class Machine: with self.nested("waiting for the VM to power off"): sys.stdout.flush() + assert self.process self.process.wait() self.pid = None @@ -611,6 +745,8 @@ class Machine: with self.nested("waiting for the VM to finish booting"): self.start() + assert self.shell + tic = time.time() self.shell.recv(1024) # TODO: Timeout @@ -750,65 +886,35 @@ class Machine: self.log("starting vm") - def create_socket(path: str) -> socket.socket: - if os.path.exists(path): - os.unlink(path) + def clear(path: pathlib.Path) -> pathlib.Path: + if path.exists(): + path.unlink() + return path + + def create_socket(path: pathlib.Path) -> socket.socket: s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) - s.bind(path) + s.bind(str(path)) s.listen(1) return s - monitor_path = os.path.join(self.state_dir, "monitor") - self.monitor_socket = create_socket(monitor_path) - - shell_path = os.path.join(self.state_dir, "shell") - self.shell_socket = create_socket(shell_path) - - display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"]) - qemu_options = ( - " ".join( - [ - "" if self.allow_reboot else "-no-reboot", - "-monitor unix:{}".format(monitor_path), - "-chardev socket,id=shell,path={}".format(shell_path), - "-device virtio-serial", - "-device virtconsole,chardev=shell", - "-device virtio-rng-pci", - "-serial stdio" if display_available else "-nographic", - ] - ) - + " " - + os.environ.get("QEMU_OPTS", "") + monitor_socket = create_socket(clear(self.monitor_path)) + shell_socket = create_socket(clear(self.shell_path)) + self.process = self.start_command.run( + self.state_dir, + self.shared_dir, + self.monitor_path, + self.shell_path, ) - - environment = dict(os.environ) - environment.update( - { - "TMPDIR": self.state_dir, - "SHARED_DIR": self.shared_dir, - "USE_TMPDIR": "1", - "QEMU_OPTS": qemu_options, - } - ) - - self.process = subprocess.Popen( - self.script, - stdin=subprocess.DEVNULL, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - shell=True, - cwd=self.state_dir, - env=environment, - ) - self.monitor, _ = self.monitor_socket.accept() - self.shell, _ = self.shell_socket.accept() + self.monitor, _ = monitor_socket.accept() + self.shell, _ = shell_socket.accept() # Store last serial console lines for use # of wait_for_console_text self.last_lines: Queue = Queue() def process_serial_output() -> None: - assert self.process.stdout is not None + assert self.process + assert self.process.stdout for _line in self.process.stdout: # Ignore undecodable bytes that may occur in boot menus line = _line.decode(errors="ignore").replace("\r", "").rstrip() @@ -825,15 +931,15 @@ class Machine: self.log("QEMU running (pid {})".format(self.pid)) def cleanup_statedir(self) -> None: - if os.path.isdir(self.state_dir): - shutil.rmtree(self.state_dir) - self.logger.log(f"deleting VM state directory {self.state_dir}") - self.logger.log("if you want to keep the VM state, pass --keep-vm-state") + shutil.rmtree(self.state_dir) + rootlog.log(f"deleting VM state directory {self.state_dir}") + rootlog.log("if you want to keep the VM state, pass --keep-vm-state") def shutdown(self) -> None: if not self.booted: return + assert self.shell self.shell.send("poweroff\n".encode()) self.wait_for_shutdown() @@ -908,41 +1014,225 @@ class Machine: """Make the machine reachable.""" self.send_monitor_command("set_link virtio-net-pci.1 on") + def release(self) -> None: + if self.pid is None: + return + rootlog.info(f"kill machine (pid {self.pid})") + assert self.process + assert self.shell + assert self.monitor + self.process.terminate() + self.shell.close() + self.monitor.close() + + +class VLan: + """A handle to the vlan with this number, that also knows how to manage + it's lifecycle. + """ -def create_machine(args: Dict[str, Any]) -> Machine: - args["log"] = log - return Machine(args) + nr: int + socket_dir: pathlib.Path + process: Optional[subprocess.Popen] + pid: Optional[int] + fd: Optional[io.TextIOBase] -def start_all() -> None: - with log.nested("starting all VMs"): - for machine in machines: - machine.start() + def __repr__(self) -> str: + return f"" + def __init__(self, nr: int, tmp_dir: pathlib.Path): + self.nr = nr + self.socket_dir = tmp_dir / f"vde{self.nr}.ctl" -def join_all() -> None: - with log.nested("waiting for all VMs to finish"): - for machine in machines: - machine.wait_for_shutdown() + # TODO: don't side-effect environment here + os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir) + def start(self) -> None: -def run_tests(interactive: bool = False) -> None: - if interactive: - ptpython.repl.embed(test_symbols(), {}) - else: - test_script() + rootlog.info("start vlan") + pty_master, pty_slave = pty.openpty() + + self.process = subprocess.Popen( + ["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"], + stdin=pty_slave, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + ) + self.pid = self.process.pid + self.fd = os.fdopen(pty_master, "w") + self.fd.write("version\n") + + # TODO: perl version checks if this can be read from + # an if not, dies. we could hang here forever. Fix it. + assert self.process.stdout is not None + self.process.stdout.readline() + if not (self.socket_dir / "ctl").exists(): + rootlog.error("cannot start vde_switch") + + rootlog.info(f"running vlan (pid {self.pid})") + + def release(self) -> None: + if self.pid is None: + return + rootlog.info(f"kill vlan (pid {self.pid})") + assert self.fd + assert self.process + self.fd.close() + self.process.terminate() + + +class Driver: + """A handle to the driver that sets up the environment + and runs the tests""" + + tests: str + vlans: List[VLan] + machines: List[Machine] + + def __init__( + self, + start_scripts: List[str], + vlans: List[int], + tests: str, + keep_vm_state: bool = False, + ): + self.tests = tests + + tmp_dir = pathlib.Path(os.environ.get("TMPDIR", tempfile.gettempdir())) + tmp_dir.mkdir(mode=0o700, exist_ok=True) + + self.vlans = [VLan(nr, tmp_dir) for nr in vlans] + with rootlog.nested("start all VLans"): + for vlan in self.vlans: + vlan.start() + + def cmd(scripts: List[str]) -> Iterator[NixStartScript]: + for s in scripts: + yield NixStartScript(s) + + self.machines = [ + Machine( + start_command=cmd, + keep_vm_state=keep_vm_state, + name=cmd.machine_name, + tmp_dir=tmp_dir, + ) + for cmd in cmd(start_scripts) + ] + + @atexit.register + def clean_up() -> None: + with rootlog.nested("clean up"): + for machine in self.machines: + machine.release() + for vlan in self.vlans: + vlan.release() + + def subtest(self, name: str) -> Iterator[None]: + """Group logs under a given test name""" + with rootlog.nested(name): + try: + yield + return True + except: + rootlog.error(f'Test "{name}" failed with error:') + raise + + def test_symbols(self) -> Dict[str, Any]: + @contextmanager + def subtest(name: str) -> Iterator[None]: + return self.subtest(name) + + general_symbols = dict( + start_all=self.start_all, + test_script=self.test_script, + machines=self.machines, + vlans=self.vlans, + driver=self, + log=rootlog, + os=os, + create_machine=self.create_machine, + subtest=subtest, + run_tests=self.run_tests, + join_all=self.join_all, + retry=retry, + serial_stdout_off=self.serial_stdout_off, + serial_stdout_on=self.serial_stdout_on, + Machine=Machine, # for typing + ) + machine_symbols = { + m.name: self.machines[idx] for idx, m in enumerate(self.machines) + } + vlan_symbols = { + f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans) + } + print( + "additionally exposed symbols:\n " + + ", ".join(map(lambda m: m.name, self.machines)) + + ",\n " + + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans)) + + ",\n " + + ", ".join(list(general_symbols.keys())) + ) + return {**general_symbols, **machine_symbols, **vlan_symbols} + + def test_script(self) -> None: + """Run the test script""" + with rootlog.nested("run the VM test script"): + symbols = self.test_symbols() # call eagerly + exec(self.tests, symbols, None) + + def run_tests(self) -> None: + """Run the test script (for non-interactive test runs)""" + self.test_script() # TODO: Collect coverage data - for machine in machines: + for machine in self.machines: if machine.is_up(): machine.execute("sync") + def start_all(self) -> None: + """Start all machines""" + with rootlog.nested("start all VMs"): + for machine in self.machines: + machine.start() + + def join_all(self) -> None: + """Wait for all machines to shut down""" + with rootlog.nested("wait for all VMs to finish"): + for machine in self.machines: + machine.wait_for_shutdown() + + def create_machine(self, args: Dict[str, Any]) -> Machine: + rootlog.warning( + "Using legacy create_machine(), please instantiate the" + "Machine class directly, instead" + ) + tmp_dir = pathlib.Path(os.environ.get("TMPDIR", tempfile.gettempdir())) + tmp_dir.mkdir(mode=0o700, exist_ok=True) -def serial_stdout_on() -> None: - log._print_serial_logs = True + if args.get("startCommand"): + start_command: str = args.get("startCommand", "") + cmd = NixStartScript(start_command) + name = args.get("name", cmd.machine_name) + else: + cmd = Machine.create_startcommand(args) # type: ignore + name = args.get("name", "machine") + + return Machine( + tmp_dir=tmp_dir, + start_command=cmd, + name=name, + keep_vm_state=args.get("keep_vm_state", False), + allow_reboot=args.get("allow_reboot", False), + ) + def serial_stdout_on(self) -> None: + rootlog._print_serial_logs = True -def serial_stdout_off() -> None: - log._print_serial_logs = False + def serial_stdout_off(self) -> None: + rootlog._print_serial_logs = False class EnvDefault(argparse.Action): @@ -970,52 +1260,6 @@ class EnvDefault(argparse.Action): setattr(namespace, self.dest, values) -@contextmanager -def subtest(name: str) -> Iterator[None]: - with log.nested(name): - try: - yield - return True - except Exception as e: - log.log(f'Test "{name}" failed with error: "{e}"') - raise e - - return False - - -def _test_symbols() -> Dict[str, Any]: - general_symbols = dict( - start_all=start_all, - test_script=globals().get("test_script"), # same - machines=globals().get("machines"), # without being initialized - log=globals().get("log"), # extracting those symbol keys - os=os, - create_machine=create_machine, - subtest=subtest, - run_tests=run_tests, - join_all=join_all, - retry=retry, - serial_stdout_off=serial_stdout_off, - serial_stdout_on=serial_stdout_on, - Machine=Machine, # for typing - ) - return general_symbols - - -def test_symbols() -> Dict[str, Any]: - - general_symbols = _test_symbols() - - machine_symbols = {m.name: machines[idx] for idx, m in enumerate(machines)} - print( - "additionally exposed symbols:\n " - + ", ".join(map(lambda m: m.name, machines)) - + ",\n " - + ", ".join(list(general_symbols.keys())) - ) - return {**general_symbols, **machine_symbols} - - if __name__ == "__main__": arg_parser = argparse.ArgumentParser(prog="nixos-test-driver") arg_parser.add_argument( @@ -1055,44 +1299,18 @@ if __name__ == "__main__": ) args = arg_parser.parse_args() - testscript = pathlib.Path(args.testscript).read_text() - - global log, machines, test_script - - log = Logger() - - vde_sockets = [create_vlan(v) for v in args.vlans] - for nr, vde_socket, _, _ in vde_sockets: - os.environ["QEMU_VDE_SOCKET_{}".format(nr)] = vde_socket - - machines = [ - create_machine({"startCommand": s, "keepVmState": args.keep_vm_state}) - for s in args.start_scripts - ] - machine_eval = [ - "{0} = machines[{1}]".format(m.name, idx) for idx, m in enumerate(machines) - ] - exec("\n".join(machine_eval)) - - @atexit.register - def clean_up() -> None: - with log.nested("cleaning up"): - for machine in machines: - if machine.pid is None: - continue - log.log("killing {} (pid {})".format(machine.name, machine.pid)) - machine.process.kill() - for _, _, process, _ in vde_sockets: - process.terminate() - log.close() - - def test_script() -> None: - with log.nested("running the VM test script"): - symbols = test_symbols() # call eagerly - exec(testscript, symbols, None) - - interactive = args.interactive or (not bool(testscript)) - tic = time.time() - run_tests(interactive) - toc = time.time() - print("test script finished in {:.2f}s".format(toc - tic)) + + if not args.keep_vm_state: + rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state") + + driver = Driver( + args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state + ) + + if args.interactive: + ptpython.repl.embed(driver.test_symbols(), {}) + else: + tic = time.time() + driver.run_tests() + toc = time.time() + rootlog.info(f"test script finished in {(toc-tic):.2f}s") diff --git a/nixos/lib/testing-python.nix b/nixos/lib/testing-python.nix index 43b4f9b159b..1969f40edb6 100644 --- a/nixos/lib/testing-python.nix +++ b/nixos/lib/testing-python.nix @@ -43,7 +43,8 @@ rec { from pydoc import importfile with open('driver-symbols', 'w') as fp: t = importfile('${testDriverScript}') - test_symbols = t._test_symbols() + d = t.Driver([],[],"") + test_symbols = d.test_symbols() fp.write(','.join(test_symbols.keys())) EOF ''; @@ -188,14 +189,6 @@ rec { --set startScripts "''${vmStartScripts[*]}" \ --set testScript "$out/test-script" \ --set vlans '${toString vlans}' - - ${lib.optionalString (testScript == "") '' - ln -s ${testDriver}/bin/nixos-test-driver $out/bin/nixos-run-vms - wrapProgram $out/bin/nixos-run-vms \ - --set startScripts "''${vmStartScripts[*]}" \ - --set testScript "${pkgs.writeText "start-all" "start_all(); join_all();"}" \ - --set vlans '${toString vlans}' - ''} ''); # Make a full-blown test diff --git a/nixos/modules/installer/tools/nixos-build-vms/build-vms.nix b/nixos/modules/installer/tools/nixos-build-vms/build-vms.nix index e49ceba2424..ce69b16cffa 100644 --- a/nixos/modules/installer/tools/nixos-build-vms/build-vms.nix +++ b/nixos/modules/installer/tools/nixos-build-vms/build-vms.nix @@ -8,11 +8,20 @@ let _file = "${networkExpr}@node-${vm}"; imports = [ module ]; }) (import networkExpr); + + testing = import ../../../../lib/testing-python.nix { + inherit system; + pkgs = import ../../../../.. { inherit system config; }; + }; + + interactiveDriver = (testing.makeTest { inherit nodes; testScript = "start_all(); join_all();"; }).driverInteractive; in -with import ../../../../lib/testing-python.nix { - inherit system; - pkgs = import ../../../../.. { inherit system config; }; -}; -(makeTest { inherit nodes; testScript = ""; }).driverInteractive +pkgs.runCommand "nixos-build-vms" '' + mkdir -p $out/bin + ln -s ${interactiveDriver}/bin/nixos-test-driver $out/bin/nixos-test-driver + ln -s ${interactiveDriver}/bin/nixos-test-driver $out/bin/nixos-run-vms + wrapProgram $out/bin/nixos-test-driver \ + --add-flags "--interactive" +'' -- cgit 1.4.1