### git diff from llava commit:
commit 28b96e296d980b92e0810a5d1d00f8f1e988be85 (HEAD -> main)


diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py
index 069d0d1..765c0d3 100644
--- a/llava/model/language_model/llava_llama.py
+++ b/llava/model/language_model/llava_llama.py
@@ -106,13 +106,14 @@ def generate(
         self,
         inputs: Optional[torch.Tensor] = None,
         images: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
         image_sizes: Optional[torch.Tensor] = None,
         **kwargs,
     ) -> Union[GenerateOutput, torch.LongTensor]:
         position_ids = kwargs.pop("position_ids", None)
         attention_mask = kwargs.pop("attention_mask", None)
-        if "inputs_embeds" in kwargs:
-            raise NotImplementedError("`inputs_embeds` is not supported")
+        # if "inputs_embeds" in kwargs:
+        #     raise NotImplementedError("`inputs_embeds` is not supported")

         if images is not None:
             (
@@ -132,7 +133,10 @@ def generate(
                 image_sizes=image_sizes
             )
         else:
-            inputs_embeds = self.get_model().embed_tokens(inputs)
+            if inputs_embeds is None:
+                inputs_embeds = self.get_model().embed_tokens(inputs)
+
+        #print(images, inputs_embeds, position_ids, attention_mask, kwargs)

         return super().generate(
             position_ids=position_ids,
diff --git a/llava/model/language_model/llava_mistral.py b/llava/model/language_model/llava_mistral.py
index 0def682..f8c3c1d 100644
--- a/llava/model/language_model/llava_mistral.py
+++ b/llava/model/language_model/llava_mistral.py
@@ -106,6 +106,7 @@ def generate(
         self,
         inputs: Optional[torch.Tensor] = None,
         images: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
         image_sizes: Optional[torch.Tensor] = None,
         **kwargs,
     ) -> Union[GenerateOutput, torch.LongTensor]:
@@ -132,7 +133,8 @@ def generate(
                 image_sizes=image_sizes
             )
         else:
-            inputs_embeds = self.get_model().embed_tokens(inputs)
+            if inputs_embeds is None:
+                inputs_embeds = self.get_model().embed_tokens(inputs)

         return super().generate(
             position_ids=position_ids,
