summary refs log tree commit diff
path: root/pkgs/development/python-modules/jax/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jax/default.nix')
-rw-r--r--pkgs/development/python-modules/jax/default.nix26
1 files changed, 22 insertions, 4 deletions
diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix
index a92148ada6b..203aa49db8f 100644
--- a/pkgs/development/python-modules/jax/default.nix
+++ b/pkgs/development/python-modules/jax/default.nix
@@ -12,6 +12,7 @@
 , numpy
 , opt-einsum
 , pytestCheckHook
+, pytest-xdist
 , pythonOlder
 , scipy
 , stdenv
@@ -26,17 +27,17 @@ let
 in
 buildPythonPackage rec {
   pname = "jax";
-  version = "0.4.16";
-  format = "pyproject";
+  version = "0.4.20";
+  pyproject = true;
 
   disabled = pythonOlder "3.9";
 
   src = fetchFromGitHub {
     owner = "google";
-    repo = pname;
+    repo = "jax";
     # google/jax contains tags for jax and jaxlib. Only use jax tags!
     rev = "refs/tags/${pname}-v${version}";
-    hash = "sha256-q+8CXGxK8JX0bUMK4KJB3qV/EaLHg68D1B5UrtRz0Eg=";
+    hash = "sha256-WLYXUtchOaA6SGnKuVhN9CmV06xMCLQTEuEtL13ttZU=";
   };
 
   nativeBuildInputs = [
@@ -61,13 +62,18 @@ buildPythonPackage rec {
     jaxlib'
     matplotlib
     pytestCheckHook
+    pytest-xdist
   ];
 
+  # high parallelism will result in the tests getting stuck
+  dontUsePytestXdist = true;
+
   # NOTE: Don't run the tests in the expiremental directory as they require flax
   # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
   # Not a big deal, this is how the JAX docs suggest running the test suite
   # anyhow.
   pytestFlagsArray = [
+    "--numprocesses=4"
     "-W ignore::DeprecationWarning"
     "tests/"
   ];
@@ -94,6 +100,18 @@ buildPythonPackage rec {
     "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop"
     "testQdwhWithRandomMatrix3"
     "testScanGrad_jit_scan"
+
+    # See https://github.com/google/jax/issues/17867.
+    "test_array"
+    "test_async"
+    "test_copy0"
+    "test_device_put"
+    "test_make_array_from_callback"
+    "test_make_array_from_single_device_arrays"
+
+    # Fails on some hardware due to some numerical error
+    # See https://github.com/google/jax/issues/18535
+    "testQdwhWithOnRankDeficientInput5"
   ];
 
   disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [