import pandas as pd
import clip
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import os
import numpy as np
import time
import json
from collections import defaultdict
import torchvision.transforms as transforms
import csv
import shutil
import argparse
from torch.utils.data import Dataset, DataLoader
 
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from sklearn.metrics import f1_score, precision_recall_curve, roc_curve, auc
import sys

from sklearn.decomposition import PCA
from utils.eda import *
import joblib
import statistics

def _augment_text(caption):
    augmented_caption = eda(caption)
    return augmented_caption[0]

 



def _augment_batch_image2(batch_image):
 
    batch_pil_images = [T.ToPILImage()(img) for img in batch_image]   
 
    transform = torchvision.transforms.AutoAugment()
    augmented_pil_images = [transform(img) for img in batch_pil_images]  
 
    augmented_batch = torch.stack([T.ToTensor()(img) for img in augmented_pil_images])

    return augmented_batch





 
class ImageLabelDataset_baseclean(Dataset):
    def __init__(self, base_cleanset_x, base_cleanset_y):
         
        self.base_cleanset_x = base_cleanset_x
        self.base_cleanset_y = base_cleanset_y
         

    def __len__(self):
     
        return len(self.base_cleanset_x) 

    def __getitem__(self, idx):
 
        image = self.base_cleanset_x[idx]
        label = self.base_cleanset_y[idx]

        return image, label    



class ImageLabelDataset_(Dataset):
    def __init__(self, bd_inputs, bd_targets):
 
        self.bd_inputs = bd_inputs
        self.bd_targets = bd_targets

    def __len__(self):
    
        return len(self.bd_inputs) 

    def __getitem__(self, idx):
 
        image = self.bd_inputs[idx]
        label =  954 

        return image, label

class ImageLabelDataset_clean(Dataset):
    def __init__(self, clean_inputs, clean_targets):
        
        self.clean_inputs = clean_inputs
        self.clean_targets = clean_targets
         

    def __len__(self):
         
        return len(self.clean_inputs) 

    def __getitem__(self, idx):
    
        image = self.clean_inputs[idx]
        label = self.clean_targets[idx]
 

        return image, label   

description_df = pd.read_csv('./detection_imagenet/ptp/imagenet_category_description.csv')
def get_descriptions_for_targets(targets):
    description_df = pd.read_csv('./detection_imagenet/ptp/imagenet_category_description.csv')
  
    descriptions_list = []
 
    for i,target in enumerate(targets):
        target_value = target.item()   
        
       
        row = description_df[description_df['class_index'] == target_value]
        if not row.empty:
   
            descriptions = row.iloc[0, 2]   
            descriptions_list.append(descriptions)
 
    return descriptions_list

def get_descriptions_for_targets3(targets):
    description_df = pd.read_csv('./detection_imagenet/ptp/imagenet_category_description.csv')
   
    descriptions_list = []
 
    for i,target in enumerate(targets):
        target_value = target.item()  
 
        row = description_df[description_df['class_index'] == target_value]
        if not row.empty:
  
            descriptions = row.iloc[0, 3]  
            descriptions_list.append(descriptions)
 
    return descriptions_list

def get_descriptions_for_targets4(targets):
    description_df = pd.read_csv('./detection_imagenet/ptp/imagenet_category_description.csv')
 
    descriptions_list = []
 
    for i,target in enumerate(targets):
        target_value = target.item()   
 
        row = description_df[description_df['class_index'] == target_value]
        if not row.empty:
 
            descriptions = row.iloc[0, 4]   
            descriptions_list.append(descriptions)
 
    return descriptions_list


def get_descriptions_for_targets6(targets):
    description_df = pd.read_csv('./detection_imagenet/ptp/imagenet_category_description.csv')
  
    descriptions_list = []
 
    for i,target in enumerate(targets):
        target_value = target.item()  
  
        row = description_df[description_df['class_index'] == target_value]
        if not row.empty:
 
            descriptions = row.iloc[0, 6] 
            descriptions_list.append(descriptions)
 
    return descriptions_list

