from copy import deepcopy

from soupsieve.util import lower

from datasets_ import *
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers import AutoModelForVision2Seq
from transformers import LlavaForConditionalGeneration, AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images
from tqdm import tqdm

from core.utils import load_image
import re


PROMPT = '''Given the question and the corresponding options below, first provide the correct answer. After that, explain step by step why this answer is the best choice, considering the information available. Finally, wrap the correct answer in ## ##.
Example:

Question:
Which of the following is an example of a mammal?
A. Shark.
B. Whale.
C. Lizard.
D. Frog.

Output:
B
Reason:
1.Mammals are vertebrates that typically have hair or fur and give live birth (except for monotremes).
2.Whales, although aquatic, are mammals because they give live birth, nurse their young with milk, and have a warm-blooded metabolism.
## B ##

Question:
What is the chemical symbol for gold? A. Au. B. Ag. C. Fe. D. Hg

Output:
A
Reason:
1.The chemical symbol for gold is derived from its Latin name "Aurum."
2.Au is the symbol that corresponds to gold, making it the correct answer.
## A ##

Question:
{question}

Output:
'''

def extract_answer2(res):
    word = res.split(' ')
    res = word[-1]
    res = res.replace('"', '')
    res = res.replace('.', '')
    return lower(res)

OUTPROMPT = """
Given a question, a set of options, and two provided answers, your task is to determine which of the two answers is correct and output either 1 or 2, where 1 represents that Answer 1 is correct, and 2 represents that Answer 2 is correct.
Example:

Question:
What is the direction of Earth's rotation?  
A. West to East  
B. East to West  
C. South to North  
D. North to South

Answer1:A
Answer2:B

Output:
1

Question:
{q}

Output:
"""

def extract_answer(res):
    return res

