from preparation_data import CombinData
from raie.datasets_ import Animals, RealworldQA, MMBench, MMStar, SeedBench, ScienceQA
import open_clip
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import warnings
from PIL import Image
import torch
import os
import copy
from glob import glob
warnings.filterwarnings('ignore')

class MappingActivateVector:
    def __init__(self, vit_type, vit_pretrained_path, data: Dataset, activate_root, batch_size, alpha=0.02):

        self.vision_encoder, _, self.image_processor = open_clip.create_model_and_transforms(
            vit_type,
            pretrained=vit_pretrained_path
        )

        self.vision_encoder.cuda()
        self.vision_encoder.eval()

        activate_vectors_name = sorted(glob(os.path.join(activate_root, '*_activate.pth')))
        activate_vectors = [torch.load(a) for a in activate_vectors_name]
        activate_vectors = torch.stack(activate_vectors, dim=-1)  # B, n, n_id

        self.data = DataLoader(data, batch_size, shuffle=False)
        data2 = alpha * copy.copy(data)
        data2 = CombinData(data2, activate_vectors)
        self.data2 = DataLoader(data2, batch_size, shuffle=False)
        self.activate_vectors = activate_vectors

    def load_image(self, image_path):
        if isinstance(image_path, (list, tuple)):
            image_list = [Image.open(i) for i in image_path]
            return torch.stack([self.image_processor(image) for image in image_list])

    def mapping(self, save_path=''):

        new_activate_vectors = []

        if not os.path.exists(save_path):
            os.makedirs(save_path)

        with tqdm(self.data, desc='Mapping') as dabr:
            for images_path_1, *_ in dabr:
                corr_tem = []
                images1 = self.load_image(images_path_1).cuda()
                images1_f = self.vision_encoder.encode_image(images1)  # B, 768

                with torch.inference_mode():
                    for images_path_2, _ in self.data2:
                        images2 = self.load_image(images_path_2).cuda()
                        images2_f = self.vision_encoder.encode_image(images2)  # B, 768

                        corr = images1_f @ images2_f.T
                        corr_tem.append(corr)

                    corr = torch.cat(corr_tem, dim=-1)  # B, N; N is number of activate vectors
                    corr = torch.softmax(corr, dim=-1)
                    tem_a = torch.reshape(self.activate_vectors, (corr.shape[-1], -1))
                    mapped_activate_vector = corr @ tem_a.cuda()
                    new_activate_vectors.append(mapped_activate_vector.reshape(corr.shape[0], *self.activate_vectors.shape[1:]))

        # The shape of saved activate vector must be (B, n)
        for i in range(self.activate_vectors.shape[-1]):
            torch.save(torch.cat(new_activate_vectors, dim=0)[:, :, i], os.path.join(save_path, f'id{i}_activate.pth'))
        print('done')

if __name__ == '__main__':
    # data = Animals(root='data/Animals_with_Attributes2', test=False)
    # data = MMBench(root='data/mmbench/data')
    # data = RealworldQA(root='data/realworldQA/data')
    data = MMStar(root='data/mmstar', return_image_path=False)
    # data = SeedBench(root='data/seedbench/data', return_image_path=False)
    # data = ScienceQA(root='data/science/data', return_image_path=False)
    map_a = MappingActivateVector(vit_type='ViT-L-14',
                                  vit_pretrained_path='ViT-L-14.pt',
                                  data=data,
                                  activate_root='activate_vector/mmstar_new_train/',
                                  batch_size=30,
                                  alpha=0.4)

    map_a.mapping(save_path='activate_vector/mmstar_preference')





