summary refs log tree commit diff
path: root/pkgs/development/python-modules/objax/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/objax/default.nix')
-rw-r--r--pkgs/development/python-modules/objax/default.nix33
1 files changed, 28 insertions, 5 deletions
diff --git a/pkgs/development/python-modules/objax/default.nix b/pkgs/development/python-modules/objax/default.nix
index 548039d63b3..be8a3d8327d 100644
--- a/pkgs/development/python-modules/objax/default.nix
+++ b/pkgs/development/python-modules/objax/default.nix
@@ -1,24 +1,28 @@
 { lib
-, fetchFromGitHub
 , buildPythonPackage
-, jax
+, fetchFromGitHub
+, fetchpatch
 , jaxlib
+, jax
 , numpy
 , parameterized
 , pillow
 , scipy
 , tensorboard
+, keras
+, pytestCheckHook
+, tensorflow
 }:
 
 buildPythonPackage rec {
   pname = "objax";
-  version = "1.7.0";
+  version = "1.8.0";
 
   src = fetchFromGitHub {
     owner = "google";
     repo = "objax";
-    rev = "v${version}";
-    hash = "sha256-1/XmxFZfU+XMD0Mlcv4xTUYZDwltAx1bZOlPuKWQQC0=";
+    rev = "refs/tags/v${version}";
+    hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk=";
   };
 
   # Avoid propagating the dependency on `jaxlib`, see
@@ -40,6 +44,25 @@ buildPythonPackage rec {
     "objax"
   ];
 
+  # This is necessay to ignore the presence of two protobufs version (tensorflow is bringing an
+  # older version).
+  catchConflicts = false;
+
+  nativeCheckInputs = [
+    keras
+    pytestCheckHook
+    tensorflow
+  ];
+
+  pytestFlagsArray = [
+    "tests/*.py"
+  ];
+
+  disabledTests = [
+    # Test requires internet access for prefetching some weights
+    "test_pretrained_keras_weight_0_ResNet50V2"
+  ];
+
   meta = with lib; {
     description = "Objax is a machine learning framework that provides an Object Oriented layer for JAX.";
     homepage = "https://github.com/google/objax";