diff options
Diffstat (limited to 'nixos/lib/test-driver/test-driver.py')
-rw-r--r-- | nixos/lib/test-driver/test-driver.py | 184 |
1 files changed, 120 insertions, 64 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py index f4e2bb6100f..2a3e4d94b94 100644 --- a/nixos/lib/test-driver/test-driver.py +++ b/nixos/lib/test-driver/test-driver.py @@ -1,8 +1,9 @@ #! /somewhere/python3 from contextlib import contextmanager, _GeneratorContextManager from queue import Queue, Empty -from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List +from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List, Iterable from xml.sax.saxutils import XMLGenerator +from colorama import Style import queue import io import _thread @@ -20,6 +21,7 @@ import shutil import socket import subprocess import sys +import telnetlib import tempfile import time import traceback @@ -110,7 +112,6 @@ def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any] pty_master, pty_slave = pty.openpty() vde_process = subprocess.Popen( ["vde_switch", "-s", vde_socket, "--dirmode", "0700"], - bufsize=1, stdin=pty_slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -128,18 +129,18 @@ def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any] return (vlan_nr, vde_socket, vde_process, fd) -def retry(fn: Callable) -> None: +def retry(fn: Callable, timeout: int = 900) -> None: """Call the given function repeatedly, with 1 second intervals, until it returns True or a timeout is reached. """ - for _ in range(900): + for _ in range(timeout): if fn(False): return time.sleep(1) if not fn(True): - raise Exception("action timed out") + raise Exception(f"action timed out after {timeout} seconds") class Logger: @@ -152,6 +153,8 @@ class Logger: self.xml.startDocument() self.xml.startElement("logfile", attrs={}) + self._print_serial_logs = True + def close(self) -> None: self.xml.endElement("logfile") self.xml.endDocument() @@ -175,15 +178,21 @@ class Logger: self.drain_log_queue() self.log_line(message, attributes) - def enqueue(self, message: Dict[str, str]) -> None: - self.queue.put(message) + def log_serial(self, message: str, machine: str) -> None: + self.enqueue({"msg": message, "machine": machine, "type": "serial"}) + if self._print_serial_logs: + eprint(Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL) + + def enqueue(self, item: Dict[str, str]) -> None: + self.queue.put(item) def drain_log_queue(self) -> None: try: while True: item = self.queue.get_nowait() - attributes = {"machine": item["machine"], "type": "serial"} - self.log_line(self.sanitise(item["msg"]), attributes) + msg = self.sanitise(item["msg"]) + del item["msg"] + self.log_line(msg, item) except Empty: pass @@ -206,6 +215,37 @@ class Logger: self.xml.endElement("nest") +def _perform_ocr_on_screenshot( + screenshot_path: str, model_ids: Iterable[int] +) -> List[str]: + if shutil.which("tesseract") is None: + raise Exception("OCR requested but enableOCR is false") + + magick_args = ( + "-filter Catrom -density 72 -resample 300 " + + "-contrast -normalize -despeckle -type grayscale " + + "-sharpen 1 -posterize 3 -negate -gamma 100 " + + "-blur 1x65535" + ) + + tess_args = f"-c debug_file=/dev/null --psm 11" + + cmd = f"convert {magick_args} {screenshot_path} tiff:{screenshot_path}.tiff" + ret = subprocess.run(cmd, shell=True, capture_output=True) + if ret.returncode != 0: + raise Exception(f"TIFF conversion failed with exit code {ret.returncode}") + + model_results = [] + for model_id in model_ids: + cmd = f"tesseract {screenshot_path}.tiff - {tess_args} --oem {model_id}" + ret = subprocess.run(cmd, shell=True, capture_output=True) + if ret.returncode != 0: + raise Exception(f"OCR failed with exit code {ret.returncode}") + model_results.append(ret.stdout.decode("utf-8")) + + return model_results + + class Machine: def __init__(self, args: Dict[str, Any]) -> None: if "name" in args: @@ -217,7 +257,7 @@ class Machine: match = re.search("run-(.+)-vm$", cmd) if match: self.name = match.group(1) - + self.logger = args["log"] self.script = args.get("startCommand", self.create_startcommand(args)) tmp_dir = os.environ.get("TMPDIR", tempfile.gettempdir()) @@ -227,7 +267,10 @@ class Machine: os.makedirs(path, mode=0o700, exist_ok=True) return path - self.state_dir = create_dir("vm-state-{}".format(self.name)) + self.state_dir = os.path.join(tmp_dir, f"vm-state-{self.name}") + if not args.get("keepVmState", False): + self.cleanup_statedir() + os.makedirs(self.state_dir, mode=0o700, exist_ok=True) self.shared_dir = create_dir("shared-xchg") self.booted = False @@ -235,7 +278,6 @@ class Machine: self.pid: Optional[int] = None self.socket = None self.monitor: Optional[socket.socket] = None - self.logger: Logger = args["log"] self.allow_reboot = args.get("allowReboot", False) @staticmethod @@ -250,7 +292,12 @@ class Machine: net_frontend += "," + args["netFrontendArgs"] start_command = ( - "qemu-kvm -m 384 " + net_backend + " " + net_frontend + " $QEMU_OPTS " + args.get("qemuBinary", "qemu-kvm") + + " -m 384 " + + net_backend + + " " + + net_frontend + + " $QEMU_OPTS " ) if "hda" in args: @@ -275,8 +322,9 @@ class Machine: start_command += "-cdrom " + args["cdrom"] + " " if "usb" in args: + # https://github.com/qemu/qemu/blob/master/docs/usb2.txt start_command += ( - "-device piix3-usb-uhci -drive " + "-device usb-ehci -drive " + "id=usbdisk,file=" + args["usb"] + ",if=none,readonly " @@ -295,6 +343,9 @@ class Machine: def log(self, msg: str) -> None: self.logger.log(msg, {"machine": self.name}) + def log_serial(self, msg: str) -> None: + self.logger.log_serial(msg, self.name) + def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager: my_attrs = {"machine": self.name} my_attrs.update(attrs) @@ -395,7 +446,7 @@ class Machine: def execute(self, command: str) -> Tuple[int, str]: self.connect() - out_command = "( {} ); echo '|!=EOF' $?\n".format(command) + out_command = "( set -euo pipefail; {} ); echo '|!=EOF' $?\n".format(command) self.shell.send(out_command.encode()) output = "" @@ -410,6 +461,17 @@ class Machine: return (status_code, output) output += chunk + def shell_interact(self) -> None: + """Allows you to interact with the guest shell + + Should only be used during test development, not in the production test.""" + self.connect() + self.log("Terminal is ready (there is no prompt):") + subprocess.run( + ["socat", "READLINE", f"FD:{self.shell.fileno()}"], + pass_fds=[self.shell.fileno()], + ) + def succeed(self, *commands: str) -> str: """Execute each command and check that it succeeds.""" output = "" @@ -437,7 +499,7 @@ class Machine: output += out return output - def wait_until_succeeds(self, command: str) -> str: + def wait_until_succeeds(self, command: str, timeout: int = 900) -> str: """Wait until a command returns success and return its output. Throws an exception on timeout. """ @@ -449,7 +511,7 @@ class Machine: return status == 0 with self.nested("waiting for success: {}".format(command)): - retry(check_success) + retry(check_success, timeout) return output def wait_until_fails(self, command: str) -> str: @@ -633,47 +695,32 @@ class Machine: shutil.copy(intermediate, abs_target) def dump_tty_contents(self, tty: str) -> None: - """Debugging: Dump the contents of the TTY<n> - """ + """Debugging: Dump the contents of the TTY<n>""" self.execute("fold -w 80 /dev/vcs{} | systemd-cat".format(tty)) - def get_screen_text(self) -> str: - if shutil.which("tesseract") is None: - raise Exception("get_screen_text used but enableOCR is false") - - magick_args = ( - "-filter Catrom -density 72 -resample 300 " - + "-contrast -normalize -despeckle -type grayscale " - + "-sharpen 1 -posterize 3 -negate -gamma 100 " - + "-blur 1x65535" - ) - - tess_args = "-c debug_file=/dev/null --psm 11 --oem 2" + def _get_screen_text_variants(self, model_ids: Iterable[int]) -> List[str]: + with tempfile.TemporaryDirectory() as tmpdir: + screenshot_path = os.path.join(tmpdir, "ppm") + self.send_monitor_command(f"screendump {screenshot_path}") + return _perform_ocr_on_screenshot(screenshot_path, model_ids) - with self.nested("performing optical character recognition"): - with tempfile.NamedTemporaryFile() as tmpin: - self.send_monitor_command("screendump {}".format(tmpin.name)) + def get_screen_text_variants(self) -> List[str]: + return self._get_screen_text_variants([0, 1, 2]) - cmd = "convert {} {} tiff:- | tesseract - - {}".format( - magick_args, tmpin.name, tess_args - ) - ret = subprocess.run(cmd, shell=True, capture_output=True) - if ret.returncode != 0: - raise Exception( - "OCR failed with exit code {}".format(ret.returncode) - ) - - return ret.stdout.decode("utf-8") + def get_screen_text(self) -> str: + return self._get_screen_text_variants([2])[0] def wait_for_text(self, regex: str) -> None: def screen_matches(last: bool) -> bool: - text = self.get_screen_text() - matches = re.search(regex, text) is not None + variants = self.get_screen_text_variants() + for text in variants: + if re.search(regex, text) is not None: + return True - if last and not matches: - self.log("Last OCR attempt failed. Text was: {}".format(text)) + if last: + self.log("Last OCR attempt failed. Text was: {}".format(variants)) - return matches + return False with self.nested("waiting for {} to appear on screen".format(regex)): retry(screen_matches) @@ -718,6 +765,7 @@ class Machine: shell_path = os.path.join(self.state_dir, "shell") self.shell_socket = create_socket(shell_path) + display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"]) qemu_options = ( " ".join( [ @@ -727,7 +775,7 @@ class Machine: "-device virtio-serial", "-device virtconsole,chardev=shell", "-device virtio-rng-pci", - "-serial stdio" if "DISPLAY" in os.environ else "-nographic", + "-serial stdio" if display_available else "-nographic", ] ) + " " @@ -746,7 +794,6 @@ class Machine: self.process = subprocess.Popen( self.script, - bufsize=1, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, @@ -767,8 +814,7 @@ class Machine: # Ignore undecodable bytes that may occur in boot menus line = _line.decode(errors="ignore").replace("\r", "").rstrip() self.last_lines.put(line) - eprint("{} # {}".format(self.name, line)) - self.logger.enqueue({"msg": line, "machine": self.name}) + self.log_serial(line) _thread.start_new_thread(process_serial_output, ()) @@ -780,9 +826,10 @@ class Machine: self.log("QEMU running (pid {})".format(self.pid)) def cleanup_statedir(self) -> None: - self.log("delete the VM state directory") - if os.path.isfile(self.state_dir): + if os.path.isdir(self.state_dir): shutil.rmtree(self.state_dir) + self.logger.log(f"deleting VM state directory {self.state_dir}") + self.logger.log("if you want to keep the VM state, pass --keep-vm-state") def shutdown(self) -> None: if not self.booted: @@ -840,7 +887,8 @@ class Machine: retry(window_is_visible) def sleep(self, secs: int) -> None: - time.sleep(secs) + # We want to sleep in *guest* time, not *host* time. + self.succeed(f"sleep {secs}") def forward_port(self, host_port: int = 8080, guest_port: int = 80) -> None: """Forward a TCP port on the host to a TCP port on the guest. @@ -858,15 +906,13 @@ class Machine: self.send_monitor_command("set_link virtio-net-pci.1 off") def unblock(self) -> None: - """Make the machine reachable. - """ + """Make the machine reachable.""" self.send_monitor_command("set_link virtio-net-pci.1 on") def create_machine(args: Dict[str, Any]) -> Machine: global log args["log"] = log - args["redirectSerial"] = os.environ.get("USE_SERIAL", "0") == "1" return Machine(args) @@ -909,6 +955,16 @@ def run_tests() -> None: machine.execute("sync") +def serial_stdout_on() -> None: + global log + log._print_serial_logs = True + + +def serial_stdout_off() -> None: + global log + log._print_serial_logs = False + + @contextmanager def subtest(name: str) -> Iterator[None]: with log.nested(name): @@ -923,7 +979,7 @@ def subtest(name: str) -> Iterator[None]: if __name__ == "__main__": - arg_parser = argparse.ArgumentParser() + arg_parser = argparse.ArgumentParser(prog="nixos-test-driver") arg_parser.add_argument( "-K", "--keep-vm-state", @@ -939,10 +995,10 @@ if __name__ == "__main__": for nr, vde_socket, _, _ in vde_sockets: os.environ["QEMU_VDE_SOCKET_{}".format(nr)] = vde_socket - machines = [create_machine({"startCommand": s}) for s in vm_scripts] - for machine in machines: - if not cli_args.keep_vm_state: - machine.cleanup_statedir() + machines = [ + create_machine({"startCommand": s, "keepVmState": cli_args.keep_vm_state}) + for s in vm_scripts + ] machine_eval = [ "{0} = machines[{1}]".format(m.name, idx) for idx, m in enumerate(machines) ] |