summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix29
1 files changed, 11 insertions, 18 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index 6fb960f684c..c70ab0ac2b3 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -54,7 +54,7 @@ let
   inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
 
   pname = "jaxlib";
-  version = "0.4.16";
+  version = "0.4.20";
 
   meta = with lib; {
     description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -95,7 +95,6 @@ let
     "absl_py"
     "astor_archive"
     "astunparse_archive"
-    "boringssl"
     # Not packaged in nixpkgs
     # "com_github_googleapis_googleapis"
     # "com_github_googlecloudplatform_google_cloud_cpp"
@@ -137,8 +136,8 @@ let
 
   arch =
     # KeyError: ('Linux', 'arm64')
-    if stdenv.targetPlatform.isLinux && stdenv.targetPlatform.linuxArch == "arm64" then "aarch64"
-    else stdenv.targetPlatform.linuxArch;
+    if stdenv.hostPlatform.isLinux && stdenv.hostPlatform.linuxArch == "arm64" then "aarch64"
+    else stdenv.hostPlatform.linuxArch;
 
   bazel-build = buildBazelPackage rec {
     name = "bazel-build-${pname}-${version}";
@@ -151,7 +150,7 @@ let
       repo = "jax";
       # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
       rev = "refs/tags/${pname}-v${version}";
-      hash = "sha256-q+8CXGxK8JX0bUMK4KJB3qV/EaLHg68D1B5UrtRz0Eg=";
+      hash = "sha256-WLYXUtchOaA6SGnKuVhN9CmV06xMCLQTEuEtL13ttZU=";
     };
 
     nativeBuildInputs = [
@@ -220,7 +219,7 @@ let
       build --python_path="${python}/bin/python"
       build --distinct_host_configuration=false
       build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
-    '' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) ''
+    '' + lib.optionalString (stdenv.hostPlatform.avxSupport && stdenv.hostPlatform.isUnix) ''
       build --config=avx_posix
     '' + lib.optionalString mklSupport ''
       build --config=mkl_open_source_only
@@ -264,10 +263,10 @@ let
       ];
 
       sha256 = (if cudaSupport then {
-        x86_64-linux = "sha256-6HkrEWAPjGPj4zRxahl0FLiV7WZO/6zsdCX8STfV5EE=";
+        x86_64-linux = "sha256-QczClHxHElLZCqIZlHc3z3DXJ7rZQJaMs2XIb+lxarI=";
       } else {
-        x86_64-linux = "sha256-MDnuJwJ/xKnC72Qub0ETYj5uQB2r8/AgGm10oqmzzcc=";
-        aarch64-linux = "sha256-aVUm612VNEsjZLDrtiOPTqSk1t+AhmOx+pOG3bZdOAw=";
+        x86_64-linux = "sha256-mqiJe4u0NYh1PKCbQfbo0U2e9/kYiBqj98d+BPHFSxQ=";
+        aarch64-linux = "sha256-EuLqamVBJ+qoVMCFIYUT846AghltZolfLGdtO9UeXSM=";
       }).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}");
     };
 
@@ -293,25 +292,19 @@ let
           --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
         substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
           --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
-      '' + (if stdenv.cc.isGNU then ''
-        sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
-        sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
-      '' else if stdenv.cc.isClang then ''
-        sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
-        sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
-      '' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
+      '';
     };
 
     inherit meta;
   };
   platformTag =
-    if stdenv.targetPlatform.isLinux then
+    if stdenv.hostPlatform.isLinux then
       "manylinux2014_${arch}"
     else if stdenv.system == "x86_64-darwin" then
       "macosx_10_9_${arch}"
     else if stdenv.system == "aarch64-darwin" then
       "macosx_11_0_${arch}"
-    else throw "Unsupported target platform: ${stdenv.targetPlatform}";
+    else throw "Unsupported target platform: ${stdenv.hostPlatform}";
 
 in
 buildPythonPackage {