summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Reilly <OmnipotentEntity@gmail.com>2023-01-11 13:05:46 -0600
committerMichael Reilly <OmnipotentEntity@gmail.com>2023-03-21 16:42:57 -0500
commit98df16e2417ddfc1ae12a43bd90b47f6e2e5e2e5 (patch)
treef58848cac147f9e6e00ebc5bccf73bbb63919b4e
parentdcb5b5500e86d59bb7d790fd05cc374a16ded610 (diff)
downloadnixpkgs-98df16e2417ddfc1ae12a43bd90b47f6e2e5e2e5.tar
nixpkgs-98df16e2417ddfc1ae12a43bd90b47f6e2e5e2e5.tar.gz
nixpkgs-98df16e2417ddfc1ae12a43bd90b47f6e2e5e2e5.tar.bz2
nixpkgs-98df16e2417ddfc1ae12a43bd90b47f6e2e5e2e5.tar.lz
nixpkgs-98df16e2417ddfc1ae12a43bd90b47f6e2e5e2e5.tar.xz
nixpkgs-98df16e2417ddfc1ae12a43bd90b47f6e2e5e2e5.tar.zst
nixpkgs-98df16e2417ddfc1ae12a43bd90b47f6e2e5e2e5.zip
katago: 1.11.0 -> 1.12.4
Added TensorRT backend, refactored backend selection, IMPORTANT also
bumps TensorRT version to 8.5.2 for CUDA 11.7 and 11.8.  (Per Nvidia's
documentation, 8.4.0.6 does not officially support 11.7, but we're using
for it.  Additionally, KataGo requires 8.5 or better for the new
TensorRT backend.  8.4.0.6 remains the default for 11.6 and lower, in
interest of not messing things up as much as possible.)

New version also adds support for nested residual block networks, such as the
b18 network that was used in the UEC.
-rw-r--r--pkgs/development/libraries/science/math/tensorrt/extension.nix78
-rw-r--r--pkgs/development/libraries/science/math/tensorrt/generic.nix4
-rw-r--r--pkgs/games/katago/default.nix36
-rw-r--r--pkgs/top-level/all-packages.nix8
4 files changed, 102 insertions, 24 deletions
diff --git a/pkgs/development/libraries/science/math/tensorrt/extension.nix b/pkgs/development/libraries/science/math/tensorrt/extension.nix
index b4f47a8969c..57e71c4928a 100644
--- a/pkgs/development/libraries/science/math/tensorrt/extension.nix
+++ b/pkgs/development/libraries/science/math/tensorrt/extension.nix
@@ -29,22 +29,91 @@ final: prev: let
       else throw "tensorrt-${tensorRTDefaultVersion} does not support your cuda version ${cudaVersion}"; };
   in allBuilds // defaultBuild;
 
