from core import SPRA
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from core import rules
import os
import torch
from class_ids import cifar10_dict
import pandas as pd
from core.utils import load_activate_vectors

def generate_mu(json_path='', batch_size=1, activate_vector_path='', mu_path=''):
    if not os.path.exists(mu_path):
        os.makedirs(mu_path)

    df = pd.read_json(json_path)
    activate_vectors = []
    for id in df.columns:
        activate_vectors.append(torch.load(os.path.join(activate_vector_path, f'{id}_activate.pth')))

    activate_vectors = load_activate_vectors(activate_vectors)

    spra = SPRA()
    for idx, id in enumerate(df.columns):
        data = df[id].tolist()

        spra.get_mu(data, batch_size, activate_vectors[idx], mu_path, id)

    return

if __name__ == '__main__':

    # cat_activate_vector(rules, cls_shot, 'activate_vector/cifar10_train')

    generate_mu(json_path='json_for_rules/mmstar_new/data_for_activate_vector.jsonl',
                batch_size=4,
                activate_vector_path=r'activate_vector/mmstar_new_train',
                mu_path='mu/mmstar_new')


