summary refs log tree commit diff
path: root/nixos/lib/test-driver/test-driver.py
diff options
context:
space:
mode:
authorJörg Thalheim <joerg@thalheim.io>2019-11-08 10:01:29 +0000
committerJörg Thalheim <joerg@thalheim.io>2019-11-11 13:49:48 +0000
commit03e6ca15e205bf892eac08a5f561c3a16284c90a (patch)
tree3fce743ca1a921659ab661cc542136bf14be5134 /nixos/lib/test-driver/test-driver.py
parent556a169f14f4970927b8c18a997dbf323ed9a865 (diff)
downloadnixpkgs-03e6ca15e205bf892eac08a5f561c3a16284c90a.tar
nixpkgs-03e6ca15e205bf892eac08a5f561c3a16284c90a.tar.gz
nixpkgs-03e6ca15e205bf892eac08a5f561c3a16284c90a.tar.bz2
nixpkgs-03e6ca15e205bf892eac08a5f561c3a16284c90a.tar.lz
nixpkgs-03e6ca15e205bf892eac08a5f561c3a16284c90a.tar.xz
nixpkgs-03e6ca15e205bf892eac08a5f561c3a16284c90a.tar.zst
nixpkgs-03e6ca15e205bf892eac08a5f561c3a16284c90a.zip
test-driver: add mypy support
It's a good idea to expand this in future to test code as well,
so we get type checking there as well.
Diffstat (limited to 'nixos/lib/test-driver/test-driver.py')
-rw-r--r--nixos/lib/test-driver/test-driver.py189
1 files changed, 99 insertions, 90 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py
index c8d4936ac52..9d89960876e 100644
--- a/nixos/lib/test-driver/test-driver.py
+++ b/nixos/lib/test-driver/test-driver.py
@@ -1,6 +1,5 @@
 #! /somewhere/python3
-
-from contextlib import contextmanager
+from contextlib import contextmanager, _GeneratorContextManager
 from xml.sax.saxutils import XMLGenerator
 import _thread
 import atexit
@@ -8,7 +7,7 @@ import json
 import os
 import ptpython.repl
 import pty
-import queue
+from queue import Queue, Empty
 import re
 import shutil
 import socket
@@ -17,6 +16,7 @@ import sys
 import tempfile
 import time
 import unicodedata
