summary refs log tree commit diff
path: root/pkgs/development/python-modules/torchmetrics/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/torchmetrics/default.nix')
-rw-r--r--pkgs/development/python-modules/torchmetrics/default.nix42
1 files changed, 31 insertions, 11 deletions
diff --git a/pkgs/development/python-modules/torchmetrics/default.nix b/pkgs/development/python-modules/torchmetrics/default.nix
index e3d51f7b551..ee66ee817d1 100644
--- a/pkgs/development/python-modules/torchmetrics/default.nix
+++ b/pkgs/development/python-modules/torchmetrics/default.nix
@@ -1,6 +1,9 @@
 { lib
 , buildPythonPackage
 , fetchFromGitHub
+, pythonOlder
+, numpy
+, lightning-utilities
 , cloudpickle
 , scikit-learn
 , scikit-image
@@ -11,23 +14,30 @@
 , pytestCheckHook
 , torchmetrics
 , pytorch-lightning
+, pytest-doctestplus
+, pytest-xdist
 }:
 
 let
   pname = "torchmetrics";
-  version = "0.11.4";
+  version = "1.2.0";
 in
 buildPythonPackage {
   inherit pname version;
+  pyproject = true;
 
   src = fetchFromGitHub {
-    owner = "PyTorchLightning";
-    repo = "metrics";
+    owner = "Lightning-AI";
+    repo = "torchmetrics";
     rev = "refs/tags/v${version}";
-    hash = "sha256-K8QLdDpnS4N8s3zXsifFloRXW/QXEm36mJXXKEBEJBs=";
+    hash = "sha256-g5JuTbiRd8yWx2nM3UE8ejOhuZ0XpAQdS5AC9AlrSFY=";
   };
 
+  disabled = pythonOlder "3.8";
+
   propagatedBuildInputs = [
+    numpy
+    lightning-utilities
     packaging
     py-deprecate
   ];
@@ -44,22 +54,33 @@ buildPythonPackage {
     cloudpickle
     psutil
     pytestCheckHook
+    pytest-doctestplus
+    pytest-xdist
   ];
 
   # A cyclic dependency in: integrations/test_lightning.py
   doCheck = false;
   passthru.tests.check = torchmetrics.overridePythonAttrs (_: {
+    pname = "${pname}-check";
     doCheck = true;
+    # We don't have to install because the only purpose
+    # of this passthru test is to, well, test.
+    # This fixes having to set `catchConflicts` to false.
+    dontInstall = true;
   });
 
+  disabledTests = [
+    # `IndexError: list index out of range`
+    "test_metric_lightning_log"
+  ];
+
   disabledTestPaths = [
     # These require too many "leftpad-level" dependencies
-    "tests/text"
-    "tests/audio"
-    "tests/image"
+    # Also too cross-dependent
+    "tests/unittests"
 
-    # A few non-deterministic things like test_check_compute_groups_is_faster
-    "tests/bases/test_collections.py"
+    # A trillion import path mismatch errors
+    "src/torchmetrics"
   ];
 
   pythonImportsCheck = [
@@ -68,11 +89,10 @@ buildPythonPackage {
 
   meta = with lib; {
     description = "Machine learning metrics for distributed, scalable PyTorch applications (used in pytorch-lightning)";
-    homepage = "https://torchmetrics.readthedocs.io";
+    homepage = "https://lightning.ai/docs/torchmetrics/";
     license = licenses.asl20;
     maintainers = with maintainers; [
       SomeoneSerge
     ];
   };
 }
-