summary refs log tree commit diff
path: root/pkgs/development/interpreters/python/update-python-libraries/update-python-libraries.py
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/interpreters/python/update-python-libraries/update-python-libraries.py')
-rwxr-xr-xpkgs/development/interpreters/python/update-python-libraries/update-python-libraries.py144
1 files changed, 112 insertions, 32 deletions
diff --git a/pkgs/development/interpreters/python/update-python-libraries/update-python-libraries.py b/pkgs/development/interpreters/python/update-python-libraries/update-python-libraries.py
index feb44bef079..14b3ed4f3f1 100755
--- a/pkgs/development/interpreters/python/update-python-libraries/update-python-libraries.py
+++ b/pkgs/development/interpreters/python/update-python-libraries/update-python-libraries.py
@@ -12,14 +12,16 @@ to update all non-pinned libraries in that folder.
 """
 
 import argparse
+import json
+import logging
 import os
-import pathlib
 import re
 import requests
 from concurrent.futures import ThreadPoolExecutor as Pool
 from packaging.version import Version as _Version
 from packaging.version import InvalidVersion
 from packaging.specifiers import SpecifierSet
+from typing import Optional, Any
 import collections
 import subprocess
 
@@ -31,11 +33,12 @@ EXTENSIONS = ['tar.gz', 'tar.bz2', 'tar', 'zip', '.whl']
 
 PRERELEASES = False
 
+BULK_UPDATE = False
+
 GIT = "git"
 
-NIXPGKS_ROOT = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode('utf-8').strip()
+NIXPKGS_ROOT = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode('utf-8').strip()
 
-import logging
 logging.basicConfig(level=logging.INFO)
 
 
@@ -67,6 +70,22 @@ def _get_values(attribute, text):
     values = regex.findall(text)
     return values
 
+
+def _get_attr_value(attr_path: str) -> Optional[Any]:
+    try:
+        response = subprocess.check_output([
+            "nix",
+            "--extra-experimental-features", "nix-command",
+            "eval",
+            "-f", f"{NIXPKGS_ROOT}/default.nix",
+            "--json",
+            f"{attr_path}"
+        ])
+        return json.loads(response.decode())
+    except (subprocess.CalledProcessError, ValueError):
+        return None
+
+
 def _get_unique_value(attribute, text):
     """Match attribute in text and return unique match.
 
@@ -81,23 +100,29 @@ def _get_unique_value(attribute, text):
     else:
         raise ValueError("no value found for {}".format(attribute))
 
-def _get_line_and_value(attribute, text):
+def _get_line_and_value(attribute, text, value=None):
     """Match attribute in text. Return the line and the value of the attribute."""
-    regex = '({}\s+=\s+"(.*)";)'.format(attribute)
+    if value is None:
+        regex = rf'({attribute}\s+=\s+\"(.*)\";)'
+    else:
+        regex = rf'({attribute}\s+=\s+\"({value})\";)'
     regex = re.compile(regex)
-    value = regex.findall(text)
-    n = len(value)
+    results = regex.findall(text)
+    n = len(results)
     if n > 1:
         raise ValueError("found too many values for {}".format(attribute))
     elif n == 1:
-        return value[0]
+        return results[0]
     else:
         raise ValueError("no value found for {}".format(attribute))
 
 
-def _replace_value(attribute, value, text):
+def _replace_value(attribute, value, text, oldvalue=None):
     """Search and replace value of attribute in text."""
-    old_line, old_value = _get_line_and_value(attribute, text)
+    if oldvalue is None:
+        old_line, old_value = _get_line_and_value(attribute, text)
+    else:
+        old_line, old_value = _get_line_and_value(attribute, text, oldvalue)
     new_line = old_line.replace(old_value, value)
     new_text = text.replace(old_line, new_line)
     return new_text
@@ -124,6 +149,23 @@ def _fetch_github(url):
         raise ValueError("request for {} failed".format(url))
 
 
+def _hash_to_sri(algorithm, value):
+    """Convert a hash to its SRI representation"""
+    return subprocess.check_output([
+        "nix",
+        "hash",
+        "to-sri",
+        "--type", algorithm,
+        value
+    ]).decode().strip()
+
+
+def _skip_bulk_update(attr_name: str) -> bool:
+    return bool(_get_attr_value(
+        f"{attr_name}.skipBulkUpdate"
+    ))
+
+
 SEMVER = {
     'major' : 0,
     'minor' : 1,
@@ -198,7 +240,7 @@ def _get_latest_version_github(package, extension, current_version, target):
     attr_path = os.environ.get("UPDATE_NIX_ATTR_PATH", f"python3Packages.{package}")
     try:
         homepage = subprocess.check_output(
-            ["nix", "eval", "-f", f"{NIXPGKS_ROOT}/default.nix", "--raw", f"{attr_path}.src.meta.homepage"])\
+            ["nix", "eval", "-f", f"{NIXPKGS_ROOT}/default.nix", "--raw", f"{attr_path}.src.meta.homepage"])\
             .decode('utf-8')
     except Exception as e:
         raise ValueError(f"Unable to determine homepage: {e}")
@@ -217,17 +259,47 @@ def _get_latest_version_github(package, extension, current_version, target):
 
     release = next(filter(lambda x: strip_prefix(x['tag_name']) == version, releases))
     prefix = get_prefix(release['tag_name'])
