diff options
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/bin.nix')
-rw-r--r-- | pkgs/development/python-modules/jaxlib/bin.nix | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index 7e6b00429df..0929831e32a 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -120,9 +120,15 @@ buildPythonPackage rec { done ''; - # pip dependencies and optionally cudatoolkit. Note that cudatoolkit is - # necessary since jaxlib looks for "ptxas" in $PATH. - propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11; + propagatedBuildInputs = [ absl-py flatbuffers scipy ]; + + # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. + # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for + # more info. + postInstall = lib.optional cudaSupport '' + mkdir -p $out/bin + ln -s ${cudatoolkit_11}/bin/ptxas $out/bin/ptxas + ''; pythonImportsCheck = [ "jaxlib" ]; |