import datetime
import os
import time
import logging
import random
import warnings

import torch
import torch.utils.data
import torchvision
import torchvision.transforms
import numpy as np
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from mytorch.sampler import RASampler
from mytorch.transforms import get_mixup_cutmix
import mytorch.presets as presets
import mytorch.utils as utils
from mymodels import MODEL_DICT, WEIGHTS_DICT
from settings import selected_classes


class FcModel(nn.Module):
    def __init__(self, feature_dim, num_classes, weight_ma, bias_ma):
        super(FcModel, self).__init__()
        self.fc = nn.Linear(feature_dim, num_classes)
        self.fc.weight.data = weight_ma
        self.fc.bias.data = bias_ma if bias_ma is not None else torch.zeros(num_classes)

    def forward_cosine(self, x):
        cls_weights = self.fc.weight
        cls_weights = cls_weights / cls_weights.norm(dim=-1, p=2, keepdim=True)
        cosine_x = x / x.norm(dim=-1, p=2, keepdim=True)
        cosine_similarities = cosine_x @ cls_weights.permute(1, 0)

        return cosine_similarities

    def forward(self, x):
        inter_x = x

        cls_weights = self.fc.weight
        cls_weights = cls_weights / cls_weights.norm(p=2, dim=-1, keepdim=True)
        # cosine on inter_x
        cosine_x = inter_x / inter_x.norm(dim=-1, p=2, keepdim=True)
        cosine_similarities = cosine_x @ cls_weights.permute(1, 0)

        output = self.fc(inter_x)

        return output, inter_x, cosine_similarities


class FeatureData(Dataset):
    def __init__(self, data_path, target_path, data_T=None, clip_img_similarities_path=None, selected_classes=None):

        original_all_features = torch.load(data_path).numpy()
        original_all_target = torch.load(target_path).numpy()
        self.all_features = original_all_features
        self.all_target = original_all_target
        if selected_classes is not None:
            selected_all_features, selected_all_target = [], []
            for n_id, cls_id in enumerate(selected_classes):
                cls_img_ids = np.nonzero(original_all_target == cls_id)[0]
                cls_num_samples = cls_img_ids.shape[0]
                cls_features = original_all_features[cls_img_ids]
                selected_all_features.append(cls_features)
                selected_all_target.append([n_id for _ in range(cls_num_samples)])
            self.all_features = np.concatenate(selected_all_features, axis=0)
            self.all_target = np.concatenate(selected_all_target, axis=0)
        self.classes = list(np.unique(self.all_target))

        if clip_img_similarities_path is not None:
            self.is_train = True
            clip_img_similarities = torch.load(clip_img_similarities_path).to(torch.float32)
            selected_clip_img_similarities = []
            if selected_classes is not None:
                for n_id, cls_id in enumerate(selected_classes):
                    cls_img_ids = np.nonzero(original_all_target == cls_id)[0]
                    if clip_img_similarities.dim() == 2:
                        selected_clip_img_similarities.append(clip_img_similarities[cls_img_ids, cls_id])
                    else:
                        selected_clip_img_similarities.append(clip_img_similarities[cls_img_ids].detach())
                clip_img_similarities = torch.cat(selected_clip_img_similarities, dim=0)
                
            train_target = torch.LongTensor(self.all_target)
            all_img_weights = []
            for cls_id in range(len(self.classes)):
                cls_train_img_ids = torch.nonzero(train_target == cls_id).squeeze(dim=1)
                if clip_img_similarities.dim() == 2:
                    cls_img_similarities = clip_img_similarities[cls_train_img_ids, cls_id]
                elif clip_img_similarities.dim() == 1:
                    cls_img_similarities = clip_img_similarities[cls_train_img_ids].detach()
                img_weights = cls_img_similarities
                img_weights = torch.softmax(img_weights * data_T, dim=0)
                all_img_weights.append(img_weights)
            all_img_weights = torch.cat(all_img_weights)
            all_img_weights = all_img_weights * (1.0 / all_img_weights.mean())
            self.all_img_weights = all_img_weights.numpy()
        else:
            self.is_train = False

    def __len__(self):
        return self.all_features.shape[0]
            
    def __getitem__(self, idx):
        if self.is_train:
            return self.all_features[idx], self.all_target[idx], self.all_img_weights[idx]
        else:
            return self.all_features[idx], self.all_target[idx]


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # If you are using multi-GPU
    np.random.seed(seed)  # Numpy module
    random.seed(seed)  # Python random module
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def weighted_cross_entropy(logits, target, sample_weights=None):
    exp_logits = torch.exp(logits)
    softmax_logits = exp_logits / torch.sum(exp_logits, dim=1, keepdim=True)
    negative_log_likelihood = -torch.log(softmax_logits.gather(1, target.reshape(-1, 1)))
    
    if sample_weights is None:
        return negative_log_likelihood.mean()
    else:
        sample_weights = sample_weights.reshape(-1, 1)
        return (negative_log_likelihood * sample_weights).sum() / sample_weights.sum()


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

    header = f"Epoch: [{epoch}]"
    for i, (feature, target, img_weights) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
        start_time = time.time()
        feature, target, img_weights = feature.to(device), target.to(device), img_weights.to(device)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output, target_features, target_similarities = model(feature)
            loss = criterion(output, target, img_weights)

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            if args.clip_grad_norm is not None:
                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if args.clip_grad_norm is not None:
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()

        if model_ema and i % args.model_ema_steps == 0:
            model_ema.update_parameters(model)
            if epoch < args.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                model_ema.n_averaged.fill_(0)

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = feature.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
    print(f"{header} Train Acc@1 {metric_logger.acc1.global_avg:.3f} Train Acc@5 {metric_logger.acc5.global_avg:.3f}")
    return metric_logger.acc1.global_avg


