diff options
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/bin.nix')
-rw-r--r-- | pkgs/development/python-modules/jaxlib/bin.nix | 77 |
1 files changed, 50 insertions, 27 deletions
diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index 3504c6bf320..7e6b00429df 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -24,50 +24,73 @@ , flatbuffers , isPy39 , lib +, python , scipy , stdenv # Options: , cudaSupport ? config.cudaSupport or false }: -# Note that these values are tied to the specific version of the GPU wheel that -# we fetch. When updating, try to go for the latest possible versions that are -# still compatible with the cudatoolkit and cudnn versions available in nixpkgs. +# There are no jaxlib wheels targeting cudnn <8.0.5, and although there are +# wheels for cudatoolkit <11.1, we don't support them. assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1"; assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5"; let - device = if cudaSupport then "gpu" else "cpu"; -in -buildPythonPackage rec { - pname = "jaxlib"; version = "0.3.0"; - format = "wheel"; - # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting - # all of them is a pain, so we focus on 3.9, the current nixpkgs python3 - # version. - disabled = !isPy39; + pythonVersion = python.pythonVersion; - # Find new releases at https://storage.googleapis.com/jax-releases. - src = { - cpu = fetchurl { + # Find new releases at https://storage.googleapis.com/jax-releases. When + # upgrading, you can get these hashes from prefetch.sh. + cpuSrcs = { + "3.9" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; - sha256 = "151p4vqli8x0iqgrzrr8piqk7d76a2xq2krf23jlb142iam5bw01"; + hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q="; }; - gpu = fetchurl { - # Note that there's also a release targeting cuDNN 8.2, but unfortunately - # we don't yet have that packaged at the time of writing (02/03/2022). - # Check pkgs/development/libraries/science/math/cudnn/default.nix for more - # details. - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; - sha256 = "0z15rdw3a8sq51rpjmfc41ix1q095aasl79rvlib85ir6f3wh2h8"; + "3.10" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl"; + hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ="; + }; + }; - # This is what the cuDNN 8.2 download looks like for future reference: - # url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; - # sha256 = "000mnm2masm3sx3haddcmgw43j4gxa3m4fcm14p9nb8dnncjkgpb"; + gpuSrcs = { + "3.9-805" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; + hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw="; + }; + "3.9-82" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; + hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4="; }; - }.${device}; + "3.10-805" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl"; + hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U="; + }; + "3.10-82" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl"; + hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go="; + }; + }; +in +buildPythonPackage rec { + pname = "jaxlib"; + inherit version; + format = "wheel"; + + # At the time of writing (2022-03-03), there are releases for <=3.10. + # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs + # python3 version, and 3.10. + disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10"); + + src = + if !cudaSupport then cpuSrcs."${pythonVersion}" else + let + # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and + # 8.2. Try to use 8.2 whenever possible. + cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805"; + in + gpuSrcs."${pythonVersion}-${cudnnVersion}"; # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. |