def get_descriptions_for_targets7(targets):
    description_df = pd.read_csv('./detection_imagenet/ptp/imagenet_category_description.csv')
 
    descriptions_list = []
 
    for i,target in enumerate(targets):
        target_value = target.item()  
 
        row = description_df[description_df['class_index'] == target_value]
        if not row.empty:
 
            descriptions = row.iloc[0, 7]   
            descriptions_list.append(descriptions)
 
    return descriptions_list


def get_descriptions_for_targets5(targets):
    description_df = pd.read_csv('./detection_imagenet/ptp/imagenet_category_description.csv')
 
    descriptions_list = []
 
    for i,target in enumerate(targets):
        target_value = target.item()  
 
        row = description_df[description_df['class_index'] == target_value]
        if not row.empty:
 
            descriptions = row.iloc[0, 5]   
            descriptions_list.append(descriptions)
 
    return descriptions_list

def to_fraktur_unicode1(text):
   
    fraktur_map = {
        'a': '𝔞', 'b': '𝔟', 'c': '𝔠', 'd': '𝔡', 'e': '𝔢', 'f': '𝔣', 'g': '𝔤',
        'h': '𝔥', 'i': '𝔦', 'j': '𝔧', 'k': '𝔨', 'l': '𝔩', 'm': '𝔪', 'n': '𝔫',
        'o': '𝔬', 'p': '𝔭', 'q': '𝔮', 'r': '𝔯', 's': '𝔰', 't': '𝔱', 'u': '𝔲',
        'v': '𝔳', 'w': '𝔴', 'x': '𝔵', 'y': '𝔶', 'z': '𝔷',
        'A': '𝔄', 'B': '𝔅', 'C': 'ℭ', 'D': '𝔇', 'E': '𝔈', 'F': '𝔉', 'G': '𝔊',
        'H': '𝔋', 'I': '𝔐', 'J': '𝔍', 'K': '𝔎', 'L': '𝔏', 'M': '𝔐', 'N': '𝔑',
        'O': '𝔒', 'P': '𝔓', 'Q': '𝔔', 'R': 'ℜ', 'S': '𝔖', 'T': '𝔗', 'U': '𝔘',
        'V': '𝔙', 'W': '𝔚', 'X': '𝔛', 'Y': '𝔜', 'Z': 'ℨ'
    }
    return ''.join([fraktur_map.get(c, c) for c in text])
def to_fraktur_unicode2(text):
 
    fraktur_map = {'a': 'a̲', 'b': 'b̲', 'c': 'c̲', 'd': 'd̲', 'e': 'e̲', 'f': 'f̲', 'g': 'g̲',
    'h': 'h̲', 'i': 'i̲', 'j': 'j̲', 'k': 'k̲', 'l': 'l̲', 'm': 'm̲', 'n': 'n̲',
    'o': 'o̲', 'p': 'p̲', 'q': 'q̲', 'r': 'r̲', 's': 's̲', 't': 't̲', 'u': 'u̲',
    'v': 'v̲', 'w': 'w̲', 'x': 'x̲', 'y': 'y̲', 'z': 'z̲',

    'A': 'A̲', 'B': 'B̲', 'C': 'C̲', 'D': 'D̲', 'E': 'E̲', 'F': 'F̲', 'G': 'G̲',
    'H': 'H̲', 'I': 'I̲', 'J': 'J̲', 'K': 'K̲', 'L': 'L̲', 'M': 'M̲', 'N': 'N̲',
    'O': 'O̲', 'P': 'P̲', 'Q': 'Q̲', 'R': 'R̲', 'S': 'S̲', 'T': 'T̲', 'U': 'U̲',
    'V': 'V̲', 'W': 'W̲', 'X': 'X̲', 'Y': 'Y̲', 'Z': 'Z̲'}
    return ''.join([fraktur_map.get(c, c) for c in text])
