summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix39
1 files changed, 23 insertions, 16 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index eee432f7185..456c9108593 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -9,11 +9,14 @@
 , buildBazelPackage
 , buildPythonPackage
 , cctools
+, curl
 , cython
 , fetchFromGitHub
 , git
 , IOKit
 , jsoncpp
+, nsync
+, openssl
 , pybind11
 , setuptools
 , symlinkJoin
@@ -50,7 +53,7 @@ let
   inherit (cudaPackages) cudatoolkit cudnn nccl;
 
   pname = "jaxlib";
-  version = "0.3.0";
+  version = "0.3.15";
 
   meta = with lib; {
     description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -93,7 +96,7 @@ let
       owner = "google";
       repo = "jax";
       rev = "${pname}-v${version}";
-      sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72";
+      sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s=";
     };
 
     nativeBuildInputs = [
@@ -103,15 +106,19 @@ let
       setuptools
       wheel
       which
+    ] ++ lib.optionals stdenv.isDarwin [
+      cctools
     ];
 
     buildInputs = [
+      curl
       double-conversion
       giflib
       grpc
       jsoncpp
       libjpeg_turbo
       numpy
+      openssl
       pkgs.flatbuffers
       pkgs.protobuf
       pybind11
@@ -124,6 +131,8 @@ let
       cudnn
     ] ++ lib.optionals stdenv.isDarwin [
       IOKit
+    ] ++ lib.optionals (!stdenv.isDarwin) [
+      nsync
     ];
 
     postPatch = ''
@@ -149,6 +158,7 @@ let
       build --action_env=PYENV_ROOT
       build --python_path="${python}/bin/python"
       build --distinct_host_configuration=false
+      build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
     '' + lib.optionalString cudaSupport ''
       build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
       build --action_env CUDNN_INSTALL_PATH="${cudnn}"
@@ -163,7 +173,7 @@ let
     # Copy-paste from TF derivation.
     # Most of these are not really used in jaxlib compilation but it's simpler to keep it
     # 'as is' so that it's more compatible with TF derivation.
-    TF_SYSTEM_LIBS = lib.concatStringsSep "," [
+    TF_SYSTEM_LIBS = lib.concatStringsSep "," ([
       "absl_py"
       "astor_archive"
       "astunparse_archive"
@@ -179,7 +189,6 @@ let
       "cython"
       "dill_archive"
       "double_conversion"
-      "enum34_archive"
       "flatbuffers"
       "functools32_archive"
       "gast_archive"
@@ -190,11 +199,9 @@ let
       "libjpeg_turbo"
       "lmdb"
       "nasm"
-      # "nsync" # not packaged in nixpkgs
       "opt_einsum_archive"
       "org_sqlite"
       "pasta"
-      "pcre"
       "png"
       "pybind11"
       "six_archive"
@@ -204,7 +211,9 @@ let
       "typing_extensions_archive"
       "wrapt"
       "zlib"
-    ];
+    ] ++ lib.optionals (!stdenv.isDarwin) [
+      "nsync" # fails to build on darwin
+    ]);
 
     # Make sure Bazel knows about our configuration flags during fetching so that the
     # relevant dependencies can be downloaded.
@@ -226,9 +235,11 @@ let
     fetchAttrs = {
       sha256 =
         if cudaSupport then
-          "sha256-Ald+vplRx/DDG/7TfHAqD4Gktb1BGnf7FSCCJzSI0eo="
+          "sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk="
+        else if stdenv.isDarwin then
+          "sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34="
         else
-          "sha256-eK5IjTAncDarkWYKnXrEo7kw7J7iOH7in2L2GabnFYo=";
+          "sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A=";
     };
 
     buildAttrs = {
@@ -239,15 +250,10 @@ let
       # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
       #    loading multiple extensions in the same python program due to duplicate protobuf DBs.
       # 3) Patch python path in the compiler driver.
-      # 4) Patch tensorflow sources to work with later versions of protobuf. See
-      #    https://github.com/google/jax/issues/9534. Note that this should be
-      #    removed on the next release after 0.3.0.
       preBuild = ''
-        for src in ./jaxlib/*.{cc,h}; do
+        for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
           sed -i 's@include/pybind11@pybind11@g' $src
         done
-        substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
-          --replace "status.message()" "std::string{status.message()}"
       '' + lib.optionalString cudaSupport ''
         patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
       '' + lib.optionalString stdenv.isDarwin ''
@@ -275,7 +281,7 @@ let
   };
   platformTag =
     if stdenv.targetPlatform.isLinux then
-      "manylinux2010_${stdenv.targetPlatform.linuxArch}"
+      "manylinux2014_${stdenv.targetPlatform.linuxArch}"
     else if stdenv.system == "x86_64-darwin" then
       "macosx_10_9_${stdenv.targetPlatform.linuxArch}"
     else if stdenv.system == "aarch64-darwin" then
@@ -306,6 +312,7 @@ buildPythonPackage {
 
   propagatedBuildInputs = [
     absl-py
+    curl
     double-conversion
     flatbuffers
     giflib