summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/bin.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/bin.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/bin.nix77
1 files changed, 50 insertions, 27 deletions
diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix
index 3504c6bf320..7e6b00429df 100644
--- a/pkgs/development/python-modules/jaxlib/bin.nix
+++ b/pkgs/development/python-modules/jaxlib/bin.nix
@@ -24,50 +24,73 @@
 , flatbuffers
 , isPy39
 , lib
+, python
 , scipy
 , stdenv
   # Options:
 , cudaSupport ? config.cudaSupport or false
 }:
 
-# Note that these values are tied to the specific version of the GPU wheel that
-# we fetch. When updating, try to go for the latest possible versions that are
-# still compatible with the cudatoolkit and cudnn versions available in nixpkgs.
+# There are no jaxlib wheels targeting cudnn <8.0.5, and although there are
+# wheels for cudatoolkit <11.1, we don't support them.
 assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1";
 assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5";
 
 let
-  device = if cudaSupport then "gpu" else "cpu";
-in
-buildPythonPackage rec {
-  pname = "jaxlib";
   version = "0.3.0";
-  format = "wheel";
 
-  # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting
-  # all of them is a pain, so we focus on 3.9, the current nixpkgs python3
-  # version.
-  disabled = !isPy39;
+  pythonVersion = python.pythonVersion;
 
-  # Find new releases at https://storage.googleapis.com/jax-releases.
-  src = {
-    cpu = fetchurl {
+  # Find new releases at https://storage.googleapis.com/jax-releases. When
+  # upgrading, you can get these hashes from prefetch.sh.
+  cpuSrcs = {
+    "3.9" = fetchurl {
       url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
-      sha256 = "151p4vqli8x0iqgrzrr8piqk7d76a2xq2krf23jlb142iam5bw01";
+      hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q=";
     };
-    gpu = fetchurl {
-      # Note that there's also a release targeting cuDNN 8.2, but unfortunately
-      # we don't yet have that packaged at the time of writing (02/03/2022).
-      # Check pkgs/development/libraries/science/math/cudnn/default.nix for more
-      # details.
-      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl";
-      sha256 = "0z15rdw3a8sq51rpjmfc41ix1q095aasl79rvlib85ir6f3wh2h8";
+    "3.10" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl";
+      hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ=";
+    };
+  };
 
-      # This is what the cuDNN 8.2 download looks like for future reference:
-      # url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl";
-      # sha256 = "000mnm2masm3sx3haddcmgw43j4gxa3m4fcm14p9nb8dnncjkgpb";
+  gpuSrcs = {
+    "3.9-805" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl";
+      hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw=";
+    };
+    "3.9-82" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl";
+      hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4=";
     };
-  }.${device};
+    "3.10-805" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl";
+      hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U=";
+    };
+    "3.10-82" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl";
+      hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go=";
+    };
+  };
+in
+buildPythonPackage rec {
+  pname = "jaxlib";
+  inherit version;
+  format = "wheel";
+
+  # At the time of writing (2022-03-03), there are releases for <=3.10.
+  # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs
+  # python3 version, and 3.10.
+  disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10");
+
+  src =
+    if !cudaSupport then cpuSrcs."${pythonVersion}" else
+    let
+      # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and
+      # 8.2. Try to use 8.2 whenever possible.
+      cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805";
+    in
+    gpuSrcs."${pythonVersion}-${cudnnVersion}";
 
   # Prebuilt wheels are dynamically linked against things that nix can't find.
   # Run `autoPatchelfHook` to automagically fix them.