def to_fraktur_unicode3(text):
  
    fraktur_map = {'a': '𝗮', 'b': '𝗯', 'c': '𝗰', 'd': '𝗱', 'e': '𝗲', 'f': '𝗳', 'g': '𝗴', 
 'h': '𝗵', 'i': '𝗶', 'j': '𝗷', 'k': '𝗸', 'l': '𝗹', 'm': '𝗺', 'n': '𝗻', 
 'o': '𝗼', 'p': '𝗽', 'q': '𝗾', 'r': '𝗿', 's': '𝘀', 't': '𝘁', 'u': '𝘂', 
 'v': '𝘃', 'w': '𝘄', 'x': '𝘅', 'y': '𝘆', 'z': '𝘇', 
 'A': '𝗔', 'B': '𝗕', 'C': '𝗖', 'D': '𝗗', 'E': '𝗘', 'F': '𝗙', 'G': '𝗚', 
 'H': '𝗛', 'I': '𝗜', 'J': '𝗝', 'K': '𝗞', 'L': '𝗟', 'M': '𝗠', 'N': '𝗡', 
 'O': '𝗢', 'P': '𝗣', 'Q': '𝗤', 'R': '𝗥', 'S': '𝗦', 'T': '𝗧', 'U': '𝗨', 
 'V': '𝗩', 'W': '𝗪', 'X': '𝗫', 'Y': '𝗬', 'Z': '𝗭'}
    return ''.join([fraktur_map.get(c, c) for c in text])
def to_fraktur_unicode4(text):
 
    fraktur_map = {'a': '𝒶', 'b': '𝒷', 'c': '𝒸', 'd': '𝒹', 'e': 'ℯ', 'f': '𝒻', 'g': '𝑔', 
 'h': '𝒽', 'i': '𝒾', 'j': '𝒿', 'k': '𝓀', 'l': '𝓁', 'm': '𝓂', 'n': '𝓃', 
 'o': 'ℴ', 'p': '𝓅', 'q': '𝓆', 'r': '𝓇', 's': '𝓈', 't': '𝓉', 'u': '𝓊', 
 'v': '𝓋', 'w': '𝓌', 'x': '𝓍', 'y': '𝓎', 'z': '𝓏', 
 'A': '𝒜', 'B': 'ℬ', 'C': '𝒞', 'D': '𝒟', 'E': 'ℰ', 'F': '𝒻', 'G': '𝒢', 
 'H': 'ℋ', 'I': '𝒾', 'J': '𝒥', 'K': '𝒦', 'L': 'ℒ', 'M': 'ℳ', 'N': '𝒩', 
 'O': '𝒪', 'P': '𝒫', 'Q': '𝒬', 'R': 'ℛ', 'S': '𝒮', 'T': '𝒯', 'U': '𝒰', 
 'V': '𝒱', 'W': '𝒲', 'X': '𝒳', 'Y': '𝒴', 'Z': '𝒵'}
    return ''.join([fraktur_map.get(c, c) for c in text])
   

def sample_batch_from_gaussian(text_features: torch.Tensor) -> torch.Tensor:
 
    num_transforms, num_samples, feature_dim = text_features.shape
    sampled_features = []

    for i in range(num_samples):
        feats = text_features[:, i, :].float()  
        mean = feats.mean(dim=0)  
 
        feats_centered = feats - mean
        cov = feats_centered.T @ feats_centered / (num_transforms - 1)  
        cov += torch.eye(feature_dim, device=feats.device, dtype=feats.dtype) * 1e-5
 
        dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
        sampled = dist.sample()  # shape: (1024,)
        sampled_features.append(sampled)

    sampled_features = torch.stack(sampled_features, dim=0)  # shape: (256, 1024)
    return sampled_features.to(text_features.dtype)  