# with torch.inference_mode():
def compare_experiments(model_name, dataset_name):
    data = None

    if dataset_name == 'animals':
        data = Animals(root='/home/15t/fzy/code/raie/data/Animals_with_Attributes2', test=False, return_image_path=True)
    else:
        pass

    print('data loaded')

    image_list = data.image_list
    prompt = data.prompt
    answer = data.label_list

    all = 0
    acc = 0

    min_pixels = 256 * 28 * 28
    max_pixels = 512 * 28 * 28

    vlm = Qwen2VLForConditionalGeneration.from_pretrained(
        '/home/15t/fzy/code/raie/Qwen2-VL-2B/Qwen2-VL-2B-Instruct',
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    vlm_processor = AutoProcessor.from_pretrained('/home/15t/fzy/code/raie/Qwen2-VL-2B/Qwen2-VL-2B-Instruct',
                                                  min_pixels=min_pixels, max_pixels=max_pixels)

    generation_config = {
        "top_p": 0.8,
        "top_k": 100,
        "temperature": 0.7,
        "do_sample": True,
        "repetition_penalty": 1.05
    }

    if model_name == 'ide':
        count = 0.
        all = 0.000001
        device = 'cuda:1'

        processor = AutoProcessor.from_pretrained("idefics2")
        model = AutoModelForVision2Seq.from_pretrained("idefics2", torch_dtype=torch.bfloat16).to(device)

        for image_path, p, a in tqdm(zip(image_list, prompt, answer), desc='Processing', total=len(image_list)):
            qw_message = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            # "image": giraffe,
                        },
                        {"type": "text", "text": f"{p}"},
                    ],
                },
            ]

            texts = [vlm_processor.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True
            ) for message in [qw_message]]

            images = load_image(image_path)
            inputs = vlm_processor(
                text=texts,
                images=[images],
                # videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(vlm.device)

            with torch.inference_mode():
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_1 = generated_texts
                response_1_ = [extract_answer(res) for res in response_1]

                real_world = f"'You previously answered \'{response_1_[0]}\'. Can you be more specific and provide the exact species or breed of the animal in the image, if possible? Answer only with the name in lowercase."

                qw_message.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[0]}'}]})
                qw_message.append({'role': 'user', 'content': [{'type': 'text',
                                                                'text': f'{real_world}'}]})

                texts = [vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in [qw_message]]
                inputs = vlm_processor(
                    text=texts,
                    images=[images],
                    # videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
                inputs = inputs.to(vlm.device)
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_2 = generated_texts
                response_2_ = [extract_answer(res) for res in response_2]

            q = f'{p}\n\nAnswer1: {response_1_[0]}\nAnswer2: {response_2_[0]}'

            message = [{
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": f"{OUTPROMPT.format(q=q)}"}

                ]
            }]

            t = processor.apply_chat_template(message, add_generation_prompt=True)
            image = load_image(image_path)

            inputs = processor(text=t, images=[image], padding=True, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            with torch.inference_mode():
                generated_ids = model.generate(**inputs, max_new_tokens=10, **generation_config)
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
            ]
            generated_texts = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)[0]
            answer_ = generated_texts.replace('.', '')

            if response_1_[0] != response_2_[0]:
                if response_1_[0] == a:
                    all += 1
                    if answer_ == '1':
                        count += 1
                elif response_2_[0] == a:
                    all += 1
                    if answer_ == '2':
                        count += 1
                else:
                    continue

        print(count, all)
        print(f"{dataset_name} {model_name} preference acc: ", count / all)
    elif model_name == 'deepseek':
        vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained('deepseekvl')
        tokenizer = vl_chat_processor.tokenizer

        vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained('deepseekvl', trust_remote_code=True)
        vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

        for image_path, p, a in tqdm(zip(image_list, prompt, answer), desc='Processing', total=len(image_list)):
            message = [
                {
                "role": "User",
                "content":  f"<image_placeholder>{p}",
                "images": [image_path]

                },
                {
                    "role": "Assistant",
                    "content": ""
                }

            ]

            pil_images = load_pil_images(message)
            prepare_inputs = vl_chat_processor(
                conversations=message,
                images=pil_images,
                force_batchify=True
            ).to(vl_gpt.device)

            inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
            outputs = vl_gpt.language_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=prepare_inputs.attention_mask,
                pad_token_id=tokenizer.eos_token_id,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                max_new_tokens=30,
                do_sample=False,
                use_cache=True
            )
            answer_ = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)

            response_1 = extract_answer2(answer_)

            with torch.inference_mode():
                out = vl_gpt.language_model(
                    inputs_embeds=inputs_embeds,
                    attention_mask=prepare_inputs.attention_mask,
                    pad_token_id=tokenizer.eos_token_id,
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    max_new_tokens=10,
                    do_sample=False,
                    use_cache=True)
            pro_a = torch.max(torch.softmax(out.logits[:, -1], -1)[0])

            real_world = f"You previously answered \'{response_1}\'. Can you be more specific and provide the exact species or breed of the animal in the image, if possible? Answer only with the name in lowercase."

            message = [
                {
                    "role": "User",
                    "content": f"<image_placeholder>{p}",
                    "images": [image_path]

                },

                {
                    "role": "Assistant",
                    "content": f"{answer_}"
                }

            ]

            message.append({'role': 'User', 'content': f'{real_world}'})
            message.append({"role": "Assistant", "content": ""})

            pil_images = load_pil_images(message)
            prepare_inputs = vl_chat_processor(
                conversations=message,
                images=pil_images,
                force_batchify=True
            ).to(vl_gpt.device)

            inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
            outputs = vl_gpt.language_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=prepare_inputs.attention_mask,
                pad_token_id=tokenizer.eos_token_id,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                max_new_tokens=20,
                do_sample=False,
                use_cache=True
            )
            answer_ = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)

            response_2 = extract_answer2(answer_)

            with torch.inference_mode():
                out = vl_gpt.language_model(
                    inputs_embeds=inputs_embeds,
                    attention_mask=prepare_inputs.attention_mask,
                    pad_token_id=tokenizer.eos_token_id,
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    max_new_tokens=10,
                    do_sample=False,
                    use_cache=True)

            pro_b = torch.max(torch.softmax(out.logits[:, -1], -1)[0])

            if response_1 != response_2:
                if response_1 == a:
                    all += 1
                    if pro_a > pro_b:
                        acc += 1
                elif response_2 == a:
                    all += 1
                    if pro_a < pro_b:
                        acc += 1
            # print(f"{acc/(all+0.0000001):.4f}")
        print(f"{dataset_name} {model_name} preference acc: ", acc/all)
    elif model_name == 'llava':
        processor = AutoProcessor.from_pretrained("llava1.5-hf")
        model = LlavaForConditionalGeneration.from_pretrained("llava1.5-hf", torch_dtype=torch.float16).to(0)
        count = 0.
        all = 0.0000001
        for image_path, p, a in tqdm(zip(image_list, prompt, answer), desc='Processing', total=len(image_list)):
            qw_message = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            # "image": giraffe,
                        },
                        {"type": "text", "text": f"{p}"},
                    ],
                },
            ]

            texts = [vlm_processor.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True
            ) for message in [qw_message]]

            images = load_image(image_path)
            inputs = vlm_processor(
                text=texts,
                images=[images],
                # videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(vlm.device)

            with torch.inference_mode():
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_1 = generated_texts
                response_1_ = [extract_answer(res) for res in response_1]
                print(response_1_)
                real_world = f"'You previously answered \'{response_1_[0]}\'. Can you be more specific and provide the exact species or breed of the animal in the image, if possible? Answer only with the name in lowercase."

                qw_message.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[0]}'}]})
                qw_message.append({'role': 'user', 'content': [{'type': 'text',
                                                                'text': f'{real_world}'}]})

                texts = [vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in [qw_message]]
                inputs = vlm_processor(
                    text=texts,
                    images=[images],
                    # videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
                inputs = inputs.to(vlm.device)
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_2 = generated_texts
                response_2_ = [extract_answer(res) for res in response_2]
                print(response_2_[0])

            q = f'{p}\n\nAnswer1: {response_1_[0]}\nAnswer2: {response_2_[0]}'

            message = [{
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": f"{OUTPROMPT.format(q=q)}"}

                ]
            }]

            t = processor.apply_chat_template(message, add_generation_prompt=True)
            image = load_image(image_path)

            inputs = processor(images=image, text=t, return_tensors='pt').to(0, torch.float16)
            # inputs = {k: v.to(device) for k, v in inputs.items()}

            generated_ids = model.generate(**inputs, max_new_tokens=10, **generation_config)
            generated_ids_item = [gerei[len(inpui):] for inpui, gerei in zip(inputs['input_ids'], generated_ids)]
            generated_texts = processor.batch_decode(generated_ids_item, skip_special_tokens=True)[0]

            answer_ = generated_texts
            print(generated_texts)
            if response_1_[0] != response_2_[0]:
                if response_1_[0] == a:
                    all += 1
                    if answer_ == '1':
                        count += 1
                elif response_2_[0] == a:
                    all += 1
                    if answer_ == '2':
                        count += 1
                else:
                    continue

        print(count, all)
        print(f"{dataset_name} {model_name} preference acc: ", count / all)
    else:
        model = AutoModel.from_pretrained('miniCPM/MiniCPM-V-2_6', trust_remote_code=True,
                                  attn_implementation='sdpa', torch_dtype=torch.bfloat16)

        model = model.eval().cuda()
        tokenizer = AutoTokenizer.from_pretrained('miniCPM/MiniCPM-V-2_6', trust_remote_code=True)
        processor = AutoProcessor.from_pretrained('miniCPM/MiniCPM-V-2_6', trust_remote_code=True)
        count = 0.
        all = 0.0000001
        for image_path, p, a in tqdm(zip(image_list, prompt, answer), desc='Processing', total=len(image_list)):
            qw_message = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            # "image": giraffe,
                        },
                        {"type": "text", "text": f"{p}"},
                    ],
                },
            ]

            texts = [vlm_processor.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True
            ) for message in [qw_message]]

            images = load_image(image_path)
            inputs = vlm_processor(
                text=texts,
                images=[images],
                # videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(vlm.device)

            with torch.inference_mode():
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_1 = generated_texts
                response_1_ = [extract_answer(res) for res in response_1]

                real_world = f"'You previously answered \'{response_1_[0]}\'. Can you be more specific and provide the exact species or breed of the animal in the image, if possible? Answer only with the name in lowercase."

                qw_message.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[0]}'}]})
                qw_message.append({'role': 'user', 'content': [{'type': 'text',
                                                                'text': f'{real_world}'}]})

                texts = [vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in [qw_message]]
                inputs = vlm_processor(
                    text=texts,
                    images=[images],
                    # videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
                inputs = inputs.to(vlm.device)
                generated_ids = vlm.generate(**inputs, max_new_tokens=1024)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_2 = generated_texts
                response_2_ = [extract_answer(res) for res in response_2]

            q = f'{p}\n\nAnswer1: {response_1_[0]}\nAnswer2: {response_2_[0]}'

            image = load_image(image_path)
            message = [{
                "role": "user",
                "content": [image, OUTPROMPT.format(q=q)]}]
            copy_msgs = deepcopy(message)
            images = []

            generation_config = {
                "top_p": 0.8,
                "top_k": 100,
                "temperature": 0.7,
                "do_sample": True,
                "repetition_penalty": 1.05
            }

            for i, msg in enumerate(copy_msgs):
                role = msg["role"]
                content = msg["content"]
                assert role in ["user", "assistant"]
                if i == 0:
                    assert role == "user", "The role of first msg should be user"
                if isinstance(content, str):
                    content = [content]
                cur_msgs = []
                for c in content:
                    if isinstance(c, Image.Image):
                        images.append(c)
                        cur_msgs.append("(<image>./</image>)")
                    elif isinstance(c, str):
                        cur_msgs.append(c)
                msg["content"] = "\n".join(cur_msgs)

            text = processor.tokenizer.apply_chat_template(copy_msgs, add_generation_prompt=True, tokenize=False)

            inputs = processor([text],
                               [images],
                               return_tensors="pt",
                               max_length=8192).to(model.device)
            inputs.pop("image_sizes")
            with torch.inference_mode():
                out_ids = model.generate(**inputs, tokenizer=tokenizer, max_new_tokens=520, decode_text=True,
                                         **generation_config)[0]
                answer_ = extract_answer(out_ids)

                # y = {'input_ids': inputs['input_ids'], 'image_bound': None, 'pixel_values':None, 'tgt_sizes': None}

            if response_1_[0] != response_2_[0]:
                if response_1_[0] == a:
                    all += 1
                    if answer_ == '1':
                        count += 1
                elif response_2_[0] == a:
                    all += 1
                    if answer_ == '2':
                        count += 1
                else:
                    continue

        print(count, all)
        print(f"{dataset_name} {model_name} preference acc: ", count / all)

if __name__ == '__main__':
    data_list = ['animals', 'animals', 'animals'] #ide, deepseek, llava
    # data_list = ['realworld']
    for i in data_list:
        compare_experiments('ide', i)