+  tarballURL =
+  {fullVersion, fileVersionCuda, fileVersionCudnn ? null} :
+    "TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}"
+    + lib.optionalString (fileVersionCudnn != null) ".cudnn${fileVersionCudnn}"
+    + ".tar.gz";
+
   tensorRTVersions = {
+    "8.6.0" = [
+      rec {
+        fileVersionCuda = "11.8";
+        fullVersion = "8.6.0.12";
+        sha256 = "sha256-wXMqEJPFerefoLaH8GG+Np5EnJwXeStmDzZj7Nj6e2M=";
+        tarball = tarballURL { inherit fileVersionCuda fullVersion; };
+        supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" "11.8" ];
+      }
+    ];
+    "8.5.3" = [
+      rec {
+        fileVersionCuda = "11.8";
+        fileVersionCudnn = "8.6";
+        fullVersion = "8.5.3.1";
+        sha256 = "sha256-BNeuOYvPTUAfGxI0DVsNrX6Z/FAB28+SE0ptuGu7YDY=";
+        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
+        supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" "11.8" ];
+      }
+      rec {
+        fileVersionCuda = "10.2";
+        fileVersionCudnn = "8.6";
+        fullVersion = "8.5.3.1";
+        sha256 = "sha256-WCt6yfOmFbrjqdYCj6AE2+s2uFpISwk6urP+2I0BnGQ=";
+        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
+        supportedCudaVersions = [ "10.2" ];
+      }
+    ];
+    "8.5.2" = [
+      rec {
+        fileVersionCuda = "11.8";
+        fileVersionCudnn = "8.6";
+        fullVersion = "8.5.2.2";
+        sha256 = "sha256-Ov5irNS/JETpEz01FIFNMs9YVmjGHL7lSXmDpgCdgao=";
+        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
+        supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" "11.8" ];
+      }
+      rec {
+        fileVersionCuda = "10.2";
+        fileVersionCudnn = "8.6";
+        fullVersion = "8.5.2.2";
+        sha256 = "sha256-UruwQShYcHLY5d81lKNG7XaoUsZr245c+PUpUN6pC5E=";
+        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
+        supportedCudaVersions = [ "10.2" ];
+      }
+    ];
+    "8.5.1" = [
+      rec {
+        fileVersionCuda = "11.8";
+        fileVersionCudnn = "8.6";
+        fullVersion = "8.5.1.7";
+        sha256 = "sha256-Ocx/B3BX0TY3lOj/UcTPIaXb7M8RFrACC6Da4PMGMHY=";
+        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
+        supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" "11.8" ];
+      }
+      rec {
+        fileVersionCuda = "10.2";
+        fileVersionCudnn = "8.6";
+        fullVersion = "8.5.1.7";
+        sha256 = "sha256-CcFGJhw7nFdPnSYYSxcto2MHK3F84nLQlJYjdIw8dPM=";
+        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
+        supportedCudaVersions = [ "10.2" ];
+      }
+    ];
     "8.4.0" = [
       rec {
         fileVersionCuda = "11.6";
         fileVersionCudnn = "8.3";
         fullVersion = "8.4.0.6";
         sha256 = "sha256-DNgHHXF/G4cK2nnOWImrPXAkOcNW6Wy+8j0LRpAH/LQ=";
-        tarball = "TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}.cudnn${fileVersionCudnn}.tar.gz";
-        supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" ];
+        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
+        supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" ];
       }
       rec {
         fileVersionCuda = "10.2";
         fileVersionCudnn = "8.3";
         fullVersion = "8.4.0.6";
         sha256 = "sha256-aCzH0ZI6BrJ0v+e5Bnm7b8mNltA7NNuIa8qRKzAQv+I=";
-        tarball = "TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}.cudnn${fileVersionCudnn}.tar.gz";
+        tarball = tarballURL { inherit fileVersionCuda fileVersionCudnn fullVersion; };
         supportedCudaVersions = [ "10.2" ];
       }
     ];
@@ -60,7 +129,8 @@ final: prev: let
     "11.4" = "8.4.0";
     "11.5" = "8.4.0";
     "11.6" = "8.4.0";
-    "11.7" = "8.4.0";
+    "11.7" = "8.5.3";
+    "11.8" = "8.5.3";
   }.${cudaVersion} or "8.4.0";
 
 in tensorRTPackages
diff --git a/pkgs/development/libraries/science/math/tensorrt/generic.nix b/pkgs/development/libraries/science/math/tensorrt/generic.nix
index 3447087051f..492fde77e51 100644
--- a/pkgs/development/libraries/science/math/tensorrt/generic.nix
+++ b/pkgs/development/libraries/science/math/tensorrt/generic.nix
@@ -9,13 +9,13 @@
 }:
 
 { fullVersion
-, fileVersionCudnn
+, fileVersionCudnn ? null
 , tarball
 , sha256
 , supportedCudaVersions ? [ ]
 }:
 
-assert lib.assertMsg (lib.strings.versionAtLeast cudnn.version fileVersionCudnn)
+assert fileVersionCudnn == null || lib.assertMsg (lib.strings.versionAtLeast cudnn.version fileVersionCudnn)
   "This version of TensorRT requires at least cuDNN ${fileVersionCudnn} (current version is ${cudnn.version})";
 
 stdenv.mkDerivation rec {
diff --git a/pkgs/games/katago/default.nix b/pkgs/games/katago/default.nix
index a3d7ed5cba9..bc3ac458de4 100644
--- a/pkgs/games/katago/default.nix
+++ b/pkgs/games/katago/default.nix
@@ -14,28 +14,26 @@
 , openssl
 , writeShellScriptBin
 , enableAVX2 ? stdenv.hostPlatform.avx2Support
+, backend ? "opencl"
 , enableBigBoards ? false
-, enableCuda ? false
 , enableContrib ? false
-, enableGPU ? true
 , enableTcmalloc ? true
 }:
 
