import json
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM
import torch
from core.lrm import LearnRiskModel
from core.risk_utils import var2pro
import os
from core.build_up import building_up
import copy
import re
from datasets_ import Animals, RealworldQA, MMBench, MMStar, SeedBench, ScienceQA

PROMPT = '''Given the question and the corresponding options below, first provide the correct answer. After that, explain step by step why this answer is the best choice, considering the information available. Finally, wrap the correct answer in ## ##.
Example:

Question:
Which of the following is an example of a mammal?
A. Shark.
B. Whale.
C. Lizard.
D. Frog.

Output:
B
Reason:
1.Mammals are vertebrates that typically have hair or fur and give live birth (except for monotremes).
2.Whales, although aquatic, are mammals because they give live birth, nurse their young with milk, and have a warm-blooded metabolism.
## B ##

Question:
What is the chemical symbol for gold? A. Au. B. Ag. C. Fe. D. Hg

Output:
A
Reason:
1.The chemical symbol for gold is derived from its Latin name "Aurum."
2.Au is the symbol that corresponds to gold, making it the correct answer.
## A ##

Question:
{question}

Output:
'''


def modify_lang_x(lang_x):
    new = []
    for i in lang_x:
        tem = i[0]
        for t in tem['content']:
            if t['type'] == 'text':
                question = t['text']
                t['text'] = PROMPT.format(question=question)
        new.append([tem])

    return new

def extract_answer(res):
    patterns = [
        r'##\s*([ABCD])\s*##',
        r'##(.*?)##',
        r'##\s*([0-3])\s*##'
    ]

    for pattern in patterns:
        match = re.search(pattern, res)
        if match:
            return match.group(1).strip()

    match = re.search(r'([ABCD])\.', res)
    if match:
        return match.group(1).strip()

    match = re.search(r'([A-Za-z]+[0-9]*|[0-9]+)', res)
    if match:
        return match.group(1).strip()

    return None

def get_var_score(lr_model, activate_vector, mu):
    activate_vector = activate_vector.float().cuda()
    mu = mu.float().cuda()
    m, s = lr_model(mu, activate_vector)
    pro = var2pro(m, s)

    return pro

