summary refs log tree commit diff
path: root/nixos/lib/test-driver/test-driver.py
diff options
context:
space:
mode:
Diffstat (limited to 'nixos/lib/test-driver/test-driver.py')
-rw-r--r--nixos/lib/test-driver/test-driver.py184
1 files changed, 120 insertions, 64 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py
index f4e2bb6100f..2a3e4d94b94 100644
--- a/nixos/lib/test-driver/test-driver.py
+++ b/nixos/lib/test-driver/test-driver.py
@@ -1,8 +1,9 @@
 #! /somewhere/python3
 from contextlib import contextmanager, _GeneratorContextManager
 from queue import Queue, Empty
-from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List
+from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List, Iterable
 from xml.sax.saxutils import XMLGenerator
+from colorama import Style
 import queue
 import io
 import _thread
@@ -20,6 +21,7 @@ import shutil
 import socket
 import subprocess
 import sys
+import telnetlib
 import tempfile
 import time
 import traceback
@@ -110,7 +112,6 @@ def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]
     pty_master, pty_slave = pty.openpty()
     vde_process = subprocess.Popen(
         ["vde_switch", "-s", vde_socket, "--dirmode", "0700"],
-        bufsize=1,
         stdin=pty_slave,
         stdout=subprocess.PIPE,
         stderr=subprocess.PIPE,
@@ -128,18 +129,18 @@ def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]
     return (vlan_nr, vde_socket, vde_process, fd)
 
 
-def retry(fn: Callable) -> None:
+def retry(fn: Callable, timeout: int = 900) -> None:
     """Call the given function repeatedly, with 1 second intervals,
     until it returns True or a timeout is reached.
     """
 
-    for _ in range(900):
+    for _ in range(timeout):
         if fn(False):
             return
         time.sleep(1)
 
     if not fn(True):
-        raise Exception("action timed out")
+        raise Exception(f"action timed out after {timeout} seconds")
 
 
 class Logger:
@@ -152,6 +153,8 @@ class Logger:
         self.xml.startDocument()
         self.xml.startElement("logfile", attrs={})
 
+        self._print_serial_logs = True
+
     def close(self) -> None:
         self.xml.endElement("logfile")
         self.xml.endDocument()
@@ -175,15 +178,21 @@ class Logger:
         self.drain_log_queue()
         self.log_line(message, attributes)
 
-    def enqueue(self, message: Dict[str, str]) -> None:
-        self.queue.put(message)
+    def log_serial(self, message: str, machine: str) -> None:
+        self.enqueue({"msg": message, "machine": machine, "type": "serial"})
+        if self._print_serial_logs:
+            eprint(Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL)
+
+    def enqueue(self, item: Dict[str, str]) -> None:
+        self.queue.put(item)
 
     def drain_log_queue(self) -> None:
         try:
             while True:
                 item = self.queue.get_nowait()
-                attributes = {"machine": item["machine"], "type": "serial"}
-                self.log_line(self.sanitise(item["msg"]), attributes)
+                msg = self.sanitise(item["msg"])
+                del item["msg"]
+                self.log_line(msg, item)
         except Empty:
             pass
 
@@ -206,6 +215,37 @@ class Logger:
         self.xml.endElement("nest")
 
 
