"""
Save data to jsonl:
The code in this file is designed to convert the data used to generate the activation vectors into a more readable json file.
"""
import os
from datasets_ import Animals, RealworldQA, MMBench, MMStar, SeedBench, ScienceQA
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import copy
import torch
import json
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from core.utils import load_vlm
import re

PROMPT = '''Given the question and the corresponding options below, first provide the correct answer. After that, explain step by step why this answer is the best choice, considering the information available. Finally, wrap the correct answer in ## ##.
Example:

Question:
Which of the following is an example of a mammal?
A. Shark.
B. Whale.
C. Lizard.
D. Frog.

Output:
B
Reason:
1.Mammals are vertebrates that typically have hair or fur and give live birth (except for monotremes).
2.Whales, although aquatic, are mammals because they give live birth, nurse their young with milk, and have a warm-blooded metabolism.
## B ##

Question:
What is the chemical symbol for gold? A. Au. B. Ag. C. Fe. D. Hg

Output:
A
Reason:
1.The chemical symbol for gold is derived from its Latin name "Aurum."
2.Au is the symbol that corresponds to gold, making it the correct answer.
## A ##

Question:
{question}

Output:
'''

real_world = """Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer. First you should output the corrected option. After that, explain step by step why this answer is better than the previous one. Finally, wrap the correct answer in ## ##.

Example:            
A
Reason:
1. The reason1.
2. The reason2.
## A ##

Output:
"""

class ToJson:
    def __init__(self, data: Dataset, batch_size=1, vlm=None, vlm_processor=None, save_path=''):
        self.data = data
        self.data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, collate_fn=self.data.collate_fn)

        self.vlm_processor = vlm_processor
        self.vlm = vlm

        self.save_path = save_path

    def generate_json(self, resume=0, rev=''):
        all_data1 = []
        all_data2 = []
        with tqdm(self.data_loader, desc='Generating') as dbar:
            for idxx, (images, lang_x, label) in enumerate(dbar):
                if idxx < resume:
                    continue

                tem1 = {}
                tem2 = {}
                # print(images)
                # print(lang_x)
                # print(label)
                image_path, base_image, base_lang_x, base_label = self.data.n_shot(images, lang_x, label,
                                                                                        shot=0, cls='diff')
                # print(image_path)
                # print(base_image)

                texts = [self.vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in base_lang_x]

                inputs = self.vlm_processor(images=base_image,
                                       text=texts,
                                       return_tensors="pt",
                                       padding=True).to(self.vlm.device)

                with torch.inference_mode():
                    generated_ids = self.vlm.generate(**inputs, max_new_tokens=5)
                    generated_ids_trimmed = [
                        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                    ]
                    generated_texts = self.vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                    response_1 = generated_texts

                base_lang_x2 = copy.deepcopy(base_lang_x)

                [m.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[i]}'}]}) for i, m in
                 enumerate(base_lang_x2)]

                # for animals:
                [m.append({'role': 'user', 'content': [{'type': 'text', 'text': f'You previously answered \'{response_1[i]}\'. Can you be more specific and provide the exact species or breed of the animal in the image, if possible? Answer only with the name in lowercase.'}]})
                 for i, m in enumerate(base_lang_x2)]

                # real_world = "Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer. Please answer directly with only the letter of the correct option and nothing else."
                # [m.append({'role': 'user', 'content': [{'type': 'text',
                #                                         'text': f'{real_world}'}]})
                #  for i, m in enumerate(base_lang_x2)]


                texts = [self.vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in base_lang_x2]

                inputs = self.vlm_processor(images=base_image,
                                       text=texts,
                                       return_tensors="pt",
                                       padding=True).to(self.vlm.device)

                with torch.inference_mode():
                    generated_ids = self.vlm.generate(**inputs, max_new_tokens=5)
                    generated_ids_trimmed = [
                        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                    ]
                    generated_texts = self.vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                    response_2 = generated_texts

                for idx, (promot1, promot2) in enumerate(zip(base_lang_x, base_lang_x2)):

                    tem1['images'] = image_path[idx]
                    tem1['prompt'] = promot1
                    tem1['label'] = label[idx]
                    tem1['response'] = response_1[idx]

                    tem2['images'] = image_path[idx]
                    tem2['prompt'] = promot2
                    tem2['label'] = label[idx]
                    tem2['response'] = response_2[idx]

                    all_data1.append(tem1)
                    all_data2.append(tem2)
                    # print(tem1['response'], tem2['response'], tem2['label'])
                    tem1 = {}
                    tem2 = {}
        #  id0: init response, id1: improved response
        data_ = {
            'id0': all_data1,
            'id1': all_data2
        }

        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

        with open(os.path.join(self.save_path, 'data_for_activate_vector.jsonl'), 'w') as f:
            json.dump(data_, f, ensure_ascii=False, indent=4)

        print('done perfectly')

class QA2Json:
    def __init__(self, data: Dataset, batch_size=1, vlm=None, vlm_processor=None, save_path=''):
        self.data = data
        self.data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, collate_fn=self.data.collate_fn)

        self.vlm_processor = vlm_processor
        self.vlm = vlm

        self.save_path = save_path

    def generate_json(self, resume=0, extractor="Your final answer should be put between two ##, like ## A ## (if your final answer is A), at the end of your response."):
        all_data1 = []
        all_data2 = []
        real_world = "Review your previous answer and ensure that all relevant aspects of the image have been considered. Are there any elements or details that you missed? Based on your review, improve your answer."
        extractor = "Your final answer should be put between two ##, like ## A ## (if your final answer is A), at the end of your response."

        with tqdm(self.data_loader, desc='Generating') as dbar:
            for idxx, (images, lang_x, label) in enumerate(dbar):
                if idxx < resume:
                    continue

                tem1 = {}
                tem2 = {}

                image_path, base_image, base_lang_x, base_label = self.data.n_shot(images, lang_x, label,
                                                                                        shot=0, cls='diff')

                tem_ = copy.deepcopy(base_lang_x)

                base_lang_x = self.modify_lang_x(base_lang_x)

                # print(base_lang_x)
                # print(image_path)
                # print(base_image)

                texts = [self.vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in base_lang_x]

                inputs = self.vlm_processor(images=base_image,
                                       text=texts,
                                       return_tensors="pt",
                                       padding=True).to(self.vlm.device)

                with torch.inference_mode():
                    generated_ids = self.vlm.generate(**inputs, max_new_tokens=1024)
                    generated_ids_trimmed = [
                        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                    ]
                    generated_texts = self.vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                    response_1 = generated_texts

                base_lang_x2 = copy.deepcopy(tem_)
                response_1_ = [self.extract_answer(r) for r in response_1]

                [m.append({'role': 'assistant', 'content': [{'type': 'text', 'text': f'{response_1[i]}'}]}) for i, m in
                 enumerate(base_lang_x2)]



                [m.append({'role': 'user', 'content': [{'type': 'text',
                                                        'text': f'{real_world} {extractor}'}]})
                 for i, m in enumerate(base_lang_x2)]

                texts = [self.vlm_processor.apply_chat_template(
                    message, tokenize=False, add_generation_prompt=True
                ) for message in base_lang_x2]

                inputs = self.vlm_processor(images=base_image,
                                       text=texts,
                                       return_tensors="pt",
                                       padding=True).to(self.vlm.device)

                with torch.inference_mode():
                    generated_ids = self.vlm.generate(**inputs, max_new_tokens=1024)
                    generated_ids_trimmed = [
                        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                    ]
                    generated_texts = self.vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

                    response_2 = generated_texts
                    response_2_ = [self.extract_answer(res) for res in response_2]

                for idx, (promot1, promot2) in enumerate(zip(tem_, base_lang_x2)):

                    tem1['images'] = image_path[idx]
                    tem1['prompt'] = promot1
                    tem1['label'] = label[idx]
                    tem1['response'] = response_1_[idx]

                    tem2['images'] = image_path[idx]
                    tem2['prompt'] = promot2
                    tem2['label'] = label[idx]
                    tem2['response'] = response_2_[idx]

                    all_data1.append(tem1)
                    all_data2.append(tem2)
                    # print(tem1['response'], tem2['response'], tem2['label'])
                    tem1 = {}
                    tem2 = {}
        #  id0: init response, id1: improved response
        data_ = {
            'id0': all_data1,
            'id1': all_data2
        }

        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

        with open(os.path.join(self.save_path, 'data_for_activate_vector.jsonl'), 'w') as f:
            json.dump(data_, f, ensure_ascii=False, indent=4)

        print('done perfectly')

    def modify_lang_x(self, lang_x):
        new = []
        for i in lang_x:
            tem = i[0]
            for t in tem['content']:
                if t['type'] == 'text':
                    question = t['text']
                    t['text'] = PROMPT.format(question=question)
            new.append([tem])

        return new

    def extract_answer(self, res):
        patterns = [
            r'##\s*([ABCD])\s*##',
            r'##(.*?)##',
            r'##\s*([0-3])\s*##'
        ]

        for pattern in patterns:
            match = re.search(pattern, res)
            if match:
                return match.group(1).strip()

        match = re.search(r'([ABCD])\.', res)
        if match:
            return match.group(1).strip()

        match = re.search(r'([A-Za-z]+[0-9]*|[0-9]+)', res)
        if match:
            return match.group(1).strip()

        return None

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--data_root', default='data/science/data')
    parser.add_argument('--vlm_weight', default='Qwen2-VL-2B/Qwen2-VL-2B-Instruct')
    parser.add_argument('--save_path', default='json_for_rules/science_new')
    parser.add_argument('--resume', default=0)
    parser.add_argument('--alpha', default=0.4)

    args = parser.parse_args()

    #  The data for generating activate vectors
    # data = Animals(root=args.data_root, test=False, return_image_path=True)
    # data = args.alpha * data  # alpha = 0.02
    # data = RealworldQA(root=args.data_root, return_image_path=True)
    # data = MMBench(root=args.data_root, return_image_path=True)
    # data = MMStar(root=args.data_root, return_image_path=True)
    # data = SeedBench(root=args.data_root, return_image_path=True)
    data = ScienceQA(root=args.data_root, return_image_path=True)
    data = args.alpha * data

    vlm_weight = args.vlm_weight

    vlm, vlm_processor = load_vlm(vlm_path=vlm_weight)

    save_json = QA2Json(data=data, batch_size=4, vlm=vlm,
                       vlm_processor=vlm_processor,
                       save_path=args.save_path)

    # rev = 'Review your contextual understanding of the image. Based on your review, improve your answer using one word with lowercase.'
    # rev = 'Is the above answer correct?'

    save_json.generate_json(resume=args.resume)