summary refs log tree commit diff
path: root/pkgs/development/libraries/science/math/cudnn/extension.nix
blob: 81e4b76f21c6b021b9acb33fc492f56b39dfb30b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Support matrix can be found at
# https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn-880/support-matrix/index.html
# Type aliases
# Release = {
#   version: String,
#   minCudaVersion: String,
#   maxCudaVersion: String,
#   url: String,
#   hash: String,
# }
final: prev: let
  inherit (final) callPackage;
  inherit (prev) cudaVersion;
  inherit (prev.lib) attrsets lists versions strings trivial;

  # Utilities
  # majorMinorPatch :: String -> String
  majorMinorPatch = (trivial.flip trivial.pipe) [
    (versions.splitVersion)
    (lists.take 3)
    (strings.concatStringsSep ".")
  ];

  # Compute versioned attribute name to be used in this package set
  # computeName :: String -> String
  computeName = version: "cudnn_${strings.replaceStrings ["."] ["_"] (majorMinorPatch version)}";

  # Check whether a CUDNN release supports our CUDA version
  # Thankfully we're able to do lexicographic comparison on the version strings
  # isSupported :: Release -> Bool
  isSupported = release:
    strings.versionAtLeast cudaVersion release.minCudaVersion
    && strings.versionAtLeast release.maxCudaVersion cudaVersion;

  # useCudatoolkitRunfile :: Bool
  useCudatoolkitRunfile = strings.versionOlder cudaVersion "11.3.999";

  # buildCuDnnPackage :: Release -> Derivation
  buildCuDnnPackage = callPackage ./generic.nix {inherit useCudatoolkitRunfile;};

  # Reverse the list to have the latest release first
  # cudnnReleases :: List Release
  cudnnReleases = lists.reverseList (builtins.import ./releases.nix);

  # Check whether a CUDNN release supports our CUDA version
  # supportedReleases :: NonEmptyList Release
  supportedReleases = let
    filtered = builtins.filter isSupported cudnnReleases;
    nonEmptyFiltered =
      trivial.throwIf (filtered == [])
      ''
        CUDNN does not support your cuda version ${cudaVersion}
      ''
      filtered;
  in
    nonEmptyFiltered;

  # The latest release is the first element of the list and will be our default choice
  # latestReleaseName :: String
  latestReleaseName = computeName (builtins.head supportedReleases).version;

  # Function to transform our releases into build attributes
  # toBuildAttrs :: Release -> { name: String, value: Derivation }
  toBuildAttrs = release: {
    name = computeName release.version;
    value = buildCuDnnPackage release;
  };

  # Add all supported builds as attributes
  # allBuilds :: AttrSet String Derivation
  allBuilds = builtins.listToAttrs (builtins.map toBuildAttrs supportedReleases);

  # The latest release will be our default build
  # defaultBuild :: AttrSet String Derivation
  defaultBuild.cudnn = allBuilds.${latestReleaseName};

  # builds :: AttrSet String Derivation
  builds = allBuilds // defaultBuild;
in
  builds