import requests
import torch
import os, sys, copy
import io
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from urllib.request import urlopen
from loguru import logger


class TransformerAgentPhimi:
    def __init__(self, model_path) -> None:
        # import soundfile as sf
        self.device = "cuda"
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="cuda",
            torch_dtype="auto",
            trust_remote_code=True,
            # if you do not use Ampere or later GPUs, change attention to "eager"
            _attn_implementation="flash_attention_2",
        ).cuda()
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        self.generation_config = GenerationConfig.from_pretrained(model_path)

    def convert_openai_to_paligemma(self, messages):
        text_parts = []
        pil_images = []
        for message in messages:
            if message["role"] == "system":
                continue
            content = message.get("content", [])
            if not isinstance(content, list):
                content = [{"type": "text", "text": content}]
            for item in content:
                if item["type"] == "text":
                    text_parts.append(item["text"])
                elif item["type"] == "image":
                    image_path = item["image"]
                    try:
                        image = Image.open(image_path).convert("RGB")
                        pil_images.append(image)
                    except FileNotFoundError:
                        logger.warning(f"Warning: Image file not found at {image_path}. Skipping.")
        final_image = None
        if pil_images:
            if len(pil_images) == 1:
                final_image = pil_images[0]
            else:
                total_width = sum(img.width for img in pil_images)
                max_height = max(img.height for img in pil_images)

                stitched_image = Image.new("RGB", (total_width, max_height))

                current_x = 0
                for img in pil_images:
                    stitched_image.paste(img, (current_x, 0))
                    current_x += img.width
                final_image = stitched_image

        full_text = " ".join(text_parts)
        prompt_text = f"{full_text}"
        return prompt_text, final_image

    def send_chat_request(self, messages):
        prompt_cat, image_cat = self.convert_openai_to_paligemma(messages)

        user_prompt = "<|user|>"
        assistant_prompt = "<|assistant|>"
        prompt_suffix = "<|end|>"
        prompt = f"{user_prompt}<|image_1|>{prompt_cat}{prompt_suffix}{assistant_prompt}"

        image = image_cat
        inputs = self.processor(text=prompt, images=image, return_tensors="pt").to("cuda:0")
        generate_ids = self.model.generate(
            **inputs,
            max_new_tokens=1000,
            generation_config=self.generation_config,
        )
        generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
        output_text = self.processor.batch_decode(
            generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        return output_text, None, None, None


class TransformerAgentPhivi:
    def __init__(self, model_path) -> None:
        self.device = "auto"
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map=self.device,
            trust_remote_code=True,
            torch_dtype="auto",
            _attn_implementation="flash_attention_2",
        )
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, num_crops=4)

    def convert_openai_to_paligemma(self, messages):
        text_parts = []
        pil_images = []
        for message in messages:
            if message["role"] == "system":
                continue
            content = message.get("content", [])
            if not isinstance(content, list):
                content = [{"type": "text", "text": content}]
            for item in content:
                if item["type"] == "text":
                    text_parts.append(item["text"])
                elif item["type"] == "image":
                    image_path = item["image"]
                    try:
                        image = Image.open(image_path).convert("RGB")
                        pil_images.append(image)
                    except FileNotFoundError:
                        logger.warning(f"Warning: Image file not found at {image_path}. Skipping.")
        final_image = None
        if pil_images:
            if len(pil_images) == 1:
                final_image = pil_images[0]
            else:
                total_width = sum(img.width for img in pil_images)
                max_height = max(img.height for img in pil_images)

                stitched_image = Image.new("RGB", (total_width, max_height))

                current_x = 0
                for img in pil_images:
                    stitched_image.paste(img, (current_x, 0))
                    current_x += img.width
                final_image = stitched_image

        full_text = " ".join(text_parts)
        prompt_text = f"{full_text}"
        return prompt_text, final_image

    def send_chat_request(self, messages):
        prompt_cat, image_cat = self.convert_openai_to_paligemma(messages)
        messages_for_processor = [
            {"role": "user", "content": "<|image_1|>\n" + f"{prompt_cat}"},
        ]
        prompt = self.processor.tokenizer.apply_chat_template(
            messages_for_processor, tokenize=False, add_generation_prompt=True
        )
        inputs = self.processor(prompt, [image_cat], return_tensors="pt").to("cuda")
        generation_args = {
            "max_new_tokens": 4096,
            "temperature": 0.0,
            "do_sample": False,
        }
        generate_ids = self.model.generate(
            **inputs, eos_token_id=self.processor.tokenizer.eos_token_id, **generation_args
        )
        generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
        output_text = self.processor.batch_decode(
            generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        return output_text, None, None, None


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "4"
    model_path = "models/Phi-4-multimodal-instruct"  # NOTE
    # model_path = "models/Phi-3.5-vision-instruct"  # NOTE

    # miniforge3/condabin/conda run -n phimi --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_phi.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgentPhimi(model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result", model_name))
    eval.run_one_prediction_local()
