summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/bin.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/bin.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/bin.nix12
1 files changed, 9 insertions, 3 deletions
diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix
index 7e6b00429df..0929831e32a 100644
--- a/pkgs/development/python-modules/jaxlib/bin.nix
+++ b/pkgs/development/python-modules/jaxlib/bin.nix
@@ -120,9 +120,15 @@ buildPythonPackage rec {
     done
   '';
 
-  # pip dependencies and optionally cudatoolkit. Note that cudatoolkit is
-  # necessary since jaxlib looks for "ptxas" in $PATH.
-  propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11;
+  propagatedBuildInputs = [ absl-py flatbuffers scipy ];
+
+  # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
+  # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
+  # more info.
+  postInstall = lib.optional cudaSupport ''
+    mkdir -p $out/bin
+    ln -s ${cudatoolkit_11}/bin/ptxas $out/bin/ptxas
+  '';
 
   pythonImportsCheck = [ "jaxlib" ];