import json
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM
import torch
from core.lrm import LearnRiskModel
from core.risk_utils import var2pro
import os
from core.build_up import building_up
import copy
import re
PROMPT = '''Given the question and the corresponding options below, first provide the correct answer. After that, explain step by step why this answer is the best choice, considering the information available. Finally, wrap the correct answer in ## ##.
Example:

Question:
Which of the following is an example of a mammal?
A. Shark.
B. Whale.
C. Lizard.
D. Frog.

Output:
B
Reason:
1.Mammals are vertebrates that typically have hair or fur and give live birth (except for monotremes).
2.Whales, although aquatic, are mammals because they give live birth, nurse their young with milk, and have a warm-blooded metabolism.
## B ##

Question:
What is the chemical symbol for gold? A. Au. B. Ag. C. Fe. D. Hg

Output:
A
Reason:
1.The chemical symbol for gold is derived from its Latin name "Aurum."
2.Au is the symbol that corresponds to gold, making it the correct answer.
## A ##

Question:
{question}

Output:
'''

class GeneratePere:
    def __init__(self, original_data, batch_size=1, data_save_path='', vlm=None, vlm_processor=None, mu_path='', activate_vector_path='', checkpoint_path='', resume=0):
        self.original_data = original_data
        self.data_loader = DataLoader(self.original_data,
                                 batch_size=batch_size,
                                 collate_fn=original_data.collate_fn,
                                 shuffle=False)

        self.risk_data_loader, (in_dim, id_dim) = building_up(mu_root=mu_path,
                                                              activate_root=activate_vector_path,
                                                              risk_label_root='',
                                                              batch_size=batch_size)

        self.lr_model = LearnRiskModel(in_dim=in_dim, n_id=id_dim)
        self.lr_model.load_state_dict(torch.load(checkpoint_path))
        self.lr_model.eval()
        self.lr_model.cuda()

        self.vlm = vlm
        self.vlm_processor = vlm_processor

        self.data_save_path = data_save_path
        self.resume = resume
        self.batch_size = batch_size

    def generate_preference_dataset(self):
        preference_data = []

        chosen_rejected = 0
        rejected_chosen = 0

        acc = 0
        all = 1e-6

        # extractor = "Your final answer should be put between two ##, like ## A ## (if your final answer is A), at the end of your response."
        # real_world = "Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer."

        with tqdm(self.data_loader, desc='Generating', ncols=100) as dbar:
            for idx, ((image_path, lang_x, label), (mu, activate_vectors)) in enumerate(zip(dbar, self.risk_data_loader)):

                if idx < self.resume:
                    continue

                base_image_path, base_image, base_lang_x, base_label = self.original_data.n_shot(image_path, lang_x, label,
                                                                                            shot=0, cls='diff')

                # tem_lang_x = copy.deepcopy(base_lang_x)
                # if data is from animals, moving the function of "modify_lang_x"
                base_lang_x = self.modify_lang_x(base_lang_x)
                tem_lang_x = copy.deepcopy(base_lang_x)

                texts = [self.vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in base_lang_x]

                inputs = self.vlm_processor(images=base_image,
                                       text=texts,
                                       return_tensors="pt",
                                       padding=True).to(self.vlm.device)

                with torch.inference_mode():
                    generated_ids = self.vlm.generate(**inputs, max_new_tokens=1024)
                    generated_ids_trimmed = [
                        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                    ]
                    generated_texts = self.vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                    response_1 = generated_texts
                    response_1_ = [self.extract_answer(res) for res in response_1]

                    base_lang_x2 = copy.deepcopy(base_lang_x)

                    [m.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[i]}'}]}) for i, m
                     in
                     enumerate(base_lang_x2)]
                    # For the task of Animals use these:
                    # [m.append({'role': 'user', 'content': [{'type': 'text',
                    #                                         'text': f'You previously answered \'{response_1[i]}\'. Can you be more specific and provide the exact species or breed of the animal in the image, if possible? Answer only with the name in lowercase.'}]})
                    #  for i, m in enumerate(base_lang_x2)]
                    real_world = """Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer. First you should output the corrected option. After that, explain step by step why this answer is better than the previous one. Finally, wrap the correct answer in ## ##.
                    Example:
                    A
                    Reason:
                    1. The reason1.
                    2. The reason2.
                    ## A ##

                    Output:
                    """
                    [m.append({'role': 'user', 'content': [{'type': 'text',
                                                            'text': f'{real_world}'}]})
                     for i, m in enumerate(base_lang_x2)]

                    texts = [self.vlm_processor.apply_chat_template(
                        message, tokenize=False, add_generation_prompt=True
                    ) for message in base_lang_x2]

                    inputs = self.vlm_processor(images=base_image,
                                           text=texts,
                                           return_tensors="pt",
                                           padding=True).to(self.vlm.device)

                    generated_ids = self.vlm.generate(**inputs, max_new_tokens=1024)
                    generated_ids_trimmed = [
                        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                    ]
                    generated_texts = self.vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                    response_2 = generated_texts
                    response_2_ = [self.extract_answer(res) for res in response_2]

                    score = self.get_var_score(activate_vectors, mu)
                    # print(score)
                    # print(response_1, response_2)

                    for ix, (r1, r2) in enumerate(zip(response_1_, response_2_)):
                        # For the task of Animals need these:
                        # tem_r1 = r1.lower()
                        # tem_r2 = r2.lower()
                        tem_r1 = r1
                        tem_r2 = r2

                        if tem_r1 != tem_r2:
                            tem = {}
                            if score[ix, 0] < score[ix, 1]:
                                tem['images'] = base_image_path[ix]
                                tem['prompt'] = tem_lang_x[ix]
                                # tem['chosen'] = tem_r1
                                tem['chosen'] = response_1[ix]
                                # tem['rejected'] = tem_r2
                                tem['rejected'] = response_2[ix]
                                chosen_rejected += 1
                            else:
                                tem['images'] = base_image_path[ix]
                                tem['prompt'] = tem_lang_x[ix]
                                # tem['chosen'] = tem_r2
                                # tem['rejected'] = tem_r1
                                tem['chosen'] = response_2[ix]
                                tem['rejected'] = response_1[ix]
                                rejected_chosen += 1
                            preference_data.append(tem)

                        if tem_r1 != tem_r2:
                            if tem_r1 == label[ix]:
                                all += 1
                                if score[ix, 0] < score[ix, 1]:
                                    acc += 1
                            elif tem_r2 == label[ix]:
                                all += 1
                                if score[ix, 1] < score[ix, 0]:
                                    acc += 1
                            else:
                                # if score[ix, 0] > 0.5 and score[ix, 1] > 0.5:
                                #     acc += 1
                                continue

                dbar.set_postfix(acc=f'{acc/all:.4f}', c_r=f'{chosen_rejected}', r_c=f'{rejected_chosen}')
            #
            if not os.path.exists(self.data_save_path):
                os.makedirs(self.data_save_path)

            data = {
                'id0': preference_data
            }

            file_name = os.path.join(self.data_save_path, 'preference_data.jsonl')
            with open(file_name, 'w') as f:
                json.dump(data, f, ensure_ascii=False, indent=4)
            print('done perfectly')

    def extract_answer(self, res):
        patterns = [
            r'##\s*([ABCD])\s*##',
            r'##(.*?)##',
            r'##\s*([0-3])\s*##'
        ]

        for pattern in patterns:
            match = re.search(pattern, res)
            if match:
                return match.group(1).strip()

        match = re.search(r'([ABCD])\.', res)
        if match:
            return match.group(1).strip()

        match = re.search(r'([A-Za-z]+[0-9]*|[0-9]+)', res)
        if match:
            return match.group(1).strip()

        return None

    def modify_lang_x(self, lang_x):
        new = []
        for i in lang_x:
            tem = i[0]
            for t in tem['content']:
                if t['type'] == 'text':
                    question = t['text']
                    t['text'] = PROMPT.format(question=question)
            new.append([tem])

        return new

    def get_var_score(self, activate_vector, mu):
        activate_vector = activate_vector.float().cuda()
        mu = mu.float().cuda()
        m, s = self.lr_model(mu, activate_vector)
        pro = var2pro(m, s)

        return pro




if __name__ == '__main__':
    cls_shot = [('diff', 4), ('diff', 6), ('diff', 8), ('diff', 10), ('same', 4), ('same', 6),
                ('same', 8),
                ('same', 10)]
    data = Cifar10(root='data/cifar10', test=False, return_image_path=True)
    data = 0.01 * data
    rules = rules['cifar10']

    min_pixels = 256 * 28 * 28
    max_pixels = 512 * 28 * 28

    vlm = Qwen2VLForConditionalGeneration.from_pretrained(
        'Qwen2-VL-2B/Qwen2-VL-2B-Instruct',
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    mu = torch.load('mu/mu_cifar10/mu.pth')


    vlm_processor = AutoProcessor.from_pretrained('Qwen2-VL-2B/Qwen2-VL-2B-Instruct', min_pixels=min_pixels, max_pixels=max_pixels)

    llm = AutoModelForCausalLM.from_pretrained('idefics3/Llama-3.1-Tulu-3-8B', device_map='auto')
    llm_processor = AutoProcessor.from_pretrained('idefics3/Llama-3.1-Tulu-3-8B')

    generate_preference_dataset(original_data=data,
                                batch_size=1,
                                vlm=vlm,
                                vlm_processor=vlm_processor,
                                llm=llm,
                                llm_processor=llm_processor,
                                checkpoint_path='lrm.pth',
                                mu=mu,
                                rules=rules)








