import json
import os
from agents.agent import Agent
import torch
from PIL import Image
from huggingface_hub import login
from transformers import MllamaForConditionalGeneration, AutoProcessor


class LLaMA32(Agent):
    def __init__(self, role="SOLVER") -> None:
        """
        Initialize the agent. Init your model accordingly.
        """
        super(LLaMA32, self).__init__(role)

        model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

        self.model = MllamaForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
        ).cuda()
        self.processor = AutoProcessor.from_pretrained(model_id)

    def clear(self):
        """
        Clears the conversation history.
        """
        self.conversation = []

    def respond(self, input_data, actions, message=None):
        """
        Respond based on the given conversation history. Return the predicted action.
        Implement model-specific behavior here.

        Args:
            input_data: Data to process (e.g., image, text, etc.).
            actions: Any additional actions to consider when generating the response.
            message: Optional message input.

        Returns:
            predicted_action: The response generated by the model.
        """
        # Implement conversation history string formation here, if needed
        ret = self.get_conversation_history_string(image_data=input_data, actions=actions, message=message)
        if type(ret) == tuple:
            image, llm_input = ret
            image = Image.open(image)
            text = "This is a picture of the puzzle manual. " + llm_input
            messages = [
                {"role": "user", "content": [
                    {"type": "image"},
                    {"type": "text", "text": text}
                ]}
            ]
        else:
            image = None

            messages = [
                {"role": "user", "content": [
                    {"type": "text", "text": ret}
                ]}
            ]
       
        input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = self.processor(
            image,
            input_text,
            add_special_tokens=False,
            return_tensors="pt"
        ).to("cuda:0")

        output = self.model.generate(**inputs, max_new_tokens=2000)
        predicted_action = self.processor.decode(output[0])[len(input_text):].replace("<|eot_id|>", "")
        
        # Append to conversation history
        self.conversation.append(predicted_action)

        return predicted_action
