diff options
Diffstat (limited to 'pkgs/development/libraries/science/math/libtorch/bin.nix')
-rw-r--r-- | pkgs/development/libraries/science/math/libtorch/bin.nix | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/pkgs/development/libraries/science/math/libtorch/bin.nix b/pkgs/development/libraries/science/math/libtorch/bin.nix index 9631f3931ca..72c4e5ac1eb 100644 --- a/pkgs/development/libraries/science/math/libtorch/bin.nix +++ b/pkgs/development/libraries/science/math/libtorch/bin.nix @@ -8,10 +8,16 @@ , fixDarwinDylibNames , cudaSupport -, nvidia_x11 +, cudatoolkit_10_2 +, cudnn_cudatoolkit_10_2 }: let + # The binary libtorch distribution statically links the CUDA + # toolkit. This means that we do not need to provide CUDA to + # this derivation. However, we should ensure on version bumps + # that the CUDA toolkit for `passthru.tests` is still + # up-to-date. version = "1.7.1"; device = if cudaSupport then "cuda" else "cpu"; srcs = import ./binary-hashes.nix version; @@ -24,12 +30,7 @@ in stdenv.mkDerivation { nativeBuildInputs = if stdenv.isDarwin then [ fixDarwinDylibNames ] - else [ addOpenGLRunpath patchelf ] - ++ lib.optionals cudaSupport [ addOpenGLRunpath ]; - - buildInputs = [ - stdenv.cc.cc - ] ++ lib.optionals cudaSupport [ nvidia_x11 ]; + else [ patchelf ] ++ lib.optionals cudaSupport [ addOpenGLRunpath ]; dontBuild = true; dontConfigure = true; @@ -56,9 +57,7 @@ in stdenv.mkDerivation { ''; postFixup = let - libPaths = [ stdenv.cc.cc.lib ] - ++ lib.optionals cudaSupport [ nvidia_x11 ]; - rpath = lib.makeLibraryPath libPaths; + rpath = lib.makeLibraryPath [ stdenv.cc.cc.lib ]; in lib.optionalString stdenv.isLinux '' find $out/lib -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do echo "setting rpath for $lib..." @@ -108,12 +107,17 @@ in stdenv.mkDerivation { outputs = [ "out" "dev" ]; - passthru.tests.cmake = callPackage ./test { }; + passthru.tests.cmake = callPackage ./test { + inherit cudaSupport; + cudatoolkit = cudatoolkit_10_2; + cudnn = cudnn_cudatoolkit_10_2; + }; meta = with lib; { description = "C++ API of the PyTorch machine learning framework"; homepage = "https://pytorch.org/"; license = licenses.unfree; # Includes CUDA and Intel MKL. + maintainers = with maintainers; [ danieldk ]; platforms = with platforms; linux ++ darwin; }; } |