+from typing import Tuple, TextIO, Any, Callable, Dict, Iterator, Optional, List
 
 CHAR_TO_KEY = {
     "A": "shift-a",
@@ -81,12 +81,18 @@ CHAR_TO_KEY = {
     ")": "shift-0x0B",
 }
 
+# Forward references
+nr_tests: int
+nr_succeeded: int
+log: "Logger"
+machines: "List[Machine]"
+
 
-def eprint(*args, **kwargs):
+def eprint(*args: object, **kwargs: Any) -> None:
     print(*args, file=sys.stderr, **kwargs)
 
 
-def create_vlan(vlan_nr):
+def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
     global log
     log.log("starting VDE switch for network {}".format(vlan_nr))
     vde_socket = os.path.abspath("./vde{}.ctl".format(vlan_nr))
@@ -110,7 +116,7 @@ def create_vlan(vlan_nr):
     return (vlan_nr, vde_socket, vde_process, fd)
 
 
-def retry(fn):
+def retry(fn: Callable) -> None:
     """Call the given function repeatedly, with 1 second intervals,
     until it returns True or a timeout is reached.
     """
@@ -125,52 +131,52 @@ def retry(fn):
 
 
 class Logger:
-    def __init__(self):
+    def __init__(self) -> None:
         self.logfile = os.environ.get("LOGFILE", "/dev/null")
         self.logfile_handle = open(self.logfile, "wb")
         self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
-        self.queue = queue.Queue(1000)
+        self.queue: "Queue[Dict[str, str]]" = Queue(1000)
 
         self.xml.startDocument()
         self.xml.startElement("logfile", attrs={})
 
-    def close(self):
+    def close(self) -> None:
         self.xml.endElement("logfile")
         self.xml.endDocument()
         self.logfile_handle.close()
 
-    def sanitise(self, message):
+    def sanitise(self, message: str) -> str:
         return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
 
-    def maybe_prefix(self, message, attributes):
+    def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
         if "machine" in attributes:
             return "{}: {}".format(attributes["machine"], message)
         return message
 
-    def log_line(self, message, attributes):
+    def log_line(self, message: str, attributes: Dict[str, str]) -> None:
         self.xml.startElement("line", attributes)
         self.xml.characters(message)
         self.xml.endElement("line")
 
-    def log(self, message, attributes={}):
+    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
         eprint(self.maybe_prefix(message, attributes))
         self.drain_log_queue()
         self.log_line(message, attributes)
 
-    def enqueue(self, message):
+    def enqueue(self, message: Dict[str, str]) -> None:
         self.queue.put(message)
 
-    def drain_log_queue(self):
+    def drain_log_queue(self) -> None:
         try:
             while True:
                 item = self.queue.get_nowait()
                 attributes = {"machine": item["machine"], "type": "serial"}
                 self.log_line(self.sanitise(item["msg"]), attributes)
-        except queue.Empty:
+        except Empty:
             pass
 
     @contextmanager
-    def nested(self, message, attributes={}):
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
         eprint(self.maybe_prefix(message, attributes))
 
         self.xml.startElement("nest", attrs={})
@@ -189,24 +195,22 @@ class Logger:
 
 
 class Machine:
-    def __init__(self, args):
+    def __init__(self, args: Dict[str, Any]) -> None:
         if "name" in args:
             self.name = args["name"]
         else:
             self.name = "machine"
-            try:
-                cmd = args["startCommand"]
-                self.name = re.search("run-(.+)-vm$", cmd).group(1)
-            except KeyError:
-                pass
-            except AttributeError:
-                pass
+            cmd = args.get("startCommand", None)
+            if cmd:
+                match = re.search("run-(.+)-vm$", cmd)
+                if match:
+                    self.name = match.group(1)
 
         self.script = args.get("startCommand", self.create_startcommand(args))
 
         tmp_dir = os.environ.get("TMPDIR", tempfile.gettempdir())
 
-        def create_dir(name):
+        def create_dir(name: str) -> str:
             path = os.path.join(tmp_dir, name)
             os.makedirs(path, mode=0o700, exist_ok=True)
             return path
@@ -216,14 +220,14 @@ class Machine:
 
         self.booted = False
         self.connected = False
-        self.pid = None
+        self.pid: Optional[int] = None
         self.socket = None
-        self.monitor = None
-        self.logger = args["log"]
+        self.monitor: Optional[socket.socket] = None
+        self.logger: Logger = args["log"]
         self.allow_reboot = args.get("allowReboot", False)
 
     @staticmethod
-    def create_startcommand(args):
+    def create_startcommand(args: Dict[str, str]) -> str:
         net_backend = "-netdev user,id=net0"
         net_frontend = "-device virtio-net-pci,netdev=net0"
 
@@ -273,30 +277,32 @@ class Machine:
 
         return start_command
 
-    def is_up(self):
+    def is_up(self) -> bool:
         return self.booted and self.connected
 
-    def log(self, msg):
+    def log(self, msg: str) -> None:
         self.logger.log(msg, {"machine": self.name})
 
-    def nested(self, msg, attrs={}):
+    def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
         my_attrs = {"machine": self.name}
         my_attrs.update(attrs)
         return self.logger.nested(msg, my_attrs)
 
-    def wait_for_monitor_prompt(self):
+    def wait_for_monitor_prompt(self) -> str:
+        assert self.monitor is not None
         while True:
             answer = self.monitor.recv(1024).decode()
             if answer.endswith("(qemu) "):
                 return answer
 
-    def send_monitor_command(self, command):
+    def send_monitor_command(self, command: str) -> str:
         message = ("{}\n".format(command)).encode()
         self.log("sending monitor command: {}".format(command))
+        assert self.monitor is not None
         self.monitor.send(message)
         return self.wait_for_monitor_prompt()
 
-    def wait_for_unit(self, unit, user=None):
+    def wait_for_unit(self, unit: str, user: Optional[str] = None) -> bool:
         while True:
             info = self.get_unit_info(unit, user)
             state = info["ActiveState"]
@@ -316,7 +322,7 @@ class Machine:
             if state == "active":
                 return True
 
-    def get_unit_info(self, unit, user=None):
+    def get_unit_info(self, unit: str, user: Optional[str] = None) -> Dict[str, str]:
         status, lines = self.systemctl('--no-pager show "{}"'.format(unit), user)
         if status != 0:
             raise Exception(
@@ -327,8 +333,9 @@ class Machine:
 
         line_pattern = re.compile(r"^([^=]+)=(.*)$")
 
-        def tuple_from_line(line):
+        def tuple_from_line(line: str) -> Tuple[str, str]:
             match = line_pattern.match(line)
+            assert match is not None
             return match[1], match[2]
 
         return dict(
@@ -337,7 +344,7 @@ class Machine:
             if line_pattern.match(line)
         )
 
-    def systemctl(self, q, user=None):
+    def systemctl(self, q: str, user: Optional[str] = None) -> Tuple[int, str]:
         if user is not None:
             q = q.replace("'", "\\'")
             return self.execute(
@@ -349,7 +356,7 @@ class Machine:
             )
         return self.execute("systemctl {}".format(q))
 
-    def require_unit_state(self, unit, require_state="active"):
+    def require_unit_state(self, unit: str, require_state: str = "active") -> None:
         with self.nested(
             "checking if unit ‘{}’ has reached state '{}'".format(unit, require_state)
         ):
@@ -361,7 +368,7 @@ class Machine:
                     + "'active' but it is in state ‘{}’".format(state)
                 )
 
-    def execute(self, command):
+    def execute(self, command: str) -> Tuple[int, str]:
         self.connect()
 
         out_command = "( {} ); echo '|!EOF' $?\n".format(command)
@@ -379,7 +386,7 @@ class Machine:
                 return (status_code, output)
             output += chunk
 
-    def succeed(self, *commands):
+    def succeed(self, *commands: str) -> str:
         """Execute each command and check that it succeeds."""
         output = ""
         for command in commands:
@@ -393,7 +400,7 @@ class Machine:
                 output += out
         return output
 
-    def fail(self, *commands):
+    def fail(self, *commands: str) -> None:
         """Execute each command and check that it fails."""
         for command in commands:
             with self.nested("must fail: {}".format(command)):
@@ -403,21 +410,21 @@ class Machine:
                         "command `{}` unexpectedly succeeded".format(command)
                     )
 
