from __future__ import print_function

import argparse
import os
import time
import random
import logging
from tqdm import tqdm
from copy import deepcopy as dcopy

import torch
from torch.cuda.amp import GradScaler, autocast

from replace import clip
from models.prompters import TokenPrompter, NullPrompter
from method_utils import *
from attacks import *
from func import clip_img_preprocessing, multiGPU_CLIP, multiGPU_CLIP_v1, multiGPU_CLIP_v2, multiGPU_CLIP_v2_orig

# torch.cuda.set_device(4)
# device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
# gpu_id = os.getenv("GPU_ID", "0")  # 如果未设置GPU_ID，则默认使用设备0
# device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")

def parse_options():
    parser = argparse.ArgumentParser()

    parser.add_argument('--evaluate', type=bool, default=True) # eval mode
    parser.add_argument('--batch_size', type=int, default=16, help='batch_size')
    parser.add_argument('--num_workers', type=int, default=0, help='num of workers to use')
    parser.add_argument('--cache', type=str, default='./cache')    

    # test setting
    parser.add_argument('--test_set', default=[], type=str, nargs='+') # defaults to 17 datasets, if not specified
    parser.add_argument('--test_attack_type', type=str, default="pgd", choices=['pgd', 'CW', 'autoattack',])
    parser.add_argument('--test_eps', type=float, default=1, help='test attack budget')
    parser.add_argument('--test_numsteps', type=int, default=10)
    parser.add_argument('--test_stepsize', type=int, default=1)

    # model
    parser.add_argument('--model', type=str, default='clip')
    parser.add_argument('--backbone', type=str, default='ViT-B/32')
    parser.add_argument('--method', type=str, default='null_patch',
                        choices=['null_patch'], help='choose visual prompting method')
    parser.add_argument('--name', type=str, default='')
    parser.add_argument('--prompt_size', type=int, default=30, help='size for visual prompts')
    parser.add_argument('--add_prompt_size', type=int, default=0, help='size for additional visual prompts')

    # data
    parser.add_argument('--root', type=str, default='/NAS/zhuxyu/alldatasets', help='dataset path')

    parser.add_argument('--dataset', type=str, default='tinyImageNet', help='dataset used for AFT methods')
    parser.add_argument('--image_size', type=int, default=224, help='image size')
    
    # TTC config
    parser.add_argument('--seed', type=int, default=0, help='seed for initializing training')
    # parser.add_argument('--victim_resume', type=str, default='/NAS/zhuxyu/CLIP-Test-time-Counterattacks-main_paralle/AFT_model_weights/pmg_checkpoint_best.pth.tar', help='model weights of victim to attack.')
    parser.add_argument('--victim_resume', type=str, default=None, help='model weights of victim to attack.')
    parser.add_argument('--outdir', type=str, default=None, help='output directory for results')

    parser.add_argument('--tau_thres', type=float, default=0.2)
    parser.add_argument('--beta', type=float, default=2.,)
    parser.add_argument('--ttc_eps', type=float, default=4.)
    parser.add_argument('--ttc_numsteps', type=int, default=2)
    parser.add_argument('--ttc_stepsize', type=float, default=1.)

    args = parser.parse_args()
    return args

def calculate_batch_entropy(logits):
    return -(logits.softmax(-1) * logits.log_softmax(-1)).sum(-1)

@torch.no_grad()
def get_entropy_weight(output, img_t=0.5, text_t=0.5):
    with torch.cuda.amp.autocast():
        # get weights for images
        image_entropy = calculate_batch_entropy(output.mean(1))
        image_weights = F.softmax(-image_entropy/img_t, dim=-1)

        # get weights for descriptors
        _, n_des, n_cls = output.shape
        anchor = output[0].mean(0)[None, None, :].repeat(n_des, n_cls, 1)
        output_des = output[0].unsqueeze(-1)
        # scatter_indices = torch.arange(n_cls)[None, :, None].repeat(n_des, 1, 1).to(device)
        scatter_indices = torch.arange(n_cls)[None, :, None].repeat(n_des, 1, 1).cuda()

        anchor.scatter_(dim=2, index=scatter_indices, src=output_des) # n_des, n_cls, n_cls
        text_entropy = calculate_batch_entropy(anchor)
        text_weights = F.softmax(-text_entropy.t()/text_t, dim=-1) # n_cls, n_des

    return image_weights, text_weights