def sample_batch_from_gaussian_pca(text_features: torch.Tensor) -> torch.Tensor:
 
    
    pca_dim = min(6, text_features.shape[0] - 1)
    num_transforms, num_samples, feature_dim = text_features.shape
    sampled_features = []

    for i in range(num_samples):  
        feats = text_features[:, i, :].float()  # (7, 1024)
        pca = PCA(n_components=pca_dim) #pca_dim = 6
        feats_pca = pca.fit_transform(feats.detach().cpu().numpy())  
      
        feats_pca = torch.from_numpy(feats_pca)
 
        explained_variance_ratio = pca.explained_variance_ratio_ 
    
        feats_pca = feats_pca.to(feats.device)
        mean = feats_pca.mean(dim=0)

 
        feats_centered = feats_pca - mean   
 
        cov = (feats_centered.T @ feats_centered / (feats_pca.shape[0] - 1))  

       
        cov += torch.eye(pca_dim, device=feats.device, dtype=feats.dtype) * 1e-5   
 
        dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
        sampled_feature_pca = dist.sample()   
        sampled_feature_1024 = pca.inverse_transform(sampled_feature_pca.detach().cpu().numpy())   
        sampled_features.append(torch.tensor(sampled_feature_1024).to(feats.device))

    sampled_features = torch.stack((sampled_features), dim=0)   
    return sampled_features.to(text_features.dtype)  




def sample_text_features_sampling(transformed_text_features: torch.Tensor, targets_ori_features : torch.Tensor, n_samples: int=10, tolerance: float = 0.15, max_retry_per_sample: int = 20):
 
    device = transformed_text_features.device
    dtype = transformed_text_features.dtype
    
    num_transforms, num_samples, feature_dim = transformed_text_features.shape
    
    transformed_text_feature_with_sampled_list = []
    for i in range(num_samples):
        targets_ori_feature = targets_ori_features[i, :]  # (1024,)

        sampled_features = []

        for idx in range(num_transforms):
            transformed_text_feature = transformed_text_features[:, i, :]  # (7, 1024)
            original_similarity = (torch.nn.functional.cosine_similarity(targets_ori_feature ,transformed_text_feature[idx], dim=-1)) #text_feature / (text_feature.norm(dim=1, keepdim=True) + 1e-8)
 
            while len(sampled_features) < n_samples*num_transforms:
           
                noise_scale = torch.rand(1, 1).to(device)
             
                for attempt in range(max_retry_per_sample):

                    base_feature = transformed_text_feature[idx]

                    direction = base_feature - targets_ori_feature.unsqueeze(0)  # (1,1024)
                 
                    direction = direction / (direction.norm(dim=1, keepdim=True) + 1e-8)
                    direction_noise = direction * noise_scale  

                    sampling_sample = base_feature + direction_noise   # (1,1024)
                    
                    sampling_similarity = (torch.nn.functional.cosine_similarity(targets_ori_feature ,sampling_sample, dim=-1))
                    
                    delta_similarity = (original_similarity - sampling_similarity) / (original_similarity + 1e-8)
                    
                    if 0 < delta_similarity <= tolerance:
                        sampled_features.append(sampling_sample)
                        break   
                    elif delta_similarity > tolerance: 
                        noise_scale *= 0.5   
                    elif delta_similarity < 0:   
                        continue                     
        sampled_features = torch.stack(sampled_features, dim=0).squeeze(1)    #(33,1024)  

        transformed_text_feature_with_sampled = torch.cat([transformed_text_feature, sampled_features], dim=0)  
        transformed_text_feature_with_sampled_list.append(transformed_text_feature_with_sampled ) 
            

    return torch.stack(transformed_text_feature_with_sampled_list, dim=1)   #(40,256,1024) 


