from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests

from agents.agent import Agent
import json
import os
import base64
from io import BytesIO



class LLAVAAgent(Agent):
    def __init__(self, role="SOLVER") -> None:
        super(LLAVAAgent, self).__init__(role)
        # List of strings. Each element (string) is a message alternating between solver and expert
        self.conversation = []
        

        self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")

        self.model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True) 
        self.model.to("cuda:0")
        
        # reference: https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf


    def clear(self):
        self.conversation = []



    # Given conversation history, respond to message.
    def respond(self, image_data, actions, message=None):
        ret = self.get_conversation_history_string(
            image_data=image_data, actions=actions, message=message, model="llava")

        if type(ret) == tuple:
            image_path, llm_input = ret
            llm_input = "<image>\nThis is a picture of the puzzle manual. " + llm_input
            image = Image.open(image_path).convert('RGB')
        else:
            llm_input = ret

            image = None
            if image_data != None:
                image = Image.open(image_data).convert('RGB')
        inputs = self.processor(llm_input, image, return_tensors="pt").to("cuda:0")

        # autoregressively complete prompt
        output = self.model.generate(**inputs, max_new_tokens=200)

        predicted_action = self.processor.decode(output[0], skip_special_tokens=True)
        self.conversation.append(predicted_action)
        return predicted_action