-assert !enableGPU -> (
-  !enableCuda);
+assert lib.assertOneOf "backend" backend [ "opencl" "cuda" "tensorrt" "eigen" ];
 
 # N.b. older versions of cuda toolkit (e.g. 10) do not support newer versions
 # of gcc.  If you need to use cuda10, please override stdenv with gcc8Stdenv
 stdenv.mkDerivation rec {
   pname = "katago";
-  version = "1.11.0";
-  githash = "d8d0cd76cf73df08af3d7061a639488ae9494419";
+  version = "1.12.4";
+  githash = "75280bf26582090dd4985dca62bc7124116c856d";
 
   src = fetchFromGitHub {
     owner = "lightvector";
     repo = "katago";
     rev = "v${version}";
-    sha256 = "sha256-TZKkkYe2PPzgPhItBZBSJDwU3anhsujuCGIYru55OtU=";
+    sha256 = "sha256-1rznAxEFJ/Ah5/WiSwc+rtITOUOPYOua5BLKeqHOBr0=";
   };
 
   fakegit = writeShellScriptBin "git" "echo ${githash}";
@@ -48,13 +46,17 @@ stdenv.mkDerivation rec {
   buildInputs = [
     libzip
     boost
-  ] ++ lib.optionals (!enableGPU) [
+  ] ++ lib.optionals (backend == "eigen") [
     eigen
-  ] ++ lib.optionals (enableGPU && enableCuda) [
+  ] ++ lib.optionals (backend == "cuda") [
     cudaPackages.cudnn
     cudaPackages.cudatoolkit
     mesa.drivers
-  ] ++ lib.optionals (enableGPU && !enableCuda) [
+  ] ++ lib.optionals (backend == "tensorrt") [
+      cudaPackages.cudatoolkit
+      cudaPackages.tensorrt
+      mesa.drivers
+  ] ++ lib.optionals (backend == "opencl") [
     opencl-headers
     ocl-icd
   ] ++ lib.optionals enableContrib [
@@ -65,13 +67,15 @@ stdenv.mkDerivation rec {
 
   cmakeFlags = [
     "-DNO_GIT_REVISION=ON"
-  ] ++ lib.optionals (!enableGPU) [
-    "-DUSE_BACKEND=EIGEN"
   ] ++ lib.optionals enableAVX2 [
     "-DUSE_AVX2=ON"
-  ] ++ lib.optionals (enableGPU && enableCuda) [
+  ] ++ lib.optionals (backend == "eigen") [
+    "-DUSE_BACKEND=EIGEN"
+  ] ++ lib.optionals (backend == "cuda") [
     "-DUSE_BACKEND=CUDA"
-  ] ++ lib.optionals (enableGPU && !enableCuda) [
+  ] ++ lib.optionals (backend == "tensorrt") [
+    "-DUSE_BACKEND=TENSORRT"
+  ] ++ lib.optionals (backend == "opencl") [
     "-DUSE_BACKEND=OPENCL"
   ] ++ lib.optionals enableContrib [
     "-DBUILD_DISTRIBUTED=1"
@@ -85,7 +89,7 @@ stdenv.mkDerivation rec {
 
   preConfigure = ''
     cd cpp/
-  '' + lib.optionalString enableCuda ''
+  '' + lib.optionalString (backend == "cuda" || backend == "tensorrt") ''
     export CUDA_PATH="${cudaPackages.cudatoolkit}"
     export EXTRA_LDFLAGS="-L/run/opengl-driver/lib"
   '';
@@ -93,7 +97,7 @@ stdenv.mkDerivation rec {
   installPhase = ''
     runHook preInstall
     mkdir -p $out/bin; cp katago $out/bin;
-  '' + lib.optionalString enableCuda ''
+  '' + lib.optionalString (backend == "cuda" || backend == "tensorrt") ''
     wrapProgram $out/bin/katago \
       --prefix LD_LIBRARY_PATH : "/run/opengl-driver/lib"
   '' + ''
diff --git a/pkgs/top-level/all-packages.nix b/pkgs/top-level/all-packages.nix
index 27359d4df42..32343297b12 100644
--- a/pkgs/top-level/all-packages.nix
+++ b/pkgs/top-level/all-packages.nix
@@ -35399,11 +35399,15 @@ with pkgs;
   katago = callPackage ../games/katago { };
 
   katagoWithCuda = katago.override {
-    enableCuda = true;
+    backend = "cuda";
   };
 
   katagoCPU = katago.override {
-    enableGPU = false;
+    backend = "eigen";
+  };
+
+  katagoTensorRT = katago.override {
+    backend = "tensorrt";
   };
 
   klavaro = callPackage ../games/klavaro { };