-    def wait_until_succeeds(self, command):
+    def wait_until_succeeds(self, command: str) -> str:
         with self.nested("waiting for success: {}".format(command)):
             while True:
                 status, output = self.execute(command)
                 if status == 0:
                     return output
 
-    def wait_until_fails(self, command):
+    def wait_until_fails(self, command: str) -> str:
         with self.nested("waiting for failure: {}".format(command)):
             while True:
                 status, output = self.execute(command)
                 if status != 0:
                     return output
 
-    def wait_for_shutdown(self):
+    def wait_for_shutdown(self) -> None:
         if not self.booted:
             return
 
@@ -429,14 +436,14 @@ class Machine:
             self.booted = False
             self.connected = False
 
-    def get_tty_text(self, tty):
+    def get_tty_text(self, tty: str) -> str:
         status, output = self.execute(
             "fold -w$(stty -F /dev/tty{0} size | "
             "awk '{{print $2}}') /dev/vcs{0}".format(tty)
         )
         return output
 
-    def wait_until_tty_matches(self, tty, regexp):
+    def wait_until_tty_matches(self, tty: str, regexp: str) -> bool:
         matcher = re.compile(regexp)
         with self.nested("waiting for {} to appear on tty {}".format(regexp, tty)):
             while True:
@@ -444,43 +451,43 @@ class Machine:
                 if len(matcher.findall(text)) > 0:
                     return True
 
-    def send_chars(self, chars):
+    def send_chars(self, chars: List[str]) -> None:
         with self.nested("sending keys ‘{}‘".format(chars)):
             for char in chars:
                 self.send_key(char)
 
