from transformers import Qwen2VLForConditionalGeneration,AutoModelForCausalLM, AutoTokenizer, AutoProcessor
from transformers.generation import GenerationConfig
from PIL import Image
import torch
torch.manual_seed(1234)

from agents.agent import Agent
import json
import os


class QwenVLAgent(Agent):
    def __init__(self, role="SOLVER") -> None:
        super(QwenVLAgent, self).__init__(role)
        # List of strings. Each element (string) is a message alternating between solver and expert
        
        self.conversation = []
        
        # api_version = "2024-02-01",
        # # your endpoint should look like the following https://YOUR_RESOURCE_NAME.openai.azure.com/
        # api_base = "https://gpt4v-simulation.openai.azure.com/"
        
        config = json.load(open(os.path.join("config", "keys.json")))
        
        # api_key = config["AZURE_OPENAI_GPT4V_KEY"]
        # deployment_name = 'gpt-4v'
        # api_version = '2023-12-01-preview'  # this might change in the future
        # self.client = AzureOpenAI(
        #     api_key=api_key,
        #     api_version=api_version,
        #     base_url=f"{api_base}openai/deployments/{deployment_name}/extensions",
        # )
        
        # Note: The default behavior now has injection attack prevention off.
        # self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)

        # use bf16
        # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
        # use fp16
        # self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
        # use cpu only
        # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cpu", trust_remote_code=True).eval()
        # use cuda device


        self.model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto").cuda()
        self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", max_pixels=256 * 28 * 28)

        # Specify hyperparameters for generation
        self.model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)

        # reference: https://github.com/QwenLM/Qwen-VL

    def clear(self):
        self.conversation = []

    # Given conversation history, respond to message.
    def respond(self, image_data, actions, message=None):

        ret = self.get_conversation_history_string(
            image_data=image_data, actions=actions, message=message, model="qwenVL")

        if type(ret) == tuple:
            image_path, llm_input = ret
            llm_input = "This is a picture of the puzzle manual. " + llm_input
            # query = self.tokenizer.from_list_format([
            #         {'image': image_path},
            #         {'text': llm_input},
            #     ])
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                        },
                        {"type": "text", "text": llm_input},
                    ],
                }
            ]
            text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
            inputs = self.processor(
                text=[text_prompt], images=[Image.open(image_path)], padding=True, return_tensors="pt"
            )
        else:
            llm_input = ret
            if type(ret) == str:
                
                if image_data != None:
                    # query = self.tokenizer.from_list_format([
                    #     {'image': image_data},
                    #     {'text': llm_input},
                    # ])
                    conversation = [
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "image",
                                },
                                {"type": "text", "text": llm_input},
                            ],
                        }
                    ]
                    text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
                    inputs = self.processor(
                        text=[text_prompt], images=[image_data], padding=True, return_tensors="pt"
                    )
                else:
                    conversation = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": llm_input},
                            ],
                        }
                    ]
                    text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
                    inputs = self.processor(
                        text=[text_prompt], images=None, padding=True, return_tensors="pt"
                    )
            else:
                # inputs = self.tokenizer.from_list_format(llm_input)
                conversation = [
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": llm_input},
                        ],
                    }
                ]
                text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
                inputs = self.processor(
                    text=[text_prompt], images=None, padding=True, return_tensors="pt"
                )

        # prompt = self.processor.apply_chat_template(query, add_generation_prompt=True)

        inputs = inputs.to("cuda")

        #query_tokens = self.tokenizer(query, return_tensors="pt")
        # response = self.model.chat_stream(self.tokenizer, query=query, history=None)
        output_ids = self.model.generate(**inputs, max_new_tokens=128, temperature=0.01)
        generated_ids = [
            output_ids[len(input_ids) :]
            for input_ids, output_ids in zip(inputs.input_ids, output_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        predicted_action = output_text[0]
        self.conversation.append(predicted_action)
        return predicted_action
