summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix6
1 files changed, 6 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index 664e109719a..363bfe56134 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -259,7 +259,13 @@ buildPythonPackage {
 
   src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl";
 
+  # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
+  # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
+  # more info.
   postInstall = lib.optionalString cudaSupport ''
+    mkdir -p $out/bin
+    ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
+
     find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
       addOpenGLRunpath "$lib"
       patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"