from types import MethodType
import torch

def generate_with_past_key_values(self, input_ids=None, past_key_values=None, **kwargs):
    if past_key_values is not None and input_ids is None:
        # Use a dummy input_ids tensor with the correct batch size
        batch_size = past_key_values[0][0].shape[0]
        input_ids = torch.full((batch_size, 1), self.config.eos_token_id, dtype=torch.long, device=self.device)
    return self._generate(
        input_ids=input_ids,
        past_key_values=past_key_values,
        **kwargs,
    )


# Monkey-patch the generate method
# model.generate = MethodType(generate_with_past_key_values, model)


# generated_ids = model.generate(
#     input_ids=None,  # Set to None since we're using past_key_values
#     past_key_values=past_key_values,
#     max_new_tokens=50,
#     do_sample=True,
#     temperature=0.6,
#     top_p=0.9,
# )

# generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# print("Generated text:", generated_text)
