summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--pkgs/test/cuda/cuda-library-samples/default.nix7
-rw-r--r--pkgs/test/cuda/cuda-library-samples/generic.nix19
2 files changed, 26 insertions, 0 deletions
diff --git a/pkgs/test/cuda/cuda-library-samples/default.nix b/pkgs/test/cuda/cuda-library-samples/default.nix
index 501828c9a1f..91095fbd3ac 100644
--- a/pkgs/test/cuda/cuda-library-samples/default.nix
+++ b/pkgs/test/cuda/cuda-library-samples/default.nix
@@ -1,16 +1,20 @@
 { callPackage
 , cudatoolkit_10_1, cudatoolkit_10_2
 , cudatoolkit_11_0, cudatoolkit_11_1, cudatoolkit_11_2
+, cutensor_cudatoolkit_10_1, cutensor_cudatoolkit_10_2
+, cutensor_cudatoolkit_11_0, cutensor_cudatoolkit_11_1, cutensor_cudatoolkit_11_2
 }:
 
 rec {
 
   cuda-library-samples_cudatoolkit_10_1 = callPackage ./generic.nix {
     cudatoolkit = cudatoolkit_10_1;
+    cutensor_cudatoolkit = cutensor_cudatoolkit_10_1;
   };
 
   cuda-library-samples_cudatoolkit_10_2 = callPackage ./generic.nix {
     cudatoolkit = cudatoolkit_10_2;
+    cutensor_cudatoolkit = cutensor_cudatoolkit_10_2;
   };
 
   cuda-library-samples_cudatoolkit_10 =
@@ -20,14 +24,17 @@ rec {
 
   cuda-library-samples_cudatoolkit_11_0 = callPackage ./generic.nix {
     cudatoolkit = cudatoolkit_11_0;
+    cutensor_cudatoolkit = cutensor_cudatoolkit_11_0;
   };
 
   cuda-library-samples_cudatoolkit_11_1 = callPackage ./generic.nix {
     cudatoolkit = cudatoolkit_11_1;
+    cutensor_cudatoolkit = cutensor_cudatoolkit_11_1;
   };
 
   cuda-library-samples_cudatoolkit_11_2 = callPackage ./generic.nix {
     cudatoolkit = cudatoolkit_11_2;
+    cutensor_cudatoolkit = cutensor_cudatoolkit_11_2;
   };
 
   cuda-library-samples_cudatoolkit_11 =
diff --git a/pkgs/test/cuda/cuda-library-samples/generic.nix b/pkgs/test/cuda/cuda-library-samples/generic.nix
index 75d4541d986..999e2abd041 100644
--- a/pkgs/test/cuda/cuda-library-samples/generic.nix
+++ b/pkgs/test/cuda/cuda-library-samples/generic.nix
@@ -1,6 +1,7 @@
 { lib, stdenv, fetchFromGitHub
 , cmake, addOpenGLRunpath
 , cudatoolkit
+, cutensor_cudatoolkit
 }:
 
 let
@@ -48,4 +49,22 @@ in
 
     sourceRoot = "cuSOLVER/gesv";
   });
+
+  cutensor = stdenv.mkDerivation (commonAttrs // {
+    pname = "cuda-library-samples-cutensor";
+
+    src = "${src}/cuTENSOR";
+
+    cmakeFlags = [
+      "-DCUTENSOR_EXAMPLE_BINARY_INSTALL_DIR=${builtins.placeholder "out"}/bin"
+    ];
+
+    # CUTENSOR_ROOT is double escaped
+    postPatch = ''
+      substituteInPlace CMakeLists.txt \
+        --replace "\''${CUTENSOR_ROOT}/include" "${cutensor_cudatoolkit.dev}/include"
+    '';
+
+    CUTENSOR_ROOT = cutensor_cudatoolkit;
+  });
 }