def choose_k_by_reconstruction_loss(transformed_text_feature_with_sampled: torch.Tensor,  loss_threshold: float = 0.1):
 
    best_k_list = []
    for i in range(transformed_text_feature_with_sampled.shape[1]):
        max_k = int(transformed_text_feature_with_sampled.shape[0])
        new_text_features_np = transformed_text_feature_with_sampled[:,i,:].detach().cpu().numpy().astype(np.float32)

     
        mean_new_text_features = np.mean(new_text_features_np, axis=0, keepdims=True)
        new_text_features_centered = new_text_features_np - mean_new_text_features

 
        pca = PCA(n_components=min(new_text_features_np.shape))
        pca.fit(new_text_features_centered)

      
        total_variance = np.sum(np.var(new_text_features_centered, axis=0))

     
        losses = []
        k_list = list(range(1, min(max_k, new_text_features_np.shape[0]) + 1))

        for k in k_list:
            components = pca.components_[:k, :]  # (k, 1024)
            new_text_features_proj = new_text_features_centered @ components.T  # (n+7, k)
            new_text_features_recon = new_text_features_proj @ components  # (n+7, 1024)

            recon_error = np.mean(np.sum((new_text_features_centered - new_text_features_recon) ** 2, axis=1)) 
            normalized_loss = recon_error / (total_variance + 1e-8)   
            losses.append(normalized_loss)

         
        for k, loss_value in zip(k_list, losses):
            if loss_value <= loss_threshold:
                best_k = k
                break
        else:
            best_k = max_k   
        best_k_list.append(best_k)
    return best_k_list

def sample_batch_from_gaussian_pca_sampling_recon(transformed_text_features: torch.Tensor, targets_ori_features : torch.Tensor) -> torch.Tensor:
 
    new_text_features = sample_text_features_sampling(transformed_text_features, targets_ori_features)
    
    
    best_k_list = choose_k_by_reconstruction_loss(new_text_features,  loss_threshold=0.05) 
    
    num_transforms, num_samples, feature_dim = new_text_features.shape
 
    sampled_features_256 = []
    V_A_list = []   
    for i in range(num_samples):  
        pca_dim = best_k_list[i]
        print(f"num_transforms:{num_transforms}")
        print(f"pca_dim: {pca_dim}")
        feats = torch.cat([new_text_features[:, i, :].float(),targets_ori_features[i].unsqueeze(0).float() ], dim=0)# (40, 1024)
       
        pca = PCA(n_components=pca_dim)  
        feats_pca = pca.fit_transform(feats.detach().cpu().numpy())  
        V_A_list.append(pca.components_.T)  
 
        feats_pca = torch.from_numpy(feats_pca)
  
        explained_variance_ratio = pca.explained_variance_ratio_   
       
        feats_pca = feats_pca[:-1,:].to(feats.device)
        mean = feats_pca.mean(dim=0)

   
        feats_centered = feats_pca - mean  
        
        cov = (feats_centered.T @ feats_centered / (feats_pca.shape[0] - 1)) *0.001 # shape: (pca_dim, pca_dim)

 
        cov += torch.eye(pca_dim, device=feats.device, dtype=feats.dtype) * 1e-5  #  
 
        dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
        samples_list = []
        for k in range(5):
            sampled_feature_pca = dist.sample()  # shape: (6,)
            sampled_feature_1024 = pca.inverse_transform(sampled_feature_pca.detach().cpu().numpy())  # shape (1024,)
            cosine_sim = torch.nn.functional.cosine_similarity(torch.tensor(sampled_feature_1024).to(targets_ori_features.device), targets_ori_features[i], dim=-1)
            print(f"sampling_cosine_sim:{cosine_sim}")
            samples_list.append(torch.tensor(sampled_feature_1024).to(feats.device))
        sampled_features = torch.stack(samples_list, dim=0) #5*1024
        sampled_features_256.append(sampled_features) 
    sampled_features = torch.stack((sampled_features_256), dim=0)  # shape: (256, 1024)
    return V_A_list  , sampled_features.to(transformed_text_features.dtype)  
 

 
 
