summary refs log tree commit diff
path: root/pkgs/development/python-modules/dalle-mini/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/dalle-mini/default.nix')
-rw-r--r--pkgs/development/python-modules/dalle-mini/default.nix18
1 files changed, 12 insertions, 6 deletions
diff --git a/pkgs/development/python-modules/dalle-mini/default.nix b/pkgs/development/python-modules/dalle-mini/default.nix
index 0f5a0f07e06..0c768ba5dbe 100644
--- a/pkgs/development/python-modules/dalle-mini/default.nix
+++ b/pkgs/development/python-modules/dalle-mini/default.nix
@@ -1,6 +1,7 @@
 { lib
 , buildPythonPackage
 , fetchPypi
+, fetchpatch
 , einops
 , emoji
 , flax
@@ -15,17 +16,21 @@
 
 buildPythonPackage rec {
   pname = "dalle-mini";
-  version = "0.1.4";
+  version = "0.1.5";
+  format = "setuptools";
 
   src = fetchPypi {
     inherit pname version;
-    hash = "sha256-UwCcoKbGxZT5XB+Mtv8kAHFdj0iLw8U1Ayo60y3Tm7U=";
+    hash = "sha256-k4XILjNNz0FPcAzwPEeqe5Lj24S2Y139uc9o/1IUS1c=";
   };
 
-  format = "setuptools";
-
-  buildInputs = [
-    jaxlib
+  # Fix incompatibility with the latest JAX versions
+  # See https://github.com/borisdayma/dalle-mini/pull/338
+  patches = [
+    (fetchpatch {
+      url = "https://github.com/borisdayma/dalle-mini/pull/338/commits/22ffccf03f3e207731a481e3e42bdb564ceebb69.patch";
+      hash = "sha256-LIOyfeq/oVYukG+1rfy5PjjsJcjADCjn18x/hVmLkPY=";
+    })
   ];
 
   propagatedBuildInputs = [
@@ -34,6 +39,7 @@ buildPythonPackage rec {
     flax
     ftfy
     jax
+    jaxlib
     pillow
     transformers
     unidecode