+def _perform_ocr_on_screenshot(
+    screenshot_path: str, model_ids: Iterable[int]
+) -> List[str]:
+    if shutil.which("tesseract") is None:
+        raise Exception("OCR requested but enableOCR is false")
+
+    magick_args = (
+        "-filter Catrom -density 72 -resample 300 "
+        + "-contrast -normalize -despeckle -type grayscale "
+        + "-sharpen 1 -posterize 3 -negate -gamma 100 "
+        + "-blur 1x65535"
+    )
+
+    tess_args = f"-c debug_file=/dev/null --psm 11"
+
+    cmd = f"convert {magick_args} {screenshot_path} tiff:{screenshot_path}.tiff"
+    ret = subprocess.run(cmd, shell=True, capture_output=True)
+    if ret.returncode != 0:
+        raise Exception(f"TIFF conversion failed with exit code {ret.returncode}")
+
+    model_results = []
+    for model_id in model_ids:
+        cmd = f"tesseract {screenshot_path}.tiff - {tess_args} --oem {model_id}"
+        ret = subprocess.run(cmd, shell=True, capture_output=True)
+        if ret.returncode != 0:
+            raise Exception(f"OCR failed with exit code {ret.returncode}")
+        model_results.append(ret.stdout.decode("utf-8"))
+
+    return model_results
+
+
 class Machine:
     def __init__(self, args: Dict[str, Any]) -> None:
         if "name" in args:
@@ -217,7 +257,7 @@ class Machine:
                 match = re.search("run-(.+)-vm$", cmd)
                 if match:
                     self.name = match.group(1)
-
+        self.logger = args["log"]
         self.script = args.get("startCommand", self.create_startcommand(args))
 
         tmp_dir = os.environ.get("TMPDIR", tempfile.gettempdir())
@@ -227,7 +267,10 @@ class Machine:
             os.makedirs(path, mode=0o700, exist_ok=True)
             return path
 
-        self.state_dir = create_dir("vm-state-{}".format(self.name))
+        self.state_dir = os.path.join(tmp_dir, f"vm-state-{self.name}")
+        if not args.get("keepVmState", False):
+            self.cleanup_statedir()
+        os.makedirs(self.state_dir, mode=0o700, exist_ok=True)
         self.shared_dir = create_dir("shared-xchg")
 
         self.booted = False
@@ -235,7 +278,6 @@ class Machine:
         self.pid: Optional[int] = None
         self.socket = None
         self.monitor: Optional[socket.socket] = None
-        self.logger: Logger = args["log"]
         self.allow_reboot = args.get("allowReboot", False)
 
     @staticmethod
@@ -250,7 +292,12 @@ class Machine:
             net_frontend += "," + args["netFrontendArgs"]
 
         start_command = (
-            "qemu-kvm -m 384 " + net_backend + " " + net_frontend + " $QEMU_OPTS "
+            args.get("qemuBinary", "qemu-kvm")
+            + " -m 384 "
+            + net_backend
+            + " "
+            + net_frontend
+            + " $QEMU_OPTS "
         )
 
         if "hda" in args:
@@ -275,8 +322,9 @@ class Machine:
             start_command += "-cdrom " + args["cdrom"] + " "
 
         if "usb" in args:
+            # https://github.com/qemu/qemu/blob/master/docs/usb2.txt
             start_command += (
-                "-device piix3-usb-uhci -drive "
+                "-device usb-ehci -drive "
                 + "id=usbdisk,file="
                 + args["usb"]
                 + ",if=none,readonly "
@@ -295,6 +343,9 @@ class Machine:
     def log(self, msg: str) -> None:
         self.logger.log(msg, {"machine": self.name})
 
+    def log_serial(self, msg: str) -> None:
+        self.logger.log_serial(msg, self.name)
+
     def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
         my_attrs = {"machine": self.name}
         my_attrs.update(attrs)
@@ -395,7 +446,7 @@ class Machine:
     def execute(self, command: str) -> Tuple[int, str]:
         self.connect()
 
-        out_command = "( {} ); echo '|!=EOF' $?\n".format(command)
+        out_command = "( set -euo pipefail; {} ); echo '|!=EOF' $?\n".format(command)
         self.shell.send(out_command.encode())
 
         output = ""
@@ -410,6 +461,17 @@ class Machine:
                 return (status_code, output)
             output += chunk
 
+    def shell_interact(self) -> None:
+        """Allows you to interact with the guest shell
+
+        Should only be used during test development, not in the production test."""
+        self.connect()
+        self.log("Terminal is ready (there is no prompt):")
+        subprocess.run(
+            ["socat", "READLINE", f"FD:{self.shell.fileno()}"],
+            pass_fds=[self.shell.fileno()],
+        )
+
     def succeed(self, *commands: str) -> str:
         """Execute each command and check that it succeeds."""
         output = ""
@@ -437,7 +499,7 @@ class Machine:
                 output += out
         return output
 
-    def wait_until_succeeds(self, command: str) -> str:
+    def wait_until_succeeds(self, command: str, timeout: int = 900) -> str:
         """Wait until a command returns success and return its output.
         Throws an exception on timeout.
         """
@@ -449,7 +511,7 @@ class Machine:
             return status == 0
 
         with self.nested("waiting for success: {}".format(command)):
-            retry(check_success)
+            retry(check_success, timeout)
             return output
 
     def wait_until_fails(self, command: str) -> str:
@@ -633,47 +695,32 @@ class Machine:
                 shutil.copy(intermediate, abs_target)
 
     def dump_tty_contents(self, tty: str) -> None:
-        """Debugging: Dump the contents of the TTY<n>
-        """
+        """Debugging: Dump the contents of the TTY<n>"""
         self.execute("fold -w 80 /dev/vcs{} | systemd-cat".format(tty))
 
-    def get_screen_text(self) -> str:
-        if shutil.which("tesseract") is None:
-            raise Exception("get_screen_text used but enableOCR is false")
-
-        magick_args = (
-            "-filter Catrom -density 72 -resample 300 "
-            + "-contrast -normalize -despeckle -type grayscale "
-            + "-sharpen 1 -posterize 3 -negate -gamma 100 "
-            + "-blur 1x65535"
-        )
-
-        tess_args = "-c debug_file=/dev/null --psm 11 --oem 2"
+    def _get_screen_text_variants(self, model_ids: Iterable[int]) -> List[str]:
+        with tempfile.TemporaryDirectory() as tmpdir:
+            screenshot_path = os.path.join(tmpdir, "ppm")
+            self.send_monitor_command(f"screendump {screenshot_path}")
+            return _perform_ocr_on_screenshot(screenshot_path, model_ids)
 
-        with self.nested("performing optical character recognition"):
-            with tempfile.NamedTemporaryFile() as tmpin:
-                self.send_monitor_command("screendump {}".format(tmpin.name))
+    def get_screen_text_variants(self) -> List[str]:
+        return self._get_screen_text_variants([0, 1, 2])
 
-                cmd = "convert {} {} tiff:- | tesseract - - {}".format(
-                    magick_args, tmpin.name, tess_args
-                )
-                ret = subprocess.run(cmd, shell=True, capture_output=True)
-                if ret.returncode != 0:
-                    raise Exception(
-                        "OCR failed with exit code {}".format(ret.returncode)
-                    )
-
-                return ret.stdout.decode("utf-8")
+    def get_screen_text(self) -> str:
+        return self._get_screen_text_variants([2])[0]
 
     def wait_for_text(self, regex: str) -> None:
         def screen_matches(last: bool) -> bool:
-            text = self.get_screen_text()
-            matches = re.search(regex, text) is not None
+            variants = self.get_screen_text_variants()
+            for text in variants:
+                if re.search(regex, text) is not None:
+                    return True
 
-            if last and not matches:
-                self.log("Last OCR attempt failed. Text was: {}".format(text))
+            if last:
+                self.log("Last OCR attempt failed. Text was: {}".format(variants))
 
-            return matches
+            return False
 
         with self.nested("waiting for {} to appear on screen".format(regex)):
             retry(screen_matches)
@@ -718,6 +765,7 @@ class Machine:
         shell_path = os.path.join(self.state_dir, "shell")
         self.shell_socket = create_socket(shell_path)
 
+        display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"])
         qemu_options = (
             " ".join(
                 [
@@ -727,7 +775,7 @@ class Machine:
                     "-device virtio-serial",
                     "-device virtconsole,chardev=shell",
                     "-device virtio-rng-pci",
-                    "-serial stdio" if "DISPLAY" in os.environ else "-nographic",
+                    "-serial stdio" if display_available else "-nographic",
                 ]
             )
             + " "
@@ -746,7 +794,6 @@ class Machine:
 
         self.process = subprocess.Popen(
             self.script,
-            bufsize=1,
             stdin=subprocess.DEVNULL,
             stdout=subprocess.PIPE,
             stderr=subprocess.STDOUT,
@@ -767,8 +814,7 @@ class Machine:
                 # Ignore undecodable bytes that may occur in boot menus
                 line = _line.decode(errors="ignore").replace("\r", "").rstrip()
                 self.last_lines.put(line)
-                eprint("{} # {}".format(self.name, line))
-                self.logger.enqueue({"msg": line, "machine": self.name})
+                self.log_serial(line)
 
         _thread.start_new_thread(process_serial_output, ())
 
@@ -780,9 +826,10 @@ class Machine:
         self.log("QEMU running (pid {})".format(self.pid))
 
     def cleanup_statedir(self) -> None:
-        self.log("delete the VM state directory")
-        if os.path.isfile(self.state_dir):
+        if os.path.isdir(self.state_dir):
             shutil.rmtree(self.state_dir)
+            self.logger.log(f"deleting VM state directory {self.state_dir}")
+            self.logger.log("if you want to keep the VM state, pass --keep-vm-state")
 
     def shutdown(self) -> None:
         if not self.booted:
@@ -840,7 +887,8 @@ class Machine:
             retry(window_is_visible)
 
     def sleep(self, secs: int) -> None:
-        time.sleep(secs)
+        # We want to sleep in *guest* time, not *host* time.
+        self.succeed(f"sleep {secs}")
 
     def forward_port(self, host_port: int = 8080, guest_port: int = 80) -> None:
         """Forward a TCP port on the host to a TCP port on the guest.
@@ -858,15 +906,13 @@ class Machine:
         self.send_monitor_command("set_link virtio-net-pci.1 off")
 
     def unblock(self) -> None:
-        """Make the machine reachable.
-        """
+        """Make the machine reachable."""
         self.send_monitor_command("set_link virtio-net-pci.1 on")
 
 
 def create_machine(args: Dict[str, Any]) -> Machine:
     global log
     args["log"] = log
-    args["redirectSerial"] = os.environ.get("USE_SERIAL", "0") == "1"
     return Machine(args)
 
 
@@ -909,6 +955,16 @@ def run_tests() -> None:
             machine.execute("sync")
 
 
+def serial_stdout_on() -> None:
+    global log
+    log._print_serial_logs = True
+
+
+def serial_stdout_off() -> None:
+    global log
+    log._print_serial_logs = False
+
+
 @contextmanager
 def subtest(name: str) -> Iterator[None]:
     with log.nested(name):
@@ -923,7 +979,7 @@ def subtest(name: str) -> Iterator[None]:
 
 
 if __name__ == "__main__":
-    arg_parser = argparse.ArgumentParser()
+    arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
     arg_parser.add_argument(
         "-K",
         "--keep-vm-state",
@@ -939,10 +995,10 @@ if __name__ == "__main__":
     for nr, vde_socket, _, _ in vde_sockets:
         os.environ["QEMU_VDE_SOCKET_{}".format(nr)] = vde_socket
 
-    machines = [create_machine({"startCommand": s}) for s in vm_scripts]
-    for machine in machines:
-        if not cli_args.keep_vm_state:
-            machine.cleanup_statedir()
+    machines = [
+        create_machine({"startCommand": s, "keepVmState": cli_args.keep_vm_state})
+        for s in vm_scripts
+    ]
     machine_eval = [
         "{0} = machines[{1}]".format(m.name, idx) for idx, m in enumerate(machines)
     ]