def main():
    parser = argparse.ArgumentParser(description="Your script description")
    
    # 添加参数
    parser.add_argument("--attack_samples_model_target", type = str, help="attack_samples_rn50_banana",default="attack_samples_rn50_banana")
    parser.add_argument("--checkpoint", type = str, help="checkpoint", default="./logs/SIG_5000samples_rn50_augment/checkpoints/epoch_8.pt") 
    parser.add_argument("--attack_name", type = str, help="badnet",  default='sig') 
    parser.add_argument("--transformation", type = str, help="font3_edu, font4_edu, descriptions1_edu, descriptions3_edu, descriptions4_edu", default='sample_batch_from_gaussian_pca_sampling_recon')
 
    
   
    args = parser.parse_args()
 
    device = "cuda" if torch.cuda.is_available() else "cpu"

 
    bd_set_csv_df_to_detection_x = np.load(f"./detection_imagenet/{args.attack_samples_model_target}/{args.attack_name}_bd_test_inputs.npy")[:1000]
    bd_set_csv_df_to_detection_y = np.load(f"./detection_imagenet/{args.attack_samples_model_target}/{args.attack_name}_bd_test_targets.npy")[:1000]
 
    clean_set_csv_df_to_detection_x = np.load(f"./detection_imagenet/{args.attack_samples_model_target}/{args.attack_name}_clean_test_inputs.npy")[:1000]
    clean_set_csv_df_to_detection_y = np.load(f"./detection_imagenet/{args.attack_samples_model_target}/{args.attack_name}_clean_test_targets.npy")[:1000]

    device = "cuda" if torch.cuda.is_available() else "cpu"
 
    config = eval(open(f"./data/ImageNet1K/validation/classes.py", "r").read())
    classes = config["classes"]
    config_arabic = eval(open(f"./data/ImageNet1K/validation/classes_arabic.py", "r").read())
    classes_arabic = config_arabic["classes"]
    model,processor = clip.load("RN50", device=device)
    model = nn.DataParallel(model)
 
    checkpoint = args.checkpoint  
    state_dict = torch.load(checkpoint, map_location = device)["state_dict"]

    model.load_state_dict(state_dict)
    model.eval()  
    
 
    all_delta_similarity_bd_list = []
    all_delta_similarity_clean_list = []

    
    for weight_index in range(3): 
        clean_text_delta_list=[] 
        bd_text_delta_list = []
        true_labels = []  
        scores = []   

        
 
            
        ######################################
        ######clean 
        ###################################    
        delta_abs_clean_list_all = []
        test_clean_set = ImageLabelDataset_clean(clean_set_csv_df_to_detection_x,clean_set_csv_df_to_detection_y) 
        test_clean_set_loader = DataLoader(test_clean_set, batch_size=256, shuffle=False)
        for batch_idx, (inputs, targets) in enumerate(test_clean_set_loader):
            
            inputs_ =  _augment_batch_image2(inputs).to(device)
            targets_ori = [classes[targets[i]] for i in range(len(targets))]
            targets_ori_ =  clip.tokenize([_augment_text(target_ori) for target_ori in targets_ori]).to(device)
            
            targets_arabic =  [classes_arabic[targets[i]] for i in range(len(targets))]
            targets_arabic_ =  clip.tokenize(targets_arabic).to(device)
            
            targets_font3 = [to_fraktur_unicode3(classes[targets[i]]) for i in range(len(targets))]
            targets_font3_ =  clip.tokenize(targets_font3).to(device)

            targets_font4 = [to_fraktur_unicode4(classes[targets[i]]) for i in range(len(targets))]
            targets_font4_ =  clip.tokenize(targets_font4).to(device)        
            
            descriptions1 = get_descriptions_for_targets(targets)
            descriptions1_ = clip.tokenize(descriptions1).to(device) 
            descriptions3 = get_descriptions_for_targets3(targets)
            descriptions3_ = clip.tokenize(descriptions3).to(device) 
            descriptions4 = get_descriptions_for_targets4(targets)
            descriptions4_ = clip.tokenize(descriptions4).to(device)    
            descriptions5 = get_descriptions_for_targets5(targets)
            descriptions5_ = clip.tokenize(descriptions5).to(device)      
            

            descriptions6 = get_descriptions_for_targets6(targets)
            descriptions6_ = clip.tokenize(descriptions6).to(device)   
            descriptions7 = get_descriptions_for_targets7(targets)
            descriptions7_ = clip.tokenize(descriptions7).to(device)                
             
            with torch.no_grad():
                inputs_images_features = model.module.encode_image(inputs_)
                targets_ori_features = model.module.encode_text(targets_ori_)
                targets_font3_features = model.module.encode_text(targets_font3_)
                targets_font4_features = model.module.encode_text(targets_font4_)
                
                descriptions1_features = (model.module.encode_text(descriptions1_)) 
                descriptions3_features = (model.module.encode_text(descriptions3_)) 
                descriptions4_features = (model.module.encode_text(descriptions4_)) 
                descriptions5_features = (model.module.encode_text(descriptions5_)) 

                descriptions6_features = (model.module.encode_text(descriptions6_))  
                descriptions7_features = (model.module.encode_text(descriptions7_))                  
                target_arabic_features = model.module.encode_text(targets_arabic_)
            

            if args.transformation == 'sample_batch_from_gaussian_pca_sampling_recon': 
                guassion_feature = torch.stack([targets_font3_features, targets_font4_features, descriptions1_features, descriptions3_features, descriptions4_features, descriptions5_features,  descriptions6_features, descriptions7_features, target_arabic_features]).to(device)
                V_A_list, sampling_feature_batch = sample_batch_from_gaussian_pca_sampling_recon(guassion_feature,targets_ori_features)

                for i in range(len(V_A_list)):
                    recon_error = 0
                    for j in range(5):

                        recon_error += np.linalg.norm((inputs_images_features[i].unsqueeze(0).float() - sampling_feature_batch[i][j].unsqueeze(0).to(device)).detach().cpu().numpy())
                    clean_text_delta_list.append(recon_error.tolist())

        all_delta_similarity_clean_list.append(clean_text_delta_list)

        



        ######################################
        ######backdoor 
        ###################################    
        delta_abs_bd_list_all = []
        test_bd_set = ImageLabelDataset_(bd_set_csv_df_to_detection_x,bd_set_csv_df_to_detection_y) 
        test_bd_set_loader = DataLoader(test_bd_set, batch_size=256, shuffle=False)
        for bd_batch_idx, (inputs,  targets) in enumerate(test_bd_set_loader):
            inputs_ =  _augment_batch_image2(inputs).to(device)
            targets_ori = [classes[targets[i]] for i in range(len(targets))]
            targets_ori_ =  clip.tokenize([_augment_text(target_ori) for target_ori in targets_ori]).to(device)
            
            targets_arabic =  [classes_arabic[targets[i]] for i in range(len(targets))]
            targets_arabic_ =  clip.tokenize(targets_arabic).to(device)
            
            targets_font3 = [to_fraktur_unicode3(classes[targets[i]]) for i in range(len(targets))]
            targets_font3_ =  clip.tokenize(targets_font3).to(device)

            targets_font4 = [to_fraktur_unicode4(classes[targets[i]]) for i in range(len(targets))]
            targets_font4_ =  clip.tokenize(targets_font4).to(device)        
            
            descriptions1 = get_descriptions_for_targets(targets)
            descriptions1_ = clip.tokenize(descriptions1).to(device) 
            descriptions3 = get_descriptions_for_targets3(targets)
            descriptions3_ = clip.tokenize(descriptions3).to(device) 
            descriptions4 = get_descriptions_for_targets4(targets)
            descriptions4_ = clip.tokenize(descriptions4).to(device)    
            descriptions5 = get_descriptions_for_targets5(targets)
            descriptions5_ = clip.tokenize(descriptions5).to(device)       
            
            descriptions6 = get_descriptions_for_targets6(targets)
            descriptions6_ = clip.tokenize(descriptions6).to(device)   
            descriptions7 = get_descriptions_for_targets7(targets)
            descriptions7_ = clip.tokenize(descriptions7).to(device)                
            
            with torch.no_grad():
                inputs_images_features = model.module.encode_image(inputs_)
                targets_ori_features = model.module.encode_text(targets_ori_)
                targets_font3_features = model.module.encode_text(targets_font3_)
                targets_font4_features = model.module.encode_text(targets_font4_)
                
                descriptions1_features = (model.module.encode_text(descriptions1_)) 
                descriptions3_features = (model.module.encode_text(descriptions3_)) 
                descriptions4_features = (model.module.encode_text(descriptions4_)) 
                descriptions5_features = (model.module.encode_text(descriptions5_)) 
                
                descriptions6_features = (model.module.encode_text(descriptions6_))  
                descriptions7_features = (model.module.encode_text(descriptions7_))                  
                
                target_arabic_features = model.module.encode_text(targets_arabic_)
            

            if args.transformation == 'sample_batch_from_gaussian_pca_sampling_recon': 
                guassion_feature = torch.stack([targets_font3_features, targets_font4_features, descriptions1_features, descriptions3_features, descriptions4_features, descriptions5_features, descriptions6_features, descriptions7_features,target_arabic_features]).to(device)
                V_A_list, sampling_feature_batch = sample_batch_from_gaussian_pca_sampling_recon(guassion_feature,targets_ori_features)
               
                     
                for i in range(len(V_A_list)):
                    recon_error = 0
                    for j in range(5):

                        recon_error += np.linalg.norm((inputs_images_features[i].unsqueeze(0).float() - sampling_feature_batch[i][j].unsqueeze(0).to(device)).detach().cpu().numpy())
                    bd_text_delta_list.append(recon_error.tolist())
 
                                        

        all_delta_similarity_bd_list.append(bd_text_delta_list)
         
    all_delta_similarity_clean_list_stack = np.vstack(all_delta_similarity_clean_list)
    all_delta_similarity_clean_list_stack_mean = np.mean(all_delta_similarity_clean_list_stack, axis=0)
    all_delta_similarity_bd_list_stack = np.vstack(all_delta_similarity_bd_list)
    all_delta_similarity_bd_list_stack_mean = np.mean(all_delta_similarity_bd_list_stack, axis=0)
            
    true_labels.extend([0] * 1000)   
    scores.extend(all_delta_similarity_clean_list_stack_mean.tolist())            
    delta_abs_clean_list_all = torch.FloatTensor(all_delta_similarity_clean_list_stack_mean)
    
    
    true_labels.extend([1] * 1000)   
    scores.extend(all_delta_similarity_bd_list_stack_mean.tolist())  
    delta_abs_bd_list_all = torch.FloatTensor(all_delta_similarity_bd_list_stack_mean)

 
    scores = torch.FloatTensor(scores)
    neg_scores = scores.numpy().tolist()
   
    fpr, tpr, thresholds = roc_curve(true_labels, neg_scores) 
    roc_auc = auc(fpr, tpr)
    print(f"AUROC: {roc_auc}")   

 
    precision, recall, thresholds_pr = precision_recall_curve(true_labels, neg_scores)

   
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)   
    best_f1_idx = f1_scores.argmax()   
    best_threshold = thresholds_pr[best_f1_idx]

 
    pred_labels = (neg_scores >= best_threshold).astype(int)



 
    final_f1 = f1_score(true_labels, pred_labels)
    
    
 
    filename = f'./detection_imagenet/ptp/csv_result_for_attacks_rn50_banana/{args.attack_name}.csv'
    fields = ['transformation','checkpoint','AUROC','final_f1']     

   
    if not os.path.exists(filename):
         
        with open(filename, mode='w', newline='') as file:
            writer = csv.DictWriter(file, fieldnames=fields)
            writer.writeheader()
    with open(filename, mode='a', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=fields)
        writer.writerow({
            'transformation':args.transformation,
            'checkpoint': checkpoint,
            'AUROC': roc_auc,
            'final_f1':final_f1,
        })

 

 

if __name__ == "__main__":
    main()