def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0
    with torch.inference_mode():
        for image, target in data_loader:
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            if isinstance(output, tuple):
                output = output[0]
            loss = criterion(output, target)
        
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size
    # gather the stats from all processes

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

    metric_logger.synchronize_between_processes()

    return metric_logger.acc1.global_avg


def main(args):
    set_seed(args.seed)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

    log_dir = os.path.join(args.output_dir, args.model)
    utils.mkdir(log_dir)
    # filename = 'model-{}-seed-{}'.format(args.clip_model, args.seed)
    # log_file = os.path.join(log_dir, filename + '.log')
    # logging.basicConfig(
    #     filename=log_file,
    #     level=logging.INFO,
    #     format='%(asctime)s - %(levelname)s - %(message)s'
    # )
    train_data_path = os.path.join('saved_activations', args.model, 'target_model_all_train_features.pth')
    train_target_path = os.path.join('saved_activations', args.model, 'target_model_all_train_target.pth')
    test_data_path = os.path.join('saved_activations', args.model, 'target_model_all_test_features.pth')
    test_target_path = os.path.join('saved_activations', args.model, 'target_model_all_test_target.pth')
    if args.use_target_cav:
        clip_img_similarities_path = os.path.join('saved_activations', args.model, 'target_model_all_train_img_similarites(cls-relative-text).pth')
    elif args.use_selected_classes:
        clip_img_similarities_path = os.path.join('saved_contents', args.clip_model, 'clip_model_all_train_img_similarites(cls-selected-text).pth')
    else:
        clip_img_similarities_path = os.path.join('saved_contents', args.clip_model, 'clip_model_all_train_img_similarites(cls-relative-text).pth')
    # Load data
    # target_cav = torch.load(target_cav_path)
    if not args.use_selected_classes:
        dataset = FeatureData(train_data_path, train_target_path, args.data_T, clip_img_similarities_path=clip_img_similarities_path)
        dataset_test = FeatureData(test_data_path, test_target_path)
    else:
        dataset = FeatureData(train_data_path, train_target_path, args.data_T, clip_img_similarities_path=clip_img_similarities_path, selected_classes=selected_classes)
        dataset_test = FeatureData(test_data_path, test_target_path, selected_classes=selected_classes)
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    num_classes = len(dataset.classes)
    mixup_cutmix = get_mixup_cutmix(
        mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2
    )
    if mixup_cutmix is not None:
        def collate_fn(batch):
            return mixup_cutmix(*default_collate(batch))
    else:
        collate_fn = default_collate

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )

    print("Creating model")
    ori_model = MODEL_DICT[args.model](weights=WEIGHTS_DICT[args.model])
    if 'resnet' in args.model: 
        ori_model = ori_model.fc
    elif 'densenet' in args.model:
        ori_model = ori_model.classifier
    weight_ma = ori_model.weight
    bias_ma = ori_model.bias
    if args.use_selected_classes:
        ma_selected_ids = torch.LongTensor(selected_classes)
        weight_ma = weight_ma[ma_selected_ids, :]
        bias_ma = bias_ma[ma_selected_ids]
    
    # FC model
    model = FcModel(feature_dim=weight_ma.shape[1], num_classes=num_classes, weight_ma=weight_ma, bias_ma=bias_ma)
    model.to(device)
    
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = weighted_cross_entropy

    custom_keys_weight_decay = []
    if args.bias_weight_decay is not None:
        custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
    if args.transformer_embedding_decay is not None:
        for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
            custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
    parameters = utils.set_weight_decay(
        model,
        args.weight_decay,
        norm_weight_decay=args.norm_weight_decay,
        custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
    )

    opt_name = args.opt.lower()
    if opt_name.startswith("sgd"):
        optimizer = torch.optim.SGD(
            parameters,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
            parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
        )
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")

    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    args.lr_scheduler = args.lr_scheduler.lower()
    if args.lr_scheduler == "steplr":
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
        )
    elif args.lr_scheduler == "exponentiallr":
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
    else:
        raise RuntimeError(
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported."
        )

    if args.lr_warmup_epochs > 0:
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
            )
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
        )
    else:
        lr_scheduler = main_lr_scheduler

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    model_ema = None
    if args.model_ema:
        # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
        # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
        #
        # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
        # We consider constant = Dataset_size for a given dataset/setup and omit it. Thus:
        # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
        adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
        alpha = 1.0 - args.model_ema_decay
        alpha = min(1.0, alpha * adjust)
        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
        model_without_ddp.load_state_dict(checkpoint["model"])
        if not args.test_only:
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        # args.start_epoch = checkpoint["epoch"] + 1
        if model_ema:
            model_ema.load_state_dict(checkpoint["model_ema"])
        if scaler:
            scaler.load_state_dict(checkpoint["scaler"])

    ori_acc = evaluate(model, criterion, data_loader_test, device=device)
    print('Original Acc@1 %.3f' % (ori_acc))
    print("Start training")
    start_time = time.time()
    final_acc = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
        lr_scheduler.step()
        cur_acc = evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            if model_ema:
                checkpoint["model_ema"] = model_ema.state_dict()
            if scaler:
                checkpoint["scaler"] = scaler.state_dict()
        if cur_acc >= final_acc:
            final_acc = cur_acc
    print('Before Model Debug, Acc@1: %.3f' % (ori_acc))
    print('After Model Debug, Acc@1: %.3f' % (final_acc))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")