-    try:
-        sha256 = subprocess.check_output(["nix-prefetch-url", "--type", "sha256", "--unpack", f"{release['tarball_url']}"], stderr=subprocess.DEVNULL)\
-            .decode('utf-8').strip()
-    except:
-        # this may fail if they have both a branch and a tag of the same name, attempt tag name
-        tag_url = str(release['tarball_url']).replace("tarball","tarball/refs/tags")
-        sha256 = subprocess.check_output(["nix-prefetch-url", "--type", "sha256", "--unpack", tag_url], stderr=subprocess.DEVNULL)\
-            .decode('utf-8').strip()
-
 
-    return version, sha256, prefix
+    # some attributes require using the fetchgit
+    git_fetcher_args = []
+    if (_get_attr_value(f"{attr_path}.src.fetchSubmodules")):
+        git_fetcher_args.append("--fetch-submodules")
+    if (_get_attr_value(f"{attr_path}.src.fetchLFS")):
+        git_fetcher_args.append("--fetch-lfs")
+    if (_get_attr_value(f"{attr_path}.src.leaveDotGit")):
+        git_fetcher_args.append("--leave-dotGit")
+
+    if git_fetcher_args:
+        algorithm = "sha256"
+        cmd = [
+            "nix-prefetch-git",
+            f"https://github.com/{owner}/{repo}.git",
+            "--hash", algorithm,
+            "--rev", f"refs/tags/{release['tag_name']}"
+        ]
+        cmd.extend(git_fetcher_args)
+        response = subprocess.check_output(cmd)
+        document = json.loads(response.decode())
+        hash = _hash_to_sri(algorithm, document[algorithm])
+    else:
+        try:
+            hash = subprocess.check_output([
+                "nix-prefetch-url",
+                "--type", "sha256",
+                "--unpack",
+                f"{release['tarball_url']}"
+            ], stderr=subprocess.DEVNULL).decode('utf-8').strip()
+        except (subprocess.CalledProcessError, UnicodeError):
+            # this may fail if they have both a branch and a tag of the same name, attempt tag name
+            tag_url = str(release['tarball_url']).replace("tarball","tarball/refs/tags")
+            hash = subprocess.check_output([
+                "nix-prefetch-url",
+                "--type", "sha256",
+                "--unpack",
+                tag_url
+            ], stderr=subprocess.DEVNULL).decode('utf-8').strip()
+
+    return version, hash, prefix
 
 
 FETCHERS = {
@@ -272,12 +344,12 @@ def _determine_extension(text, fetcher):
     if fetcher == 'fetchPypi':
         try:
             src_format = _get_unique_value('format', text)
-        except ValueError as e:
+        except ValueError:
             src_format = None   # format was not given
 
         try:
             extension = _get_unique_value('extension', text)
-        except ValueError as e:
+        except ValueError:
             extension = None    # extension was not given
 
         if extension is None:
@@ -294,8 +366,6 @@ def _determine_extension(text, fetcher):
             raise ValueError('url does not point to PyPI.')
 
     elif fetcher == 'fetchFromGitHub':
-        if "fetchSubmodules" in text:
-            raise ValueError("fetchFromGitHub fetcher doesn't support submodules")
         extension = "tar.gz"
 
     return extension
@@ -321,6 +391,8 @@ def _update_package(path, target):
     # Attempt a fetch using each pname, e.g. backports-zoneinfo vs backports.zoneinfo
     successful_fetch = False
     for pname in pnames:
+        if BULK_UPDATE and _skip_bulk_update(f"python3Packages.{pname}"):
+            raise ValueError(f"Bulk update skipped for {pname}")
         try:
             new_version, new_sha256, prefix = FETCHERS[fetcher](pname, extension, version, target)
             successful_fetch = True
@@ -340,16 +412,20 @@ def _update_package(path, target):
         raise ValueError("no file available for {}.".format(pname))
 
     text = _replace_value('version', new_version, text)
+
     # hashes from pypi are 16-bit encoded sha256's, normalize it to sri to avoid merge conflicts
     # sri hashes have been the default format since nix 2.4+
-    sri_hash = subprocess.check_output(["nix", "--extra-experimental-features", "nix-command", "hash", "to-sri", "--type", "sha256", new_sha256]).decode('utf-8').strip()
+    sri_hash = _hash_to_sri("sha256", new_sha256)
 
-
-    # fetchers can specify a sha256, or a sri hash
-    try:
-        text = _replace_value('sha256', sri_hash, text)
-    except ValueError:
-        text = _replace_value('hash', sri_hash, text)
+    # retrieve the old output hash for a more precise match
+    if old_hash := _get_attr_value(f"python3Packages.{pname}.src.outputHash"):
+        # fetchers can specify a sha256, or a sri hash
+        try:
+            text = _replace_value('hash', sri_hash, text, old_hash)
+        except ValueError:
+            text = _replace_value('sha256', sri_hash, text, old_hash)
+    else:
+        raise ValueError(f"Unable to retrieve old hash for {pname}")
 
     if fetcher == 'fetchFromGitHub':
         # in the case of fetchFromGitHub, it's common to see `rev = version;` or `rev = "v${version}";`
@@ -441,6 +517,10 @@ environment variables:
 
     packages = list(map(os.path.abspath, args.package))
 
+    if len(packages) > 1:
+        global BULK_UPDATE
+        BULK_UPDATE = true
+
     logging.info("Updating packages...")
 
     # Use threads to update packages concurrently