summary refs log tree commit diff
diff options
context:
space:
mode:
authorSamuel Ainsworth <skainsworth@gmail.com>2022-02-25 21:51:29 +0000
committerFrederik Rietdijk <freddyrietdijk@fridh.nl>2022-02-26 07:33:28 +0100
commit33984cd89ca94954353be396631bcb6a9190c0d2 (patch)
treeb96a2f3d10246dbb8352959f23fe00c953c8f29d
parentb398f196e6ce1ca65cb3b0c9472eea71b251e092 (diff)
downloadnixpkgs-33984cd89ca94954353be396631bcb6a9190c0d2.tar
nixpkgs-33984cd89ca94954353be396631bcb6a9190c0d2.tar.gz
nixpkgs-33984cd89ca94954353be396631bcb6a9190c0d2.tar.bz2
nixpkgs-33984cd89ca94954353be396631bcb6a9190c0d2.tar.lz
nixpkgs-33984cd89ca94954353be396631bcb6a9190c0d2.tar.xz
nixpkgs-33984cd89ca94954353be396631bcb6a9190c0d2.tar.zst
nixpkgs-33984cd89ca94954353be396631bcb6a9190c0d2.zip
python3Packages.jax: support MKL BLAS/LAPACK implementations
-rw-r--r--pkgs/development/python-modules/jax/default.nix8
1 files changed, 8 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix
index c91b6e48522..a616a83255b 100644
--- a/pkgs/development/python-modules/jax/default.nix
+++ b/pkgs/development/python-modules/jax/default.nix
@@ -1,8 +1,10 @@
 { lib
 , absl-py
+, blas
 , buildPythonPackage
 , fetchFromGitHub
 , jaxlib
+, lapack
 , numpy
 , opt-einsum
 , pytestCheckHook
@@ -12,6 +14,9 @@
 , typing-extensions
 }:
 
+let
+  usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
+in
 buildPythonPackage rec {
   pname = "jax";
   version = "0.3.1";
@@ -59,6 +64,9 @@ buildPythonPackage rec {
     "tests/"
   ];
 
+  # See https://github.com/google/jax/issues/9705.
+  disabledTests = lib.optionals usingMKL [ "test_custom_root_with_aux" ];
+
   pythonImportsCheck = [
     "jax"
   ];