def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
    
    parser.add_argument("--data-set", default="ImageNet", type=str, help="dataset name")
    parser.add_argument("--data-path", default="datasets/ILSVRC2012", type=str, help="dataset path")
    parser.add_argument("--model", default="resnet18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=1e-3, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "--norm-weight-decay",
        default=None,
        type=float,
        help="weight decay for Normalization layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--bias-weight-decay",
        default=None,
        type=float,
        help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--transformer-embedding-decay",
        default=None,
        type=float,
        help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
    parser.add_argument("--print-freq", default=800, type=int, help="print frequency")
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
    parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
    parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")

    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

    # distributed training parameters
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
    parser.add_argument(
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
    parser.add_argument(
        "--model-ema-steps",
        type=int,
        default=32,
        help="the number of iterations that controls how often to update the EMA model (default: 32)",
    )
    parser.add_argument(
        "--model-ema-decay",
        type=float,
        default=0.99998,
        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
    )
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
    parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
    parser.add_argument(
        "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
    )
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
    parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
    # My added
    parser.add_argument("--clip-model", default="ViT-L-14", type=str)
    parser.add_argument("--seed", default=1028, type=int)
    parser.add_argument("--data-T", type=float, default=6)
    parser.add_argument("--use-target-cav", action="store_true")
    parser.add_argument("--use-selected-classes", action="store_true")

    return parser


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)