from copy import deepcopy

from datasets_ import *
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers import AutoModelForVision2Seq
from transformers import LlavaForConditionalGeneration, AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images
from tqdm import tqdm

from core.utils import load_image
import re


PROMPT = '''Given the question and the corresponding options below, first provide the correct answer. After that, explain step by step why this answer is the best choice, considering the information available. Finally, wrap the correct answer in ## ##.
Example:

Question:
Which of the following is an example of a mammal?
A. Shark.
B. Whale.
C. Lizard.
D. Frog.

Output:
B
Reason:
1.Mammals are vertebrates that typically have hair or fur and give live birth (except for monotremes).
2.Whales, although aquatic, are mammals because they give live birth, nurse their young with milk, and have a warm-blooded metabolism.
## B ##

Question:
What is the chemical symbol for gold? A. Au. B. Ag. C. Fe. D. Hg

Output:
A
Reason:
1.The chemical symbol for gold is derived from its Latin name "Aurum."
2.Au is the symbol that corresponds to gold, making it the correct answer.
## A ##

Question:
{question}

Output:
'''


OUTPROMPT = """
Given a question, a set of options, and two provided answers, your task is to determine which of the two answers is correct and output either 1 or 2, where 1 represents that Answer 1 is correct, and 2 represents that Answer 2 is correct.
Example:

Question:
What is the direction of Earth's rotation?  
A. West to East  
B. East to West  
C. South to North  
D. North to South

Answer1:A
Answer2:B

Output:
1

Question:
{q}

Output:
"""


def extract_answer(res):
    patterns = [
        r'##\s*([ABCD])\s*##',
        r'##(.*?)##',
        r'##\s*([0-3])\s*##'
    ]

    for pattern in patterns:
        match = re.search(pattern, res)
        if match:
            return match.group(1).strip()

    match = re.search(r'([ABCD])\.', res)
    if match:
        return match.group(1).strip()

    match = re.search(r'([A-Za-z]+[0-9]*|[0-9]+)', res)
    if match:
        return match.group(1).strip()

    return None

def get_answer(res):
    res = res.replace('.', '')

    return res