-    def wait_for_file(self, filename):
+    def wait_for_file(self, filename: str) -> bool:
         with self.nested("waiting for file ‘{}‘".format(filename)):
             while True:
                 status, _ = self.execute("test -e {}".format(filename))
                 if status == 0:
                     return True
 
-    def wait_for_open_port(self, port):
-        def port_is_open(_):
+    def wait_for_open_port(self, port: int) -> None:
+        def port_is_open(_: Any) -> bool:
             status, _ = self.execute("nc -z localhost {}".format(port))
             return status == 0
 
         with self.nested("waiting for TCP port {}".format(port)):
             retry(port_is_open)
 
-    def wait_for_closed_port(self, port):
-        def port_is_closed(_):
+    def wait_for_closed_port(self, port: int) -> None:
+        def port_is_closed(_: Any) -> bool:
             status, _ = self.execute("nc -z localhost {}".format(port))
             return status != 0
 
         retry(port_is_closed)
 
-    def start_job(self, jobname, user=None):
+    def start_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]:
         return self.systemctl("start {}".format(jobname), user)
 
-    def stop_job(self, jobname, user=None):
+    def stop_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]:
         return self.systemctl("stop {}".format(jobname), user)
 
-    def wait_for_job(self, jobname):
+    def wait_for_job(self, jobname: str) -> bool:
         return self.wait_for_unit(jobname)
 
-    def connect(self):
+    def connect(self) -> None:
         if self.connected:
             return
 
@@ -496,7 +503,7 @@ class Machine:
             self.log("(connecting took {:.2f} seconds)".format(toc - tic))
             self.connected = True
 
-    def screenshot(self, filename):
+    def screenshot(self, filename: str) -> None:
         out_dir = os.environ.get("out", os.getcwd())
         word_pattern = re.compile(r"^\w+$")
         if word_pattern.match(filename):
@@ -513,12 +520,12 @@ class Machine:
             if ret.returncode != 0:
                 raise Exception("Cannot convert screenshot")
 
-    def dump_tty_contents(self, tty):
+    def dump_tty_contents(self, tty: str) -> None:
         """Debugging: Dump the contents of the TTY<n>
         """
         self.execute("fold -w 80 /dev/vcs{} | systemd-cat".format(tty))
 
-    def get_screen_text(self):
+    def get_screen_text(self) -> str:
         if shutil.which("tesseract") is None:
             raise Exception("get_screen_text used but enableOCR is false")
 
@@ -546,30 +553,30 @@ class Machine:
 
                 return ret.stdout.decode("utf-8")
 
-    def wait_for_text(self, regex):
-        def screen_matches(last):
+    def wait_for_text(self, regex: str) -> None:
+        def screen_matches(last: bool) -> bool:
             text = self.get_screen_text()
-            m = re.search(regex, text)
+            matches = re.search(regex, text) is not None
 
-            if last and not m:
+            if last and not matches:
                 self.log("Last OCR attempt failed. Text was: {}".format(text))
 
-            return m
+            return matches
 
         with self.nested("waiting for {} to appear on screen".format(regex)):
             retry(screen_matches)
 
-    def send_key(self, key):
+    def send_key(self, key: str) -> None:
         key = CHAR_TO_KEY.get(key, key)
         self.send_monitor_command("sendkey {}".format(key))
 
-    def start(self):
+    def start(self) -> None:
         if self.booted:
             return
 
         self.log("starting vm")
 
-        def create_socket(path):
+        def create_socket(path: str) -> socket.socket:
             if os.path.exists(path):
                 os.unlink(path)
             s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
@@ -619,9 +626,9 @@ class Machine:
         self.monitor, _ = self.monitor_socket.accept()
         self.shell, _ = self.shell_socket.accept()
 
-        def process_serial_output():
-            for line in self.process.stdout:
-                line = line.decode("unicode_escape").replace("\r", "").rstrip()
+        def process_serial_output() -> None:
+            for _line in self.process.stdout:
+                line = _line.decode("unicode_escape").replace("\r", "").rstrip()
                 eprint("{} # {}".format(self.name, line))
                 self.logger.enqueue({"msg": line, "machine": self.name})
 