def evaluate_qa(original_data, batch_size=1, vlm=None, vlm_processor=None, mu_path='', activate_vector_path='', checkpoint_path='', resume=0):
    chosen_rejected = 0
    rejected_chosen = 0

    acc = 0
    all = 1e-6

    extractor = "Your final answer should be put between two ##, like ## A ## (if your final answer is A), at the end of your response."
    real_world = "Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer."

    data_loader = DataLoader(original_data,
                             batch_size=batch_size,
                             collate_fn=original_data.collate_fn,
                             shuffle=False)

    risk_data_loader, (in_dim, id_dim) = building_up(mu_root=mu_path,
                                                          activate_root=activate_vector_path,
                                                          risk_label_root='',
                                                          batch_size=batch_size)

    lr_model = LearnRiskModel(in_dim=in_dim, n_id=id_dim)
    lr_model.load_state_dict(torch.load(checkpoint_path))
    lr_model.eval()
    lr_model.cuda()

    with tqdm(data_loader, desc='Generating', ncols=100) as dbar:
        for idx, ((image_path, lang_x, label), (mu, activate_vectors)) in enumerate(zip(dbar, risk_data_loader)):

            if idx < resume:
                continue

            base_image_path, base_image, base_lang_x, base_label = original_data.n_shot(image_path, lang_x, label,
                                                                                             shot=0, cls='diff')


            base_lang_x = modify_lang_x(base_lang_x)

            texts = [vlm_processor.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True
            ) for message in base_lang_x]

            inputs = vlm_processor(images=base_image,
                                        text=texts,
                                        return_tensors="pt",
                                        padding=True).to(vlm.device)

            with torch.inference_mode():
                generated_ids = vlm.generate(**inputs, max_new_tokens=5)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_1 = generated_texts
                response_1_ = [extract_answer(res) for res in response_1]

                base_lang_x2 = copy.deepcopy(base_lang_x)

                [m.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[i]}'}]}) for i, m
                 in
                 enumerate(base_lang_x2)]
                # For the task of Animals use these:
                # [m.append({'role': 'user', 'content': [{'type': 'text',
                #                                         'text': f'You previously answered \'{response_1[i]}\'. Can you be more specific and provide the exact species or breed of the animal in the image, if possible? Answer only with the name in lowercase.'}]})
                #  for i, m in enumerate(base_lang_x2)]
                # real_world = """Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer. First you should output the corrected option. After that, explain step by step why this answer is better than the previous one. Finally, wrap the correct answer in ## ##.
                # Example:
                # A
                # Reason:
                # 1. The reason1.
                # 2. The reason2.
                # ## A ##
                #
                # Output:
                # """
                [m.append({'role': 'user', 'content': [{'type': 'text',
                                                        'text': f'{real_world} {extractor}'}]})
                 for i, m in enumerate(base_lang_x2)]

                texts = [vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in base_lang_x2]

                inputs = vlm_processor(images=base_image,
                                            text=texts,
                                            return_tensors="pt",
                                            padding=True).to(vlm.device)

                generated_ids = vlm.generate(**inputs, max_new_tokens=5)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_2 = generated_texts
                response_2_ = [extract_answer(res) for res in response_2]

                score = get_var_score(lr_model, activate_vectors, mu)
                # print(score)
                # print(response_1, response_2)

                for ix, (r1, r2) in enumerate(zip(response_1_, response_2_)):
                    # For the task of Animals need these:
                    # tem_r1 = r1.lower()
                    # tem_r2 = r2.lower()
                    tem_r1 = r1
                    tem_r2 = r2

                    if tem_r1 != tem_r2:
                        if tem_r1 == label[ix]:
                            all += 1
                            if score[ix, 0] < score[ix, 1]:
                                acc += 1
                        elif tem_r2 == label[ix]:
                            all += 1
                            if score[ix, 1] < score[ix, 0]:
                                acc += 1
                        else:
                            continue

            dbar.set_postfix(acc=f'{acc / all:.4f}', c_r=f'{chosen_rejected}', r_c=f'{rejected_chosen}')
        print(f'{acc / all:.4f}')


def evaluate_animals(original_data, batch_size=1, vlm=None, vlm_processor=None, mu_path='', activate_vector_path='', checkpoint_path='', resume=0):
    chosen_rejected = 0
    rejected_chosen = 0

    do_sample = True

    acc = 0
    all = 1e-6

    data_loader = DataLoader(original_data,
                             batch_size=batch_size,
                             collate_fn=original_data.collate_fn,
                             shuffle=False)

    risk_data_loader, (in_dim, id_dim) = building_up(mu_root=mu_path,
                                                          activate_root=activate_vector_path,
                                                          risk_label_root='',
                                                          batch_size=batch_size)

    lr_model = LearnRiskModel(in_dim=in_dim, n_id=id_dim)
    lr_model.load_state_dict(torch.load(checkpoint_path))
    lr_model.eval()
    lr_model.cuda()

    with tqdm(data_loader, desc='Generating', ncols=100) as dbar:
        for idx, ((image_path, lang_x, label), (mu, activate_vectors)) in enumerate(zip(dbar, risk_data_loader)):

            if idx < resume:
                continue

            base_image_path, base_image, base_lang_x, base_label = original_data.n_shot(image_path, lang_x, label,
                                                                                             shot=0, cls='diff')


            texts = [vlm_processor.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True
            ) for message in base_lang_x]

            inputs = vlm_processor(images=base_image,
                                        text=texts,
                                        return_tensors="pt",
                                        padding=True).to(vlm.device)

            with torch.inference_mode():
                generated_ids = vlm.generate(**inputs, max_new_tokens=5, do_sample=do_sample)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_1 = generated_texts
                response_1_ = [res for res in response_1]

                base_lang_x2 = copy.deepcopy(base_lang_x)

                [m.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[i]}'}]}) for i, m
                 in
                 enumerate(base_lang_x2)]
                # For the task of Animals use these:
                [m.append({'role': 'user', 'content': [{'type': 'text',
                                                        'text': f'You previously answered \'{response_1[i]}\'. Can you be more specific and provide the exact species or breed of the animal in the image, if possible? Answer only with the name in lowercase.'}]})
                 for i, m in enumerate(base_lang_x2)]


                texts = [vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in base_lang_x2]

                inputs = vlm_processor(images=base_image,
                                            text=texts,
                                            return_tensors="pt",
                                            padding=True).to(vlm.device)

                generated_ids = vlm.generate(**inputs, max_new_tokens=5, do_sample=do_sample)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                response_2 = generated_texts
                response_2_ = [res for res in response_2]

                score = get_var_score(lr_model, activate_vectors, mu)
                # print(score)
                # print(response_1, response_2)

                for ix, (r1, r2) in enumerate(zip(response_1_, response_2_)):
                    # For the task of Animals need these:
                    tem_r1 = r1.lower()
                    tem_r2 = r2.lower()
                    # tem_r1 = r1
                    # tem_r2 = r2

                    if tem_r1 != tem_r2:
                        if tem_r1 == label[ix]:
                            all += 1
                            if score[ix, 0] < score[ix, 1]:
                                acc += 1
                        elif tem_r2 == label[ix]:
                            all += 1
                            if score[ix, 1] < score[ix, 0]:
                                acc += 1
                        else:
                            continue

            dbar.set_postfix(acc=f'{acc / all:.4f}', c_r=f'{chosen_rejected}', r_c=f'{rejected_chosen}')
        print(f'{acc / all:.4f}')


if __name__ == '__main__':
    mode = 'animals'
    times = 5
    min_pixels = 256 * 28 * 28
    max_pixels = 512 * 28 * 28

    vlm_processor = AutoProcessor.from_pretrained('Qwen2-VL-2B/Qwen2-VL-2B-Instruct',
                                                  min_pixels=min_pixels,
                                                  max_pixels=max_pixels)

    vlm = Qwen2VLForConditionalGeneration.from_pretrained(
        'Qwen2-VL-2B/Qwen2-VL-2B-Instruct',
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    if mode == 'animals':
        data = Animals(root='data/Animals_with_Attributes2', test=False, return_image_path=True)
        for i in range(times):
            evaluate_animals(original_data=data, batch_size=1, vlm=vlm, vlm_processor=vlm_processor,
                        mu_path='mu/mu_animals',
                        activate_vector_path='activate_vector/animals_preference',
                        checkpoint_path='lrm_animals.pth')
    elif mode == 'rw':
        data = RealworldQA(root='data/realworldQA/data', return_image_path=True)
        evaluate_qa(original_data=data, batch_size=1, vlm=vlm, vlm_processor=vlm_processor,
                    mu_path='mu/real_world',
                    activate_vector_path='activate_vector/realworld_preference',
                    checkpoint_path='lrm_real_world.pth')
    elif mode == 'mmbench':
        data = MMBench(root='data/mmbench/data', return_image_path=True)
        evaluate_qa(original_data=data, batch_size=1, vlm=vlm, vlm_processor=vlm_processor,
                    mu_path='mu/mmbench',
                    activate_vector_path='activate_vector/mmbench_preference',
                    checkpoint_path='lrm_mmbench.pth')
    elif mode == 'mmstar':
        data = MMStar(root='data/mmstar', return_image_path=True)
        evaluate_qa(original_data=data, batch_size=1, vlm=vlm, vlm_processor=vlm_processor,
                    mu_path='mu/mmstar',
                    activate_vector_path='activate_vector/mmstar_preference',
                    checkpoint_path='lrm_mmstar.pth')
    elif mode == 'seedbench':
        data = SeedBench(root='data/seedbench/data', return_image_path=True)
        evaluate_qa(original_data=data, batch_size=1, vlm=vlm, vlm_processor=vlm_processor,
                    mu_path='mu/seedbench',
                    activate_vector_path='activate_vector/seedbench_preference',
                    checkpoint_path='lrm_seedbench.pth')
    else:
        data = ScienceQA(root='data/science/data', return_image_path=True)
        evaluate_qa(original_data=data, batch_size=1, vlm=vlm, vlm_processor=vlm_processor,
                    mu_path='mu/science',
                    activate_vector_path='activate_vector/science_preference',
                    checkpoint_path='lrm_science.pth')