# with torch.inference_mode():
def compare_experiments(model_name, dataset_name):
    data = None

    if dataset_name == 'animals':
        data = Animals(root='data/Animals_with_Attributes2', test=False, return_image_path=True)
    elif dataset_name == 'real_world':
        data = RealworldQA(root='data/realworldQA/data', return_image_path=True)
    elif dataset_name == 'mmbench':
        data = MMBench(root='data/mmbench/data', return_image_path=True)
    elif dataset_name == 'mmstar':
        data = MMStar(root='data/mmstar', return_image_path=True)
    elif dataset_name == 'seedbench':
        data = SeedBench(root='data/seedbench/data', return_image_path=True)
    elif dataset_name == 'scienceqa':
        data = ScienceQA(root='data/science/data', return_image_path=True)
    else:
        pass

    print('data loaded')

    image_list = data.image_list
    prompt = data.prompt
    answer = data.label_list

    min_pixels = 256 * 28 * 28
    max_pixels = 512 * 28 * 28

    generation_config = {
        "top_p": 0.8,
        "top_k": 100,
        "temperature": 0.7,
        "do_sample": True,
        "repetition_penalty": 1.05
    }

    vlm = Qwen2VLForConditionalGeneration.from_pretrained(
        'Qwen2-VL-2B/Qwen2-VL-2B-Instruct',
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    vlm_processor = AutoProcessor.from_pretrained('Qwen2-VL-2B/Qwen2-VL-2B-Instruct',
                                                  min_pixels=min_pixels, max_pixels=max_pixels)


    extractor = "Your final answer should be put between two ##, like ## A ## (if your final answer is A), at the end of your response."
    real_world = "Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer."

    if model_name == 'ide':
        device = 'cuda:1'

        processor = AutoProcessor.from_pretrained("idefics2")
        model = AutoModelForVision2Seq.from_pretrained("idefics2", torch_dtype=torch.bfloat16).to(device)


        all = 0.00000001
        count = 0.
        # real_world = "Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer."
        # ex = 'Your final answer should be put between two ##, like ## A ## (if your final answer is A), at the end of your response.'
        # real_world = """Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer. First you should output the corrected option. After that, explain step by step why this answer is better than the previous one. Finally, wrap the correct answer in ## ##.
        #                     Example:
        #                     A
        #                     Reason:
        #                     1. The reason1.
        #                     2. The reason2.
        #                     ## A ##
        #
        #                     Output:
        #                     """
        for image_path, p, a in tqdm(zip(image_list, prompt, answer), desc='Processing', total=len(image_list)):
            qw_message = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            # "image": giraffe,
                        },
                        {"type": "text", "text": f"{PROMPT.format(question=p)}"},
                    ],
                },
            ]

            texts = [vlm_processor.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True
            ) for message in [qw_message]]

            images = load_image(image_path)
            inputs = vlm_processor(
                text=texts,
                images=[images],
                # videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(vlm.device)

            with torch.inference_mode():
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_1 = generated_texts
                response_1_ = [extract_answer(res) for res in response_1]

                qw_message.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[0]}'}]})
                qw_message.append({'role': 'user', 'content': [{'type': 'text',
                                                            'text': f'{real_world} {extractor}'}]})

                texts = [vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in [qw_message]]
                inputs = vlm_processor(
                    text=texts,
                    images=[images],
                    # videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
                inputs = inputs.to(vlm.device)
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_2 = generated_texts
                response_2_ = [extract_answer(res) for res in response_2]

            q = f'{p}\n\nAnswer1: {response_1_[0]}\nAnswer2: {response_2_[0]}'

            message = [{
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": f"{OUTPROMPT.format(q=q)}"}

                ]
            }]

            t = processor.apply_chat_template(message, add_generation_prompt=True)
            image = load_image(image_path)

            inputs = processor(text=t, images=[image], padding=True, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}

            generated_ids = model.generate(**inputs, max_new_tokens=10, **generation_config)
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
            ]
            generated_texts = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)[0]

            answer_ = get_answer(generated_texts)

            if response_1_[0] != response_2_[0]:
                if response_1_[0] == a:
                    all += 1
                    if answer_ == '1':
                        count += 1
                elif response_2_[0] == a:
                    all += 1
                    if answer_ == '2':
                        count += 1
                else:
                    continue

        print(count, all)
        print(f"{dataset_name} {model_name} preference acc: ", count / all)
    elif model_name == 'llava':

        processor = AutoProcessor.from_pretrained("llava1.5-hf")
        model = LlavaForConditionalGeneration.from_pretrained("llava1.5-hf", torch_dtype=torch.float16).to(0)
        all = 0.00000001
        count = 0.
        # real_world = "Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer."
        # ex = 'Your final answer should be put between two ##, like ## A ## (if your final answer is A), at the end of your response.'
        # real_world = """Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer. First you should output the corrected option. After that, explain step by step why this answer is better than the previous one. Finally, wrap the correct answer in ## ##.
        #                             Example:
        #                             A
        #                             Reason:
        #                             1. The reason1.
        #                             2. The reason2.
        #                             ## A ##
        #
        #                             Output:
        #                             """
        for image_path, p, a in tqdm(zip(image_list, prompt, answer), desc='Processing', total=len(image_list)):
            qw_message = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            # "image": giraffe,
                        },
                        {"type": "text", "text": f"{PROMPT.format(question=p)}"},
                    ],
                },
            ]

            texts = [vlm_processor.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True
            ) for message in [qw_message]]

            images = load_image(image_path)
            inputs = vlm_processor(
                text=texts,
                images=[images],
                # videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(vlm.device)

            with torch.inference_mode():
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_1 = generated_texts
                response_1_ = [extract_answer(res) for res in response_1]

                qw_message.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[0]}'}]})
                qw_message.append({'role': 'user', 'content': [{'type': 'text',
                                                                'text': f'{real_world} {extractor}'}]})

                texts = [vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in [qw_message]]
                inputs = vlm_processor(
                    text=texts,
                    images=[images],
                    # videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
                inputs = inputs.to(vlm.device)
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_2 = generated_texts
                response_2_ = [extract_answer(res) for res in response_2]

            q = f'{p}\n\nAnswer1: {response_1_[0]}\nAnswer2: {response_2_[0]}'

            message = [{
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": f"{OUTPROMPT.format(q=q)}"}

                ]
            }]

            t = processor.apply_chat_template(message, add_generation_prompt=True)
            image = load_image(image_path)

            inputs = processor(images=image, text=t, return_tensors='pt').to(0, torch.float16)
            # inputs = {k: v.to(device) for k, v in inputs.items()}
            with torch.inference_mode():
                generated_ids = model.generate(**inputs, max_new_tokens=10, **generation_config)
            generated_ids_item = [gerei[len(inpui):] for inpui, gerei in zip(inputs['input_ids'], generated_ids)]
            generated_texts = processor.batch_decode(generated_ids_item, skip_special_tokens=True)[0]


            answer_ = get_answer(generated_texts)

            if response_1_[0] != response_2_[0]:
                if response_1_[0] == a:
                    all += 1
                    if answer_ == '1':
                        count += 1

                elif response_2_[0] == a:
                    all += 1
                    if answer_ == '2':
                        count += 1

                else:
                    continue

            acc_ = count / (all + 0.00001)

        print(count, all)
        print(f"{dataset_name} {model_name} preference acc: ", count / all)
    else:
        model = AutoModel.from_pretrained('miniCPM/MiniCPM-V-2_6', trust_remote_code=True,
                                  attn_implementation='sdpa', torch_dtype=torch.bfloat16)

        model = model.eval().cuda()
        tokenizer = AutoTokenizer.from_pretrained('miniCPM/MiniCPM-V-2_6', trust_remote_code=True)
        processor = AutoProcessor.from_pretrained('miniCPM/MiniCPM-V-2_6', trust_remote_code=True)
        all = 0.00000001
        count = 0.
        # real_world = """Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer. First you should output the corrected option. After that, explain step by step why this answer is better than the previous one. Finally, wrap the correct answer in ## ##.
        #                                     Example:
        #                                     A
        #                                     Reason:
        #                                     1. The reason1.
        #                                     2. The reason2.
        #                                     ## A ##
        #
        #                                     Output:
        #                                     """

        for image_path, p, a in tqdm(zip(image_list, prompt, answer), desc='Processing', total=len(image_list)):

            qw_message = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            # "image": giraffe,
                        },
                        {"type": "text", "text": f"{PROMPT.format(question=p)}"},
                    ],
                },
            ]

            texts = [vlm_processor.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True
            ) for message in [qw_message]]

            images = load_image(image_path)
            inputs = vlm_processor(
                text=texts,
                images=[images],
                # videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(vlm.device)

            with torch.inference_mode():
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_1 = generated_texts
                response_1_ = [extract_answer(res) for res in response_1]

                # For the task of Animals use these:
                # [m.append({'role': 'user', 'content': [{'type': 'text',
                #                                         'text': f'You previously answered \'{response_1[i]}\'. Can you be more specific and provide the exact species or breed of the animal in the image, if possible? Answer only with the name in lowercase.'}]})
                #  for i, m in enumerate(base_lang_x2)]

                qw_message.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[0]}'}]})
                qw_message.append({'role': 'user', 'content': [{'type': 'text',
                                                                'text': f'{real_world} {extractor}'}]})

                texts = [vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in [qw_message]]
                inputs = vlm_processor(
                    text=texts,
                    images=[images],
                    # videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
                inputs = inputs.to(vlm.device)
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_2 = generated_texts
                response_2_ = [extract_answer(res) for res in response_2]

            q = f'{p}\n\nAnswer1: {response_1_[0]}\nAnswer2: {response_2_[0]}'
            image = load_image(image_path)
            message = [{
                "role": "user",
                "content": [image, OUTPROMPT.format(q=q)]}]
            copy_msgs = deepcopy(message)
            images = []

            generation_config = {
                "top_p": 0.8,
                "top_k": 100,
                "temperature": 0.7,
                "do_sample": True,
                "repetition_penalty": 1.05
            }


            for i, msg in enumerate(copy_msgs):
                role = msg["role"]
                content = msg["content"]
                assert role in ["user", "assistant"]
                if i == 0:
                    assert role == "user", "The role of first msg should be user"
                if isinstance(content, str):
                    content = [content]
                cur_msgs = []
                for c in content:
                    if isinstance(c, Image.Image):
                        images.append(c)
                        cur_msgs.append("(<image>./</image>)")
                    elif isinstance(c, str):
                        cur_msgs.append(c)
                msg["content"] = "\n".join(cur_msgs)

            text = processor.tokenizer.apply_chat_template(copy_msgs, add_generation_prompt=True, tokenize=False)

            inputs = processor([text],
                                [images],
                                return_tensors="pt",
                                max_length=8192).to(model.device)
            inputs.pop("image_sizes")
            with torch.inference_mode():
                out_ids = model.generate(**inputs, tokenizer=tokenizer, max_new_tokens=520, decode_text=True, **generation_config)[0]
                answer_ = get_answer(out_ids)

                # y = {'input_ids': inputs['input_ids'], 'image_bound': None, 'pixel_values':None, 'tgt_sizes': None}

                # inputs['position_ids'] = torch.arange(0, inputs['input_ids'].shape[-1]).unsqueeze(0).to(model.device)
                # out = model(data=inputs)
                # pro_a = torch.max(torch.softmax(out.logits[:, -1], -1)[0])


            if response_1_[0] != response_2_[0]:
                if response_1_[0] == a:
                    all += 1
                    if answer_ == '1':
                        count += 1
                elif response_2_[0] == a:
                    all += 1
                    if answer_ == '2':
                        count += 1
                else:
                    continue

        print(count, all)
        print(f"{dataset_name} {model_name} preference acc: ", count / all)


if __name__ == '__main__':
    data_list = ['real_world']
    # data_list = ['mmstar', 'seedbench', 'scienceqa']
    # data_list = ['mmbench']

    for i in data_list:
        compare_experiments('cpm', i)




