summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlexander Kiselyov <aleksandr.kiselyov@gmail.com>2021-08-08 20:42:58 +0300
committerGitHub <noreply@github.com>2021-08-08 19:42:58 +0200
commit717538e9082a42d97dce9f53740045a73865b18e (patch)
treead47d66d3797fc49b06fd0aa86e863d45902e7b5
parent0d078fcdb23c98086b0c46d1e6c9b55e52892dab (diff)
downloadnixpkgs-717538e9082a42d97dce9f53740045a73865b18e.tar
nixpkgs-717538e9082a42d97dce9f53740045a73865b18e.tar.gz
nixpkgs-717538e9082a42d97dce9f53740045a73865b18e.tar.bz2
nixpkgs-717538e9082a42d97dce9f53740045a73865b18e.tar.lz
nixpkgs-717538e9082a42d97dce9f53740045a73865b18e.tar.xz
nixpkgs-717538e9082a42d97dce9f53740045a73865b18e.tar.zst
nixpkgs-717538e9082a42d97dce9f53740045a73865b18e.zip
python3Packages.torchvision: added cudaSupport option (#132917)
Co-authored-by: Sandro <sandro.jaeckel@gmail.com>
-rw-r--r--pkgs/development/python-modules/pytorch/default.nix5
-rw-r--r--pkgs/development/python-modules/torchvision/default.nix24
2 files changed, 26 insertions, 3 deletions
diff --git a/pkgs/development/python-modules/pytorch/default.nix b/pkgs/development/python-modules/pytorch/default.nix
index 35eb79d8b2d..0de0015ab1e 100644
--- a/pkgs/development/python-modules/pytorch/default.nix
+++ b/pkgs/development/python-modules/pytorch/default.nix
@@ -301,6 +301,11 @@ in buildPythonPackage rec {
   # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
   requiredSystemFeatures = [ "big-parallel" ];
 
+  passthru = {
+    inherit cudaSupport;
+    cudaArchList = final_cudaArchList;
+  };
+
   meta = with lib; {
     description = "Open source, prototype-to-production deep learning platform";
     homepage    = "https://pytorch.org/";
diff --git a/pkgs/development/python-modules/torchvision/default.nix b/pkgs/development/python-modules/torchvision/default.nix
index a42c517ede9..fc9905881cb 100644
--- a/pkgs/development/python-modules/torchvision/default.nix
+++ b/pkgs/development/python-modules/torchvision/default.nix
@@ -1,4 +1,5 @@
 { lib
+, symlinkJoin
 , buildPythonPackage
 , fetchFromGitHub
 , ninja
@@ -10,9 +11,18 @@
 , pillow
 , pytorch
 , pytest
+, cudatoolkit
+, cudnn
+, cudaSupport ? pytorch.cudaSupport or false # by default uses the value from pytorch
 }:
 
-buildPythonPackage rec {
+let
+  cudatoolkit_joined = symlinkJoin {
+    name = "${cudatoolkit.name}-unsplit";
+    paths = [ cudatoolkit.out cudatoolkit.lib ];
+  };
+  cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" pytorch.cudaArchList;
+in buildPythonPackage rec {
   pname = "torchvision";
   version = "0.10.0";
 
@@ -23,15 +33,22 @@ buildPythonPackage rec {
     sha256 = "13j04ij0jmi58nhav1p69xrm8dg7jisg23268i3n6lnms37n02kc";
   };
 
-  nativeBuildInputs = [ libpng ninja which ];
+  nativeBuildInputs = [ libpng ninja which ]
+    ++ lib.optionals cudaSupport [ cudatoolkit_joined ];
 
   TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/";
   TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/";
 
-  buildInputs = [ libjpeg_turbo libpng ];
+  buildInputs = [ libjpeg_turbo libpng ]
+    ++ lib.optionals cudaSupport [ cudnn ];
 
   propagatedBuildInputs = [ numpy pillow pytorch scipy ];
 
+  preBuild = lib.optionalString cudaSupport ''
+    export TORCH_CUDA_ARCH_LIST="${cudaArchStr}"
+    export FORCE_CUDA=1
+  '';
+
   # tries to download many datasets for tests
   doCheck = false;
 
@@ -45,6 +62,7 @@ buildPythonPackage rec {
     description = "PyTorch vision library";
     homepage = "https://pytorch.org/";
     license = licenses.bsd3;
+    platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin;
     maintainers = with maintainers; [ ericsagnes ];
   };
 }