diff options
Diffstat (limited to 'pkgs/development/python-modules/openai-triton/default.nix')
-rw-r--r-- | pkgs/development/python-modules/openai-triton/default.nix | 229 |
1 files changed, 80 insertions, 149 deletions
diff --git a/pkgs/development/python-modules/openai-triton/default.nix b/pkgs/development/python-modules/openai-triton/default.nix index 0e10642f069..e1ac9cb4cef 100644 --- a/pkgs/development/python-modules/openai-triton/default.nix +++ b/pkgs/development/python-modules/openai-triton/default.nix @@ -1,32 +1,28 @@ { lib +, config , buildPythonPackage -, python -, fetchpatch , fetchFromGitHub , addOpenGLRunpath +, pytestCheckHook +, pythonRelaxDepsHook +, pkgsTargetTarget , cmake -, cudaPackages -, llvmPackages +, ninja , pybind11 , gtest , zlib , ncurses , libxml2 , lit +, llvm , filelock , torchWithRocm -, pytest -, pytestCheckHook -, pythonRelaxDepsHook -, pkgsTargetTarget +, python +, cudaPackages +, cudaSupport ? config.cudaSupport }: let - pname = "triton"; - version = "2.0.0"; - - inherit (cudaPackages) cuda_cudart backendStdenv; - # A time may come we'll want to be cross-friendly # # Short explanation: we need pkgsTargetTarget, because we use string @@ -38,20 +34,11 @@ let # pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to # be executed on the GPU. # Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra - ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas"; - - llvm = (llvmPackages.llvm.override { - llvmTargetsToBuild = [ "NATIVE" "NVPTX" ]; - # Upstream CI sets these too: - # targetProjects = [ "mlir" ]; - extraCMakeFlags = [ - "-DLLVM_INSTALL_UTILS=ON" - ]; - }); + ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py) in -buildPythonPackage { - inherit pname version; - +buildPythonPackage rec { + pname = "triton"; + version = "2.0.0"; format = "setuptools"; src = fetchFromGitHub { @@ -62,21 +49,6 @@ buildPythonPackage { }; patches = [ - # Prerequisite for llvm15 patch - (fetchpatch { - url = "https://github.com/openai/triton/commit/2aba985daaa70234823ea8f1161da938477d3e02.patch"; - hash = "sha256-LGv0+Ut2WYPC4Ksi4803Hwmhi3FyQOF9zElJc/JCobk="; - }) - (fetchpatch { - url = "https://github.com/openai/triton/commit/e3941f9d09cdd31529ba4a41018cfc0096aafea6.patch"; - hash = "sha256-A+Gor6qzFlGQhVVhiaaYOzqqx8yO2MdssnQS6TIfUWg="; - }) - - # Source: https://github.com/openai/triton/commit/fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a.patch - # The original patch adds ptxas binary, so we include our own clean copy - # Drop with the next update - ./llvm15.patch - # TODO: there have been commits upstream aimed at removing the "torch" # circular dependency, but the patches fail to apply on the release # revision. Keeping the link for future reference @@ -86,72 +58,15 @@ buildPythonPackage { # url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch"; # hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig="; # }) + ] ++ lib.optionals (!cudaSupport) [ + ./0000-dont-download-ptxas.patch ]; - postPatch = '' - substituteInPlace python/setup.py \ - --replace \ - '= get_thirdparty_packages(triton_cache_path)' \ - '= os.environ["cmakeFlags"].split()' - '' - # Wiring triton=2.0.0 with llcmPackages_rocm.llvm=5.4.3 - # Revisit when updating either triton or llvm - + '' - substituteInPlace CMakeLists.txt \ - --replace "nvptx" "NVPTX" \ - --replace "LLVM 11" "LLVM" - sed -i '/AddMLIR/a set(MLIR_TABLEGEN_EXE "${llvmPackages.mlir}/bin/mlir-tblgen")' CMakeLists.txt - sed -i '/AddMLIR/a set(MLIR_INCLUDE_DIR ''${MLIR_INCLUDE_DIRS})' CMakeLists.txt - find -iname '*.td' -exec \ - sed -i \ - -e '\|include "mlir/IR/OpBase.td"|a include "mlir/IR/AttrTypeBase.td"' \ - -e 's|include "mlir/Dialect/StandardOps/IR/Ops.td"|include "mlir/Dialect/Func/IR/FuncOps.td"|' \ - '{}' ';' - substituteInPlace unittest/CMakeLists.txt --replace "include(GoogleTest)" "find_package(GTest REQUIRED)" - sed -i 's/^include.*$//' unittest/CMakeLists.txt - sed -i '/LINK_LIBS/i NVPTXInfo' lib/Target/PTX/CMakeLists.txt - sed -i '/LINK_LIBS/i NVPTXCodeGen' lib/Target/PTX/CMakeLists.txt - '' - # TritonMLIRIR already links MLIRIR. Not transitive? - # + '' - # echo "target_link_libraries(TritonPTX PUBLIC MLIRIR)" >> lib/Target/PTX/CMakeLists.txt - # '' - # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS - + '' - substituteInPlace bin/CMakeLists.txt \ - --replace "add_subdirectory(FileCheck)" "" - - rm cmake/FindLLVM.cmake - '' - + - ( - let - # Bash was getting weird without linting, - # but basically upstream contains [cc, ..., "-lcuda", ...] - # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...] - old = [ "-lcuda" ]; - new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cuda_cudart}/lib/stubs/" ]; - - quote = x: ''"${x}"''; - oldStr = lib.concatMapStringsSep ", " quote old; - newStr = lib.concatMapStringsSep ", " quote new; - in - '' - substituteInPlace python/triton/compiler.py \ - --replace '${oldStr}' '${newStr}' - '' - ) - # Triton seems to be looking up cuda.h - + '' - sed -i 's|cu_include_dir = os.path.join.*$|cu_include_dir = "${cuda_cudart}/include"|' python/triton/compiler.py - ''; - nativeBuildInputs = [ - cmake pythonRelaxDepsHook - - # Requires torch (circular dependency) and probably needs GPUs: - # pytestCheckHook + # pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs: + cmake + ninja # Note for future: # These *probably* should go in depsTargetTarget @@ -159,7 +74,6 @@ buildPythonPackage { # because we only support cudaPackages on x86_64-linux atm lit llvm - llvmPackages.mlir ]; buildInputs = [ @@ -170,17 +84,41 @@ buildPythonPackage { zlib ]; - propagatedBuildInputs = [ - filelock - ]; + propagatedBuildInputs = [ filelock ]; + + postPatch = let + # Bash was getting weird without linting, + # but basically upstream contains [cc, ..., "-lcuda", ...] + # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...] + old = [ "-lcuda" ]; + new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cudaPackages.cuda_cudart}/lib/stubs/" ]; + + quote = x: ''"${x}"''; + oldStr = lib.concatMapStringsSep ", " quote old; + newStr = lib.concatMapStringsSep ", " quote new; + in '' + # Use our `cmakeFlags` instead and avoid downloading dependencies + substituteInPlace python/setup.py \ + --replace "= get_thirdparty_packages(triton_cache_path)" "= os.environ[\"cmakeFlags\"].split()" + + # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS + substituteInPlace bin/CMakeLists.txt \ + --replace "add_subdirectory(FileCheck)" "" + + # Don't fetch googletest + substituteInPlace unittest/CMakeLists.txt \ + --replace "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\ + --replace "include(GoogleTest)" "find_package(GTest REQUIRED)" + '' + lib.optionalString cudaSupport '' + # Use our linker flags + substituteInPlace python/triton/compiler.py \ + --replace '${oldStr}' '${newStr}' + ''; # Avoid GLIBCXX mismatch with other cuda-enabled python packages preConfigure = '' - export CC="${backendStdenv.cc}/bin/cc"; - export CXX="${backendStdenv.cc}/bin/c++"; - # Upstream's setup.py tries to write cache somewhere in ~/ - export HOME=$TMPDIR + export HOME=$(mktemp -d) # Upstream's github actions patch setup.cfg to write base-dir. May be redundant echo " @@ -188,52 +126,44 @@ buildPythonPackage { base-dir=$PWD" >> python/setup.cfg # The rest (including buildPhase) is relative to ./python/ - cd python/ + cd python + '' + lib.optionalString cudaSupport '' + export CC=${cudaPackages.backendStdenv.cc}/bin/cc; + export CXX=${cudaPackages.backendStdenv.cc}/bin/c++; # Work around download_and_copy_ptxas() - dst_cuda="$PWD/triton/third_party/cuda/bin" - mkdir -p "$dst_cuda" - ln -s "${ptxas}" "$dst_cuda/" + mkdir -p $PWD/triton/third_party/cuda/bin + ln -s ${ptxas} $PWD/triton/third_party/cuda/bin ''; # CMake is run by setup.py instead dontUseCmakeConfigure = true; - cmakeFlags = [ - "-DMLIR_DIR=${llvmPackages.mlir}/lib/cmake/mlir" - ]; - postFixup = - let - ptxasDestination = "$out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas"; - in - # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink - '' - rm -f ${ptxasDestination} - ln -s ${ptxas} ${ptxasDestination} - ''; - - checkInputs = [ - cmake # ctest - ]; + # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink + postFixup = lib.optionalString cudaSupport '' + rm -f $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas + ln -s ${ptxas} $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas + ''; + + checkInputs = [ cmake ]; # ctest dontUseSetuptoolsCheck = true; - preCheck = + + preCheck = '' # build/temp* refers to build_ext.build_temp (looked up in the build logs) - '' - (cd /build/source/python/build/temp* ; ctest) - '' # For pytestCheckHook - + '' - cd test/unit - ''; - pythonImportsCheck = [ - # Circular dependency on torch - # "triton" - # "triton.language" - ]; + (cd /build/source/python/build/temp* ; ctest) + + # For pytestCheckHook + cd test/unit + ''; + + # Circular dependency on torch + # pythonImportsCheck = [ + # "triton" + # "triton.language" + # ]; # Ultimately, torch is our test suite: - passthru.tests = { - inherit torchWithRocm; - }; + passthru.tests = { inherit torchWithRocm; }; pythonRemoveDeps = [ # Circular dependency, cf. https://github.com/openai/triton/issues/1374 @@ -243,11 +173,12 @@ buildPythonPackage { "cmake" "lit" ]; + meta = with lib; { - description = "Development repository for the Triton language and compiler"; - homepage = "https://github.com/openai/triton/"; + description = "Language and compiler for writing highly efficient custom Deep-Learning primitives"; + homepage = "https://github.com/openai/triton"; platforms = lib.platforms.unix; license = licenses.mit; - maintainers = with maintainers; [ SomeoneSerge ]; + maintainers = with maintainers; [ SomeoneSerge Madouura ]; }; } |