def Sinkhorn(K, u, v):
    r = torch.ones_like(u)
    c = torch.ones_like(v)
    thresh = 1e-2
    for i in range(100):
        r0 = r
        r = u / torch.matmul(K, c.unsqueeze(-1)).squeeze(-1)
        c = v / torch.matmul(K.permute(0, 2, 1).contiguous(), r.unsqueeze(-1)).squeeze(-1)
        err = (r - r0).abs().mean()
        if err.item() < thresh:
            break
    T = torch.matmul(r.unsqueeze(-1), c.unsqueeze(-2)) * K
    return T

def optimal_transport(logits, logit_scale, image_weights, text_weights):
    eps = 0.1
    sim = logits / logit_scale.exp()
    sim = sim.permute(2, 0, 1) # n_cls x M x N

    wdist = 1.0 - sim
    with torch.no_grad():
        KK = torch.exp(-wdist / eps)
        T = Sinkhorn(KK, image_weights, text_weights)
        T = T.permute(1, 2, 0)
    assert not torch.isnan(T).any()

    return torch.sum(T * logits, dim=(0, 1)).unsqueeze(0)

def compute_tau(clip_visual, images, n):
    orig_feat = clip_visual(clip_img_preprocessing(images), None) # [bs, 512]
    noisy_feat = clip_visual(clip_img_preprocessing(images + n), None)
    diff_ratio = (noisy_feat - orig_feat).norm(dim=-1) / orig_feat.norm(dim=-1) # [bs]
    return diff_ratio

def tau_thres_weighted_counterattacks(model, X, prompter, add_prompter, alpha, attack_iters, 
                           norm="l_inf", epsilon=0, visual_model_orig=None,
                           tau_thres:float=None, beta:float=None, clip_visual=None):
    delta = torch.zeros_like(X)
    if epsilon <= 0.:
        return delta

    if norm == "l_inf":
        delta.uniform_(-epsilon, epsilon)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon
    else:
        raise ValueError

    delta = clamp(delta, lower_limit - X, upper_limit - X)
    delta.requires_grad = True

    if attack_iters == 0: # apply random noise (RN)
        return delta.data

    diff_ratio = compute_tau(clip_visual, X, delta.data) if clip_visual is not None else None

    # Freeze model parameters temporarily. Not necessary but for completeness of code
    tunable_param_names = []
    # for n,p in model.module.named_parameters():
    for n,p in model.named_parameters():

        if p.requires_grad: 
            tunable_param_names.append(n)
            p.requires_grad = False

    prompt_token = add_prompter()
    with torch.no_grad():
        X_ori_reps = model.encode_image(
                prompter(clip_img_preprocessing(X)), prompt_token
        )
        X_ori_norm = torch.norm(X_ori_reps, dim=-1) # [ bs]

    deltas_per_step = []
    deltas_per_step.append(delta.data.clone())

    for _step_id in range(attack_iters):

        prompted_images = prompter(clip_img_preprocessing(X + delta))
        # with torch.no_grad():
        X_att_reps = model.encode_image(prompted_images, prompt_token)
        # X_att_reps.requires_grad_(True)

        if _step_id == 0 and diff_ratio is None: # compute tau at the zero-th step
            feature_diff = X_att_reps - X_ori_reps # [bs, 512]
            diff_ratio = torch.norm(feature_diff, dim=-1) / X_ori_norm # [bs]
        
        scheme_sign = (tau_thres - diff_ratio).sign()
        
        l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))).sum()
        grad = torch.autograd.grad(l2_loss, delta)[0]
        d = delta[:, :, :, :]
        g = grad[:, :, :, :]
        x = X[:, :, :, :]

        if norm == "l_inf":
            d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
            scaled_g = g / (g_norm + 1e-10)
            d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
        d = clamp(d, lower_limit - x, upper_limit - x)
        delta.data[:, :, :, :] = d
        deltas_per_step.append(delta.data.clone())

    Delta = torch.stack(deltas_per_step, dim=1) # [bs, numsteps+1, C, W, H]
    
    # create weights across steps
    # weights = torch.arange(attack_iters+1).unsqueeze(0).expand(X.size(0), -1).to(device) # [bs, numsteps+1]
    weights = torch.arange(attack_iters+1).unsqueeze(0).expand(X.size(0), -1).cuda() # [bs, numsteps+1]

    weights = torch.exp(
        scheme_sign.view(-1, 1) * weights * beta
    ) # [bs, numsteps+1]
    weights /= weights.sum(dim=1, keepdim=True)

    weights_hard = torch.zeros_like(weights) # [bs, numsteps+1]
    weights_hard[:,0] = 1.

    weights = torch.where(scheme_sign.unsqueeze(1)>0, weights, weights_hard)
    weights = weights.view(X.size(0), attack_iters+1, 1, 1, 1)
    
    Delta = (weights * Delta).sum(dim=1)
    
    # Unfreeze model parameters. Only for completeness of code
    for n,p in model.named_parameters():
        if n in tunable_param_names:
            p.requires_grad = True

    return Delta


