diff options
Diffstat (limited to 'pkgs/development/python-modules/jax/default.nix')
-rw-r--r-- | pkgs/development/python-modules/jax/default.nix | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index a92148ada6b..203aa49db8f 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -12,6 +12,7 @@ , numpy , opt-einsum , pytestCheckHook +, pytest-xdist , pythonOlder , scipy , stdenv @@ -26,17 +27,17 @@ let in buildPythonPackage rec { pname = "jax"; - version = "0.4.16"; - format = "pyproject"; + version = "0.4.20"; + pyproject = true; disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; - repo = pname; + repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; - hash = "sha256-q+8CXGxK8JX0bUMK4KJB3qV/EaLHg68D1B5UrtRz0Eg="; + hash = "sha256-WLYXUtchOaA6SGnKuVhN9CmV06xMCLQTEuEtL13ttZU="; }; nativeBuildInputs = [ @@ -61,13 +62,18 @@ buildPythonPackage rec { jaxlib' matplotlib pytestCheckHook + pytest-xdist ]; + # high parallelism will result in the tests getting stuck + dontUsePytestXdist = true; + # NOTE: Don't run the tests in the expiremental directory as they require flax # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. # Not a big deal, this is how the JAX docs suggest running the test suite # anyhow. pytestFlagsArray = [ + "--numprocesses=4" "-W ignore::DeprecationWarning" "tests/" ]; @@ -94,6 +100,18 @@ buildPythonPackage rec { "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" "testQdwhWithRandomMatrix3" "testScanGrad_jit_scan" + + # See https://github.com/google/jax/issues/17867. + "test_array" + "test_async" + "test_copy0" + "test_device_put" + "test_make_array_from_callback" + "test_make_array_from_single_device_arrays" + + # Fails on some hardware due to some numerical error + # See https://github.com/google/jax/issues/18535 + "testQdwhWithOnRankDeficientInput5" ]; disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ |