import torch

from datasets_ import Animals, MMBench, MMStar, RealworldQA, SeedBench, ScienceQA
from torch.utils.data import DataLoader
from tqdm import tqdm
from core.utils import load_vlm
from fuzzywuzzy import fuzz
from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration, Idefics3Processor, Gemma3ForConditionalGeneration
from peft import PeftModel, get_peft_model, PeftConfig


def load_ide(path, v=3):
    if v == 3:
        processor = Idefics3Processor.from_pretrained(path)
    else:
        processor = AutoProcessor.from_pretrained(path)
    model = AutoModelForVision2Seq.from_pretrained(path, torch_dtype=torch.bfloat16, device_map='auto')

    return model, processor

def evaluate(args):
    test_data = Animals(args.data_root, test=False)
    test_data_loader = DataLoader(test_data, batch_size=args.batch_size, collate_fn=test_data.collate_fn)

    # vlm, vlm_processor = load_vlm(args.vlm_path)
    vlm, vlm_processor = load_ide(args.vlm_path)

    if args.checkpoint != '':
        missing_keys, unexpected_keys = vlm.load_state_dict(torch.load(args.checkpoint), strict=False)
        print(f"Frozen keys: {missing_keys}")
        print(f"Unexpected keys: {unexpected_keys}")

    with tqdm(test_data_loader, desc="Evaluating:", ncols=100) as dbar:
        count = 0
        all_ = 0
        for images, text, label in dbar:
            base_image, base_lang_x, base_label = test_data.n_shot(images, text, label,
                                                                               shot=0, cls='diff')

            texts = [vlm_processor.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True
            ) for message in base_lang_x]

            inputs = vlm_processor(images=base_image,
                                        text=texts,
                                        return_tensors="pt",
                                        padding=True).to(vlm.device)

            batch = inputs.input_ids.shape[0]
            all_ += batch

            with torch.inference_mode():
                generated_ids = vlm.generate(**inputs, max_new_tokens=5)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)


            for g, l in zip(generated_texts, label):
                g = g.replace(' ', '')
                g = g.replace('.', '')
                g = g.lower()

                if fuzz.ratio(g, l) > 75:
                    count += 1

            dbar.set_postfix(acc=f"{count/all_:.4f}")


        print(f'Acc: {count/all_}')


PROMPT = '''Given the question and the corresponding options below, first provide the correct answer. After that, explain step by step why this answer is the best choice, considering the information available. Finally, wrap the correct answer in ## ##.
Example:

Question:
Which of the following is an example of a mammal?
A. Shark.
B. Whale.
C. Lizard.
D. Frog.

Output:
B
Reason:
1.Mammals are vertebrates that typically have hair or fur and give live birth (except for monotremes).
2.Whales, although aquatic, are mammals because they give live birth, nurse their young with milk, and have a warm-blooded metabolism.
## B ##

Question:
What is the chemical symbol for gold? A. Au. B. Ag. C. Fe. D. Hg

Output:
A
Reason:
1.The chemical symbol for gold is derived from its Latin name "Aurum."
2.Au is the symbol that corresponds to gold, making it the correct answer.
## A ##

Question:
{question}

Output:
'''

def evaluate_qa(args):
    def modify_lang_x(lang_x):
        new = []
        for i in lang_x:
            tem = i[0]
            for t in tem['content']:
                if t['type'] == 'text':
                    question = t['text']
                    t['text'] = PROMPT.format(question=question)
            new.append([tem])

        return new

    import re



    def extract_answer(res):
        patterns = [
            r'##\s*([ABCD])\s*##',
            r'##(.*?)##',
            r'##\s*([0-3])\s*##'
        ]

        for pattern in patterns:
            match = re.search(pattern, res)
            if match:
                return match.group(1).strip()

        match = re.search(r'([ABCD])\.', res)
        if match:
            return match.group(1).strip()

        match = re.search(r'([A-Za-z]+[0-9]*|[0-9]+)', res)
        if match:
            return match.group(1).strip()

        return None

    def load_ide(path, v=3):
        if v == 3:
            processor = Idefics3Processor.from_pretrained(path)
        else:
            processor = AutoProcessor.from_pretrained(path)
        model = AutoModelForVision2Seq.from_pretrained(path, torch_dtype=torch.bfloat16, device_map='auto')

        return model, processor

    def load_intern(path):
        processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
        model = Gemma3ForConditionalGeneration.from_pretrained(path, device_map='auto', trust_remote_code=True)

        return model, processor

    def load_llava(path):
        processor = AutoProcessor.from_pretrained(path)
        model = LlavaForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16, device_map='auto')

        return model, processor


    # test_data = MMStar(args.data_root)
    # test_data = RealworldQA(args.data_root)
    test_data = SeedBench(args.data_root)
    # test_data = MMBench(args.data_root)
    # test_data = ScienceQA(args.data_root)

    num_sample = len(test_data)
    print('num_sample:', num_sample)
    test_data_loader = DataLoader(test_data, batch_size=args.batch_size, collate_fn=test_data.collate_fn)

    vlm, vlm_processor = load_ide(args.vlm_path)

    if args.checkpoint != '':
        if args.lora:
            vlm = PeftModel.from_pretrained(vlm, args.checkpoint)
            # vlm = lora_model.merge_and_unload()
        print('checkpoint loaded')
        # missing_keys, unexpected_keys = vlm.load_state_dict(torch.load(args.checkpoint), strict=False)
        # print(f"Frozen keys: {missing_keys}")
        # print(f"Unexpected keys: {unexpected_keys}")

    with tqdm(test_data_loader, desc="Evaluating:", ncols=100) as dbar:
        count = 0
        all_ = 0
        for images, text, label in dbar:

            base_image, base_lang_x, base_label = test_data.n_shot(images, text, label,
                                                                               shot=0, cls='diff')
            base_lang_x = modify_lang_x(base_lang_x)

            texts = [vlm_processor.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True
            ) for message in base_lang_x]

            inputs = vlm_processor(images=base_image,
                                        text=texts,
                                        return_tensors="pt",
                                        padding=True).to(vlm.device)

            batch = inputs.input_ids.shape[0]
            all_ += batch

            with torch.inference_mode():
                generated_ids = vlm.generate(**inputs, max_new_tokens=50) #250
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

            for g, l in zip(generated_texts, label):
                out = extract_answer(g)

                if fuzz.ratio(out, l) > 75:
                    count += 1

            acc = count / all_

            dbar.set_postfix(acc=f"{count/all_:.4f}", max_acc=f"{acc * (all_ / num_sample) + (1 - all_ / num_sample):.4f}")


        print(f'Acc: {count/all_}')

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--data_root', default='data/Animals_with_Attributes2')
    parser.add_argument('--batch_size', default=1)
    parser.add_argument('--vlm_path', default='idefics3/Idefics3-8B-Llama3')
    # parser.add_argument('--vlm_path', default='idefics3/Idefics3-8B-Llama3')
    parser.add_argument('--checkpoint', default='') #checkpoint/animals/finetuned_checkpoint2.pth
    # parser.add_argument('--checkpoint', default='checkpoint/all/finetuned_checkpoint3.pth')
    parser.add_argument('--lora', default=True)

    args = parser.parse_args()

    # evaluate_qa(args)
    evaluate(args)