summary refs log tree commit diff
path: root/pkgs/development/libraries/science/math/libtorch/bin.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/libraries/science/math/libtorch/bin.nix')
-rw-r--r--pkgs/development/libraries/science/math/libtorch/bin.nix26
1 files changed, 15 insertions, 11 deletions
diff --git a/pkgs/development/libraries/science/math/libtorch/bin.nix b/pkgs/development/libraries/science/math/libtorch/bin.nix
index 9631f3931ca..72c4e5ac1eb 100644
--- a/pkgs/development/libraries/science/math/libtorch/bin.nix
+++ b/pkgs/development/libraries/science/math/libtorch/bin.nix
@@ -8,10 +8,16 @@
 , fixDarwinDylibNames
 
 , cudaSupport
-, nvidia_x11
+, cudatoolkit_10_2
+, cudnn_cudatoolkit_10_2
 }:
 
 let
+  # The binary libtorch distribution statically links the CUDA
+  # toolkit. This means that we do not need to provide CUDA to
+  # this derivation. However, we should ensure on version bumps
+  # that the CUDA toolkit for `passthru.tests` is still
+  # up-to-date.
   version = "1.7.1";
   device = if cudaSupport then "cuda" else "cpu";
   srcs = import ./binary-hashes.nix version;
@@ -24,12 +30,7 @@ in stdenv.mkDerivation {
 
   nativeBuildInputs =
     if stdenv.isDarwin then [ fixDarwinDylibNames ]
-    else [ addOpenGLRunpath patchelf ]
-      ++ lib.optionals cudaSupport [ addOpenGLRunpath ];
-
-  buildInputs = [
-    stdenv.cc.cc
-  ] ++ lib.optionals cudaSupport [ nvidia_x11 ];
+    else [ patchelf ] ++ lib.optionals cudaSupport [ addOpenGLRunpath ];
 
   dontBuild = true;
   dontConfigure = true;
@@ -56,9 +57,7 @@ in stdenv.mkDerivation {
   '';
 
   postFixup = let
-    libPaths = [ stdenv.cc.cc.lib ]
-      ++ lib.optionals cudaSupport [ nvidia_x11 ];
-    rpath = lib.makeLibraryPath libPaths;
+    rpath = lib.makeLibraryPath [ stdenv.cc.cc.lib ];
   in lib.optionalString stdenv.isLinux ''
     find $out/lib -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
       echo "setting rpath for $lib..."
@@ -108,12 +107,17 @@ in stdenv.mkDerivation {
 
   outputs = [ "out" "dev" ];
 
-  passthru.tests.cmake = callPackage ./test { };
+  passthru.tests.cmake = callPackage ./test {
+    inherit cudaSupport;
+    cudatoolkit = cudatoolkit_10_2;
+    cudnn = cudnn_cudatoolkit_10_2;
+  };
 
   meta = with lib; {
     description = "C++ API of the PyTorch machine learning framework";
     homepage = "https://pytorch.org/";
     license = licenses.unfree; # Includes CUDA and Intel MKL.
+    maintainers = with maintainers; [ danieldk ];
     platforms = with platforms; linux ++ darwin;
   };
 }