@@ -634,14 +641,14 @@ class Machine:
 
         self.log("QEMU running (pid {})".format(self.pid))
 
-    def shutdown(self):
+    def shutdown(self) -> None:
         if not self.booted:
             return
 
         self.shell.send("poweroff\n".encode())
         self.wait_for_shutdown()
 
-    def crash(self):
+    def crash(self) -> None:
         if not self.booted:
             return
 
@@ -649,7 +656,7 @@ class Machine:
         self.send_monitor_command("quit")
         self.wait_for_shutdown()
 
-    def wait_for_x(self):
+    def wait_for_x(self) -> None:
         """Wait until it is possible to connect to the X server.  Note that
         testing the existence of /tmp/.X11-unix/X0 is insufficient.
         """
@@ -666,15 +673,15 @@ class Machine:
                 if status == 0:
                     return
 
-    def get_window_names(self):
+    def get_window_names(self) -> List[str]:
         return self.succeed(
             r"xwininfo -root -tree | sed 's/.*0x[0-9a-f]* \"\([^\"]*\)\".*/\1/; t; d'"
         ).splitlines()
 
-    def wait_for_window(self, regexp):
+    def wait_for_window(self, regexp: str) -> None:
         pattern = re.compile(regexp)
 
-        def window_is_visible(last_try):
+        def window_is_visible(last_try: bool) -> bool:
             names = self.get_window_names()
             if last_try:
                 self.log(
@@ -687,10 +694,10 @@ class Machine:
         with self.nested("Waiting for a window to appear"):
             retry(window_is_visible)
 
-    def sleep(self, secs):
+    def sleep(self, secs: int) -> None:
         time.sleep(secs)
 
-    def forward_port(self, host_port=8080, guest_port=80):
+    def forward_port(self, host_port: int = 8080, guest_port: int = 80) -> None:
         """Forward a TCP port on the host to a TCP port on the guest.
         Useful during interactive testing.
         """
@@ -698,43 +705,46 @@ class Machine:
             "hostfwd_add tcp::{}-:{}".format(host_port, guest_port)
         )
 
-    def block(self):
+    def block(self) -> None:
         """Make the machine unreachable by shutting down eth1 (the multicast
         interface used to talk to the other VMs).  We keep eth0 up so that
         the test driver can continue to talk to the machine.
         """
         self.send_monitor_command("set_link virtio-net-pci.1 off")
 
-    def unblock(self):
+    def unblock(self) -> None:
         """Make the machine reachable.
         """
         self.send_monitor_command("set_link virtio-net-pci.1 on")
 
 
-def create_machine(args):
+def create_machine(args: Dict[str, Any]) -> Machine:
     global log
     args["log"] = log
     args["redirectSerial"] = os.environ.get("USE_SERIAL", "0") == "1"
     return Machine(args)
 
 
-def start_all():
+def start_all() -> None:
+    global machines
     with log.nested("starting all VMs"):
         for machine in machines:
             machine.start()
 
 
-def join_all():
+def join_all() -> None:
+    global machines
     with log.nested("waiting for all VMs to finish"):
         for machine in machines:
             machine.wait_for_shutdown()
 
 
-def test_script():
+def test_script() -> None:
     exec(os.environ["testScript"])
 
 
-def run_tests():
+def run_tests() -> None:
+    global machines
     tests = os.environ.get("tests", None)
     if tests is not None:
         with log.nested("running the VM test script"):
@@ -757,7 +767,7 @@ def run_tests():
 
 
 @contextmanager
-def subtest(name):
+def subtest(name: str) -> Iterator[None]:
     global nr_tests
     global nr_succeeded
 
@@ -774,7 +784,6 @@ def subtest(name):
 
 
 if __name__ == "__main__":
-    global log
     log = Logger()
 
     vlan_nrs = list(dict.fromkeys(os.environ["VLANS"].split()))
@@ -793,7 +802,7 @@ if __name__ == "__main__":
     nr_succeeded = 0
 
     @atexit.register
-    def clean_up():
+    def clean_up() -> None:
         with log.nested("cleaning up"):
             for machine in machines:
                 if machine.pid is None: