diff options
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r-- | pkgs/development/python-modules/jaxlib/default.nix | 39 |
1 files changed, 23 insertions, 16 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index eee432f7185..456c9108593 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -9,11 +9,14 @@ , buildBazelPackage , buildPythonPackage , cctools +, curl , cython , fetchFromGitHub , git , IOKit , jsoncpp +, nsync +, openssl , pybind11 , setuptools , symlinkJoin @@ -50,7 +53,7 @@ let inherit (cudaPackages) cudatoolkit cudnn nccl; pname = "jaxlib"; - version = "0.3.0"; + version = "0.3.15"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; @@ -93,7 +96,7 @@ let owner = "google"; repo = "jax"; rev = "${pname}-v${version}"; - sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72"; + sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s="; }; nativeBuildInputs = [ @@ -103,15 +106,19 @@ let setuptools wheel which + ] ++ lib.optionals stdenv.isDarwin [ + cctools ]; buildInputs = [ + curl double-conversion giflib grpc jsoncpp libjpeg_turbo numpy + openssl pkgs.flatbuffers pkgs.protobuf pybind11 @@ -124,6 +131,8 @@ let cudnn ] ++ lib.optionals stdenv.isDarwin [ IOKit + ] ++ lib.optionals (!stdenv.isDarwin) [ + nsync ]; postPatch = '' @@ -149,6 +158,7 @@ let build --action_env=PYENV_ROOT build --python_path="${python}/bin/python" build --distinct_host_configuration=false + build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" '' + lib.optionalString cudaSupport '' build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" build --action_env CUDNN_INSTALL_PATH="${cudnn}" @@ -163,7 +173,7 @@ let # Copy-paste from TF derivation. # Most of these are not really used in jaxlib compilation but it's simpler to keep it # 'as is' so that it's more compatible with TF derivation. - TF_SYSTEM_LIBS = lib.concatStringsSep "," [ + TF_SYSTEM_LIBS = lib.concatStringsSep "," ([ "absl_py" "astor_archive" "astunparse_archive" @@ -179,7 +189,6 @@ let "cython" "dill_archive" "double_conversion" - "enum34_archive" "flatbuffers" "functools32_archive" "gast_archive" @@ -190,11 +199,9 @@ let "libjpeg_turbo" "lmdb" "nasm" - # "nsync" # not packaged in nixpkgs "opt_einsum_archive" "org_sqlite" "pasta" - "pcre" "png" "pybind11" "six_archive" @@ -204,7 +211,9 @@ let "typing_extensions_archive" "wrapt" "zlib" - ]; + ] ++ lib.optionals (!stdenv.isDarwin) [ + "nsync" # fails to build on darwin + ]); # Make sure Bazel knows about our configuration flags during fetching so that the # relevant dependencies can be downloaded. @@ -226,9 +235,11 @@ let fetchAttrs = { sha256 = if cudaSupport then - "sha256-Ald+vplRx/DDG/7TfHAqD4Gktb1BGnf7FSCCJzSI0eo=" + "sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk=" + else if stdenv.isDarwin then + "sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34=" else - "sha256-eK5IjTAncDarkWYKnXrEo7kw7J7iOH7in2L2GabnFYo="; + "sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A="; }; buildAttrs = { @@ -239,15 +250,10 @@ let # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # loading multiple extensions in the same python program due to duplicate protobuf DBs. # 3) Patch python path in the compiler driver. - # 4) Patch tensorflow sources to work with later versions of protobuf. See - # https://github.com/google/jax/issues/9534. Note that this should be - # removed on the next release after 0.3.0. preBuild = '' - for src in ./jaxlib/*.{cc,h}; do + for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do sed -i 's@include/pybind11@pybind11@g' $src done - substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \ - --replace "status.message()" "std::string{status.message()}" '' + lib.optionalString cudaSupport '' patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl '' + lib.optionalString stdenv.isDarwin '' @@ -275,7 +281,7 @@ let }; platformTag = if stdenv.targetPlatform.isLinux then - "manylinux2010_${stdenv.targetPlatform.linuxArch}" + "manylinux2014_${stdenv.targetPlatform.linuxArch}" else if stdenv.system == "x86_64-darwin" then "macosx_10_9_${stdenv.targetPlatform.linuxArch}" else if stdenv.system == "aarch64-darwin" then @@ -306,6 +312,7 @@ buildPythonPackage { propagatedBuildInputs = [ absl-py + curl double-conversion flatbuffers giflib |