def validate(args, val_dataset_name, model, model_text, model_image,
             prompter, add_prompter, criterion, visual_model_orig=None,
             clip_visual=None
    ):
    
    logging.info(f"Evaluate with Attack method: {args.test_attack_type}")

    dataset_num = len(val_dataset_name)
    # all_clean_org, all_clean_ttc, all_adv_org, all_adv_ttc = {},{},{},{}

    test_stepsize = args.test_stepsize

    ttc_eps = args.ttc_eps
    ttc_numsteps = args.ttc_numsteps
    ttc_stepsize = args.ttc_stepsize
    beta = args.beta
    tau_thres = args.tau_thres

    for cnt in range(dataset_num):
        
        all_clean_org, all_clean_ttc, all_adv_org, all_adv_ttc = {},{},{},{}        

        val_dataset, val_loader = load_val_dataset(args, val_dataset_name[cnt])
        dataset_name = val_dataset_name[cnt]
        # print('data_num', dataset_num, 'cnt', cnt)
        # print(dataset_name)


        # texts = get_text_prompts_val([val_dataset], [dataset_name])[0]
        texts = get_text_prompts_val_v1([val_dataset], [dataset_name])[0]
        class_nums = len(val_dataset.classes)

        binary = ['PCAM', 'hateful_memes']
        attacks_to_run=['apgd-ce', 'apgd-dlr']
        if dataset_name in binary:
            attacks_to_run=['apgd-ce']

        batch_time = AverageMeter('Time', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1_org = AverageMeter('Original Acc@1', ':6.2f')
        top1_org_ttc = AverageMeter('Prompt Acc@1', ':6.2f')
        top1_adv = AverageMeter('Adv Original Acc@1', ':6.2f')
        top1_adv_ttc = AverageMeter('Adv Prompt Acc@1', ':6.2f')

        # switch to evaluation mode
        prompter.eval()
        add_prompter.eval()
        model.eval()

        batch_size = 2048
        text_features = list()
        batch_num = len(texts) // batch_size + 1
        with torch.no_grad():
            for i in tqdm(range(batch_num)):
                start = i * batch_size
                end = start + batch_size
                # text_tokens = clip.tokenize(texts[start:end]).to(device)
                text_tokens = clip.tokenize(texts[start:end]).cuda()

                features = model.encode_text(text_tokens)
                text_features.append(features.cpu())
            text_features = torch.cat(text_features)
        end = time.time()
        
        # text_features = (text_features / text_features.norm(dim=-1, keepdim=True)).to(device) # [n_class, d_emb]
        text_features = (text_features / text_features.norm(dim=-1, keepdim=True)).cuda() # [n_class, d_emb]


        # text_features_reshape = text_features.clone().view(-1, class_nums, text_features.shape[-1]).permute(1, 0, 2).contiguous()
        text_features_reshape = text_features.clone().view( class_nums, -1, text_features.shape[-1]).contiguous()
        # text_projector = torch.einsum('cdn,cnb->cdb',text_features_reshape.permute(0,2,1), text_features_reshape).to(device)
        text_projector = torch.einsum('cdn,cnb->cdb',text_features_reshape.permute(0,2,1), text_features_reshape).cuda()



        for i, (images, target) in enumerate(tqdm(val_loader)):
            
            # if isinstance(images, list):
                # images = torch.cat(images)
            # images = images.to(device)
            # target = target.to(device)

            images = images.cuda()
            target = target.cuda()

            aug_images = clip_img_aug_preprocessing()(images)
            batch, aug_num, = aug_images.shape[0], aug_images.shape[1]
            c, h, w = aug_images.shape[2], aug_images.shape[3], aug_images.shape[4],
            reshape_aug_images = aug_images.reshape(-1, c, h, w).contiguous()

            with autocast():

                # original acc of clean images
                with torch.no_grad():
                    # orig_images = aug_images[:,0,:,:,:].reshape(-1, c, h, w)
                    orig_images = reshape_aug_images
                    image_features = model.encode_image(prompter(clip_img_preprocessing(orig_images)), None)
                    # image_features = model.encode_image(prompter(clip_img_preprocessing(orig_images)), None)
    
                    image_features = F.normalize(image_features, dim=-1)

                    # orig_text_features = text_features_reshape[:,0,:,]
                    # orig_clean_output = image_features @ orig_text_features.t()

                    clean_output_aug = image_features @ text_features.t()
                    # clean_output = clean_output_aug.view(image_features.shape[0], class_nums, -1).mean(dim=-1)
                    # # clean_output = clean_output_aug.view(batch,  -1, class_nums,).mean(dim=1)
                    # clean_output = clean_output.view(batch, aug_num, class_nums).mean(dim=1)

                    clean_output = clean_output_aug.view(image_features.shape[0], class_nums, -1)[:,:,0]
                    clean_output = clean_output.view(batch, aug_num, class_nums)[:,0,:]

                    # clean_acc = accuracy(clean_output.mean(dim=(1,2)), target, topk=(1,))
                    clean_acc = accuracy(clean_output, target, topk=(1,))

                    top1_org.update(clean_acc[0].item(), batch)
       
#---------------------------------------------------------------------------------------------------------
                # TTC on clean images
                ttc_delta_clean = tau_thres_weighted_counterattacks(
                    model, aug_images[:,0,:,:,:].reshape(-1, c, h, w), prompter, add_prompter,
                    alpha=ttc_stepsize, attack_iters=ttc_numsteps,
                    norm='l_inf', epsilon=ttc_eps, visual_model_orig=None,
                    tau_thres=tau_thres, beta = beta,
                    clip_visual=clip_visual
                )

                # reshape_ttx_delta_clean = ttc_delta_clean.view(batch, aug_num, -1).contiguous()
                # ttc_delta_clean_repeat = ttc_delta_clean[:,None,:,:,:].repeat(1,5,1,1,1).reshape(-1, c, h, w)
                with torch.no_grad():
                    clean_output_ttc,_,_,_ = multiGPU_CLIP_v2(
                        # None, None, None, model, prompter(clip_img_preprocessing(reshape_aug_images + ttc_delta_clean_repeat)),
                        None, None, None, model, prompter(clip_img_preprocessing(reshape_aug_images)),

                        text_features = text_features,
                        text_projector = text_projector,
                        prompt_token = None, 
                        dataset_name = dataset_name, 
                        dataset_class_num = class_nums,
                        batch = batch, 
                        aug_num = aug_num, 
                    )
                    # clean_acc_ttc = accuracy(clean_output_ttc, target, topk=(1,))

                    image_temperature = 1
                    text_temperature = 1
                    
                    for idx in range(batch):

                        image_weights, text_weights = get_entropy_weight(clean_output_ttc[idx], img_t=image_temperature, text_t=text_temperature)

                        output_ot = optimal_transport(clean_output_ttc[idx], model.logit_scale.exp(), image_weights, text_weights)


                        clean_acc_ttc = accuracy(output_ot, target[idx][None], topk=(1,))

                        top1_org_ttc.update(clean_acc_ttc[0].item(), 1)

#---------------------------------------------------------------------------------------------------------
                # generate adv samples for this batch
                torch.cuda.empty_cache()
                if args.test_attack_type == "pgd":
                    delta_prompt = attack_pgd_v1(args, prompter, model, model_text, model_image, add_prompter, criterion,
                                              images, target, test_stepsize, args.test_numsteps, 'l_inf',
                                              text_features=text_features, 
                                              text_projector = text_projector,
                                              epsilon=args.test_eps, 
                                              dataset_name=dataset_name,
                                              dataset_class_num = class_nums,
                                              batch = batch, 
                                              aug_num = aug_num, 
                                            )
                    attacked_images = images + delta_prompt
                    attacked_images = clip_img_aug_preprocessing()(attacked_images)
                    attacked_images = attacked_images.view(batch * aug_num, c, h, w).contiguous()
                    # attacked_images = reshape_aug_images + delta_prompt[:, None, :, :, :].repeat(1,5,1,1,1).view(batch * aug_num, c, h, w)
                elif args.test_attack_type == "CW":
                    delta_prompt = attack_CW_v1(args, prompter, model, model_text, model_image, add_prompter, criterion,
                                               images, target, test_stepsize, args.test_numsteps, 'l_inf',
                                              text_features=text_features, 
                                              text_projector = text_projector,
                                              epsilon=args.test_eps, 
                                              dataset_name=dataset_name,
                                              dataset_class_num = class_nums,
                                              batch = batch, 
                                              aug_num = aug_num, 
                                            )
                    attacked_images = images + delta_prompt
                    attacked_images = clip_img_aug_preprocessing()(attacked_images)
                    attacked_images = attacked_images.view(batch * aug_num, c, h, w).contiguous()
                elif args.test_attack_type == "autoattack":
                    attacked_images = attack_auto(model, images, target, text_tokens,
                        None, None, epsilon=args.test_eps, attacks_to_run=attacks_to_run)
                    
#---------------------------------------------------------------------------------------------------------
                # acc of adv images without ttc
                with torch.no_grad():
                    adv_output,_,_,_ = multiGPU_CLIP_v2_orig(
                        None,None,None, 
                        model, 
                        prompter(clip_img_preprocessing(attacked_images)),
                        text_features = text_features,
                        text_projector = text_projector,
                        prompt_token = None, 
                        dataset_name = dataset_name, 
                        dataset_class_num = class_nums,
                        batch = batch, 
                        aug_num = aug_num, 
                    )
                
                    adv_output_clip = adv_output[:,0,0,:]
                    adv_acc = accuracy(adv_output_clip, target, topk=(1,))
                    # adv_acc = accuracy(adv_output.mean(dim=1).mean(dim=0, keepdim=True), target, topk=(1,))
                    top1_adv.update(adv_acc[0].item(), 1)
                
#---------------------------------------------------------------------------------------------------------
                # acc of adv images with ttc
                ttc_delta_adv = tau_thres_weighted_counterattacks(
                    model, attacked_images.data, prompter, add_prompter,
                    alpha=ttc_stepsize, attack_iters=ttc_numsteps,
                    norm='l_inf', epsilon=ttc_eps, visual_model_orig=None,
                    tau_thres=tau_thres, beta = beta,
                    clip_visual = clip_visual
                )
                with torch.no_grad():
                    adv_output_ttc,_,_,_ = multiGPU_CLIP_v2(
                        None,None,None, 
                        model, 
                        # prompter(clip_img_preprocessing(attacked_images + ttc_delta_adv)),
                        prompter(clip_img_preprocessing(attacked_images )),

                        text_features=text_features, 
                        text_projector = text_projector,
                        prompt_token=None, 
                        dataset_name=dataset_name,
                        dataset_class_num = class_nums,
                        batch = batch, 
                        aug_num = aug_num,     
                    )

                    image_temperature = 0.5
                    text_temperature = 0.5

                    for idx in range(batch):

                        image_weights, text_weights = get_entropy_weight(adv_output_ttc[idx], img_t=image_temperature, text_t=text_temperature)

                        output_ot_adv = optimal_transport(adv_output_ttc[idx], model.logit_scale.exp(), image_weights, text_weights)

                        adv_output_acc = accuracy(output_ot_adv, target[idx][None], topk=(1,))

                        top1_adv_ttc.update(adv_output_acc[0].item(), 1)


            batch_time.update(time.time() - end)
            end = time.time()

            if i % 50 == 0:
                show_results = f"{dataset_name}:\n\t"
                show_results += f"- clean acc.  {top1_org.avg:.2f} (ttc: {top1_org_ttc.avg:.2f})\n\t"
                show_results += f"- robust acc. {top1_adv.avg:.2f} (ttc: {top1_adv_ttc.avg:.2f})"
                
                logging.info(show_results)

        torch.cuda.empty_cache()
        clean_acc = top1_org.avg
        clean_ttc_acc = top1_org_ttc.avg
        adv_acc = top1_adv.avg
        adv_ttc_acc = top1_adv_ttc.avg

        all_clean_org[dataset_name] = clean_acc
        all_clean_ttc[dataset_name] = clean_ttc_acc
        all_adv_org[dataset_name] = adv_acc
        all_adv_ttc[dataset_name] = adv_ttc_acc

        show_text = f"===== SUMMARY ACROSS {dataset_num} DATASETS =====\n\t"
        show_text += f"{dataset_name}:\n\t"
        show_text += f"- clean acc.  {clean_acc:.2f} (ttc: {clean_ttc_acc:.2f})\n\t"
        show_text += f"- robust acc. {adv_acc:.2f} (ttc: {adv_ttc_acc:.2f})"
        
        logging.info(show_text)

    # all_clean_org_avg = np.mean([all_clean_org[name] for name in all_clean_org]).item()
    # all_clean_ttc_avg = np.mean([all_clean_ttc[name] for name in all_clean_ttc]).item()
    # all_adv_org_avg = np.mean([all_adv_org[name] for name in all_adv_org]).item()
    # all_adv_ttc_avg = np.mean([all_adv_ttc[name] for name in all_adv_ttc]).item()
    # show_text = f"===== SUMMARY ACROSS {dataset_num} DATASETS =====\n\t"
    # show_text += f"AVG acc. {all_clean_org_avg:.2f} (ttc: {all_clean_ttc_avg:.2f})\n\t"
    # show_text += f"AVG acc. {all_adv_org_avg:.2f} (ttc: {all_adv_ttc_avg:.2f})"
    # logging.info(show_text)

    # zs_clean_org_avg = np.mean([all_clean_org[name] for name in val_dataset_name if name != args.dataset]).item()
    # zs_clean_ttc_avg = np.mean([all_clean_ttc[name] for name in val_dataset_name if name != args.dataset]).item()
    # zs_adv_org_avg = np.mean([all_adv_org[name] for name in val_dataset_name if name != args.dataset]).item()
    # zs_adv_ttc_avg = np.mean([all_adv_ttc[name] for name in val_dataset_name if name != args.dataset]).item()
    # show_text = f"===== SUMMARY ACROSS {dataset_num-1} DATASETS (EXCEPT {args.dataset}) =====\n\t"
    # show_text += f"AVG acc. {zs_clean_org_avg:.2f} (ttc: {zs_clean_ttc_avg:.2f})\n\t"
    # show_text += f"AVG acc. {zs_adv_org_avg:.2f} (ttc: {zs_adv_ttc_avg:.2f})"
    # logging.info(show_text)

    # return all_clean_org_avg, all_clean_ttc_avg, all_adv_org_avg, all_adv_ttc_avg
    return None, None, None, None   

# device = "cuda" if torch.cuda.is_available() else "cpu"

def main():

    args = parse_options()

    outdir = args.outdir if args.outdir is not None else "TTC_results"
    outdir = os.path.join(outdir, f"{args.test_attack_type}_eps_{args.test_eps}_numsteps_{args.test_numsteps}")
    os.makedirs(outdir, exist_ok=True)

    args.test_eps = args.test_eps / 255.
    args.test_stepsize = args.test_stepsize / 255.

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    args.ttc_stepsize = args.ttc_stepsize / 255.
    args.ttc_eps = args.ttc_eps / 255.

    imagenet_root = '/NAS/zhuxyu/alldatasets/imagenet'
    tinyimagenet_root = "/NAS/zhuxyu/alldatasets/tiny-imagenet-200"
    args.imagenet_root = imagenet_root
    args.tinyimagenet_root = tinyimagenet_root

    # load model
    model, _ = clip.load(args.backbone, jit=False, prompt_len=0)
    model = model.cuda()
    for p in model.parameters():
        p.requires_grad = False
    # convert_models_to_fp32(model)

    if args.victim_resume: # employ TTC on AFT checkpoints
        clip_visual = dcopy(model.visual)
        model = load_checkpoints2(args, args.victim_resume, model, None)
        model_name = args.victim_resume.split('/')[-1].split('_')[0]
    else:                  # employ TTC on the original CLIP
        clip_visual = None
        model_name = 'CLIP'
    
    # logging.info(args)

    log_filename = ""
    # log_filename += f"ttc_eps_{args.ttc_eps}_thres_{args.tau_thres}_beta_{args.beta}_numsteps_{args.ttc_numsteps}_stepsize_{args.ttc_stepsize}_seed_{seed}.log".replace(" ", "")
    log_filename += f"test_eps_{args.test_eps}_numsteps_{args.test_numsteps}_stepsize_{args.test_stepsize}_model_{model_name}_seed_{seed}.log".replace(" ", "")

    log_filename = os.path.join(outdir, log_filename)
    logging.basicConfig(
        filename = log_filename,
        level = logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )
    logging.info("Arguments:")
    for key, value in vars(args).items():
        logging.info(f"  {key}: {value}")


    # model = torch.nn.DataParallel(model)
    model.eval()
    prompter = NullPrompter().cuda()
    add_prompter = TokenPrompter(0).cuda()
    # prompter = torch.nn.DataParallel(prompter).cuda()
    # add_prompter = torch.nn.DataParallel(add_prompter).cuda()
    logging.info("done loading model.")

    if len(args.test_set) == 0:
        test_set = DATASETS
    else:
        test_set = args.test_set

    print(test_set)
    print(len(test_set))
    
    # criterion to compute attack loss, the reduction of 'sum' is important for effective attacks
    # criterion_attack = torch.nn.CrossEntropyLoss(reduction='sum').to(device)
    criterion_attack = torch.nn.CrossEntropyLoss(reduction='sum').cuda()

    validate(
        args, test_set, model, None, None, prompter,
        add_prompter, criterion_attack, None, clip_visual
    )

if __name__ == "__main__":
    main()
