import json
from core import SPRA
from transformers import AutoProcessor, AutoModelForCausalLM
from core import rules
import os
import torch
from core.utils import cat_activate_vector
import pandas as pd
from datasets_ import Animals, RealworldQA, MMStar, MMBench, SeedBench, ScienceQA
from core.utils import load_vlm, load_llm, load_florence2

# The rules need to load llm
config = {'is_cls_animal': True,
          'n_shot_top5_distr': False,
          'where_live': True,
          'image_feature': True,
          'descriptive_similarity': True,
          'judging_the_species': True,
          'judging_the_species2': True
          }

resume = {
          'id': [],
          'n_shot_top5_distr': True,
          'where_live': True,
          'image_feature': True,
          'descriptive_similarity': True,
          'judging_the_species': True,
          'judging_the_species2': True,
          'qa_rules': False
          }

with torch.inference_mode():
    def generate_activate_vector(json_path='', vlm_path=None, llm_path=None, rules_=None, activate_path='', cls_shot=None, dataset=None):
        if not os.path.exists(activate_path):
            os.makedirs(os.path.dirname(activate_path))

        vlm, vlm_processor = load_vlm(vlm_path)

        df = pd.read_json(json_path)

        for idx, i in enumerate(df.columns):
            print(idx)
            data = df[i].tolist()
            reference = []
            if idx != 0:
                reference = df[df.columns[0]].tolist()

            for index, r in enumerate(rules_):
                # if r.__name__ != 'descriptive_similarity':
                #     continue
                if idx in resume['id']:
                    if r.__name__ in resume:
                        if resume[r.__name__]:
                            print('pass:', r.__name__)
                            continue

                if r.__name__ in config:
                    if config[r.__name__]:
                        llm = AutoModelForCausalLM.from_pretrained(llm_path, device_map='auto')
                        llm_processor = AutoProcessor.from_pretrained(llm_path)
                        spra = SPRA(vlm, vlm_processor, llm, llm_processor)
                    else:
                        spra = SPRA(vlm, vlm_processor)

                if r.__name__ == 'n_shot_top5_distr':

                    n_rules = 3
                    for c_s in cls_shot:
                        file_name = os.path.join(activate_path, f'{i}_{r.__name__}_{c_s}.pth')
                        spra.get_activate_vector(original_data=data,
                                                 activate_path=file_name,
                                                 rules=r,
                                                 cls_shot=c_s,
                                                 dataset=dataset,
                                                 reference=reference,
                                                 n_rules=n_rules)

                elif r.__name__ == 'qa_rules':
                    llm, llm_processor = load_llm(llm_path)
                    detect_model, detect_processor = load_florence2('Florence2-large-ft')
                    spra = SPRA(vlm=vlm, vlm_processor=vlm_processor, llm=llm, llm_processor=llm_processor, other_vlm=detect_model, other_processor=detect_processor)

                    n_rules = 10
                    file_name = os.path.join(activate_path, f'{i}_{r.__name__}.pth')
                    spra.get_activate_vector(original_data=data,
                                             activate_path=file_name,
                                             rules=r,
                                             reference=reference,
                                             n_rules=n_rules)
                else:
                    n_rules = 1
                    file_name = os.path.join(activate_path, f'{i}_{r.__name__}.pth')
                    spra.get_activate_vector(original_data=data,
                                             activate_path=file_name,
                                             rules=r,
                                             reference=reference,
                                             n_rules=n_rules)

                torch.cuda.empty_cache()
            if idx in resume['id']:
                continue
            cat_activate_vector(i, rules_, cls_shot, activate_path)

        return

if __name__ == '__main__':

    # cls_shot = [('diff', 4), ('diff', 5), ('diff', 6), ('diff', 7), ('same', 4), ('same', 5),
    #             ('same', 6),
    #             ('same', 7)]
    #
    # rules = rules['animals']
    #
    # dataset = Animals(root='data/Animals_with_Attributes2', test=False, return_image_path=False)
    # dataset = 0.02 * dataset
    #
    # generate_activate_vector(json_path='json_for_rules/animals_new/data_for_activate_vector.jsonl',
    #                          vlm_path=r"Qwen2-VL-2B/Qwen2-VL-2B-Instruct",
    #                          llm_path=r"idefics3/Llama-3.1-Tulu-3-8B",
    #                          rules_=rules,
    #                          activate_path=r'activate_vector/animals_train_new/',
    #                          cls_shot=cls_shot,
    #                          dataset=dataset)

    cls_shot = [('diff', 4), ('diff', 5), ('diff', 6), ('diff', 7)]

    rules = rules['realworld']

    # dataset = RealworldQA(root='data/realworldQA/data', return_image_path=False)
    # dataset = MMBench(root='data/mmbench/data', return_image_path=False)
    dataset = MMStar(root='data/mmstar', return_image_path=False)
    # dataset = SeedBench(root='data/seedbench/data', return_image_path=False)
    # dataset = ScienceQA(root='data/science/data', return_image_path=False)
    dataset = 0.4 * dataset

    generate_activate_vector(
        json_path='json_for_rules/mmstar_new/data_for_activate_vector.jsonl',
        vlm_path=r"Qwen2-VL-2B/Qwen2-VL-2B-Instruct",
        llm_path=r"idefics3/Llama-3.1-Tulu-3-8B",
        rules_=rules,
        activate_path=r'activate_vector/mmstar_new_train/',
        cls_shot=cls_shot,
        dataset=dataset)

    # cat_activate_vector(id='id1', rules=rules, cls_shot=cls_shot, activate_vector_path='activate_vector/cifar10_train/')
    # print(torch.load('activate_vector/animals_train/id1_activate.pth').shape)


