import torch
from tqdm import tqdm
import numpy as np
import clip
import pandas as pd
import sklearn.preprocessing


class CLIPCapDataset(torch.utils.data.Dataset):
    def __init__(self, data, prefix='A photo depicts'):
        self.data = data
        self.prefix = prefix
        if self.prefix[-1] != ' ':
            self.prefix += ' '

    def __getitem__(self, idx):
        c_data = self.data[idx]
        c_data = clip.tokenize(self.prefix + c_data, truncate=True).squeeze()
        return {'caption': c_data}

    def __len__(self):
        return len(self.data)


def extract_all_captions(captions, model, device, batch_size=256, num_workers=8):
    data = torch.utils.data.DataLoader(
        CLIPCapDataset(captions),
        batch_size=batch_size, num_workers=num_workers, shuffle=False)
    all_text_features = []
    with torch.no_grad():
        for b in tqdm.tqdm(data):
            b = b['caption'].to(device)
            all_text_features.append(model.encode_text(b).cpu().numpy())
    all_text_features = np.vstack(all_text_features)
    return all_text_features


device = "cuda"
model, transform = clip.load("ViT-B/32", device=device, jit=False)
model.eval()
df = pd.read_csv('../backdoor_banana_blended_blended_595375_3000_label_scores.csv')
captions = df['caption'].tolist()
poisoned_subset_indices = np.load('../npy/blended/3000/pure_poison.npy')
candidates = []
for idx in poisoned_subset_indices:
    candidates.append(captions[idx])
candidates = extract_all_captions(candidates, model, device)
candidates = sklearn.preprocessing.normalize(candidates, axis=1)
captions_to_identify = captions
data = torch.utils.data.DataLoader(
    CLIPCapDataset(captions),
    batch_size=128, num_workers=8, shuffle=False)

textual_similarity = []
with torch.no_grad():
    for b in tqdm.tqdm(data):
        b = b['caption'].to(device)
        feat = [model.encode_text(b).cpu().numpy()]
        feat = np.vstack(feat, )
        feat = sklearn.preprocessing.normalize(feat, axis=1)
        textual_similarity.append(np.mean(feat * candidates, axis=1))
df['scores'] = textual_similarity
df.to_csv('.....', index=False)
