summary refs log tree commit diff
path: root/pkgs/development/python-modules/treex/default.nix
blob: 7ed83adc64d15898df4e4fce2dea1d309c133079 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
{ buildPythonPackage
, cloudpickle
, dm-haiku
, einops
, fetchFromGitHub
, flax
, hypothesis
, jaxlib
, keras
, lib
, poetry-core
, pytestCheckHook
, pyyaml
, rich
, tensorflow
, treeo
}:

buildPythonPackage rec {
  pname = "treex";
  version = "0.6.9";
  format = "pyproject";

  src = fetchFromGitHub {
    owner = "cgarciae";
    repo = pname;
    rev = version;
    sha256 = "1yvlldmhji12h249j14ba44hnb9x1fhrj7rh1cx2vn0vxj5wpg7x";
  };

  postPatch = ''
    substituteInPlace pyproject.toml \
      --replace 'rich = "^10.7.0"' 'rich = ">=10.7.0"' \
      --replace 'PyYAML = "^5.4.1"' 'PyYAML = ">=5.4.1"' \
      --replace 'optax = "^0.0.9"' 'optax = ">=0.0.9"'
  '';

  nativeBuildInputs = [
    poetry-core
  ];

  buildInputs = [ jaxlib ];

  propagatedBuildInputs = [
    einops
    flax
    pyyaml
    rich
    treeo
  ];

  checkInputs = [
    cloudpickle
    dm-haiku
    hypothesis
    keras
    pytestCheckHook
    tensorflow
  ];

  pythonImportsCheck = [
    "treex"
  ];

  disabledTestPaths = [
    # Require `torchmetrics` which is not packaged in `nixpkgs`.
    "tests/metrics/test_mean_absolute_error.py"
    "tests/metrics/test_mean_square_error.py"
  ];

  meta = with lib; {
    description = "Pytree Module system for Deep Learning in JAX";
    homepage = "https://github.com/cgarciae/treex";
    license = licenses.mit;
    maintainers = with maintainers; [ ndl ];
  };
}