import warnings
warnings.filterwarnings("ignore")
import argparse
import time
import yaml
import os
import logging
from collections import OrderedDict
import torch
from torch.nn.parallel import DistributedDataParallel as NativeDDP
import torch.nn.functional as F

# timm functions
from timm.models import resume_checkpoint, load_checkpoint, model_parameters
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.utils import ModelEmaV2, distribute_bn, AverageMeter, reduce_tensor, dispatch_clip_grad, accuracy, get_outdir, CheckpointSaver, update_summary

# in functions
from utils import distributed_init, random_seed, create_logger, NormalizeByChannelMeanStd
from model.model import build_model
from model.loss import build_loss, resolve_amp, build_loss_scaler
from data.dataset import build_dataset
from adv.adv_utils import adv_generator


def get_args_parser():
    parser = argparse.ArgumentParser('Robust training script', add_help=False)
    parser.add_argument('--configs', default='', type=str)

    #* distributed setting
    parser.add_argument('--distributed', default=True)
    parser.add_argument('--local-rank', default=-1, type=int)
    parser.add_argument('--device-id', type=int, default=0)
    parser.add_argument('--rank', default=-1, type=int, help='rank')
    parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes')
    parser.add_argument('--dist-backend', default='nccl', help='backend used to set up distributed training')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
    
    #* amp parameters
    parser.add_argument('--apex-amp', action='store_true', default=False,
                        help='Use NVIDIA Apex AMP mixed precision')
    parser.add_argument('--native-amp', action='store_true', default=False,
                        help='Use Native Torch AMP mixed precision')
    parser.add_argument('--amp_version', default='', help='amp version')

    #* model parameters
    parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', help='Name of model to train')
    parser.add_argument('--num-classes', default=1000, type=int, help='number of classes')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--pretrain', default='', help='pretrain from checkpoint')
    parser.add_argument('--gp', default=None, type=str, metavar='POOL',
                        help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None. (opt)')
    parser.add_argument('--channels-last', action='store_true', default=False,
                        help='Use channels_last memory layout (opt)')

    #* Batch norm parameters
    parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)')
    parser.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)')
    parser.add_argument('--sync-bn', action='store_true', default=False, help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
    parser.add_argument('--dist-bn', type=str, default='reduce', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
    parser.add_argument('--split-bn', action='store_true', default=False,
                        help='Enable separate BN layers per augmentation split.')

    #* Optimizer parameters
    parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                        help='Optimizer (default: "adamw"')
    parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
                        help='Optimizer Epsilon (default: 1e-8)')
    parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                        help='Optimizer Betas (default: None, use opt default)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=2e-5,
                        help='weight decay (default: 0.0001)')
    parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
    parser.add_argument('--clip-mode', type=str, default='norm',
                    help='Gradient clipping mode. One of ("norm", "value", "agc")')
    parser.add_argument('--layer-decay', type=float, default=None,
                        help='layer-wise learning rate decay (default: None)')

    #* Learning rate schedule parameters
    parser.add_argument('--epochs', default=150, type=int)
    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
                        help='LR scheduler (default: "cosine"')
    parser.add_argument('--lrb', type=float, default=0.1, metavar='LR',
                        help='base learning rate (default: 5e-4)')
    parser.add_argument('--lr', type=float, default=None, help='actual learning rate')
    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                        help='learning rate noise on/off epoch percentages')
    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                        help='learning rate noise limit percent (default: 0.67)')
    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                        help='learning rate noise std-dev (default: 1.0)')
    parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
                        help='learning rate cycle len multiplier (default: 1.0)')
    parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
                        help='amount to decay each learning rate cycle (default: 0.5)')
    parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
                        help='learning rate cycle limit, cycles enabled if > 1')
    parser.add_argument('--lr-k-decay', type=float, default=1.0,
                        help='learning rate k-decay for cosine/poly (default: 1.0)')
    parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
                        help='warmup learning rate (default: 0.0001)')
    parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
    parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
                        help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
    parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
                        help='epoch interval to decay LR')
    parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                        help='patience epochs for Plateau LR scheduler (default: 10')
    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                        help='LR decay rate (default: 0.1)')

    #* dataset parameters
    parser.add_argument('--batch-size', default=64, type=int)    # batch size per gpu
    parser.add_argument('--train-dir', default='', type=str, help='train dataset path')
    parser.add_argument('--eval-dir', default='', type=str, help='validation dataset path')
    parser.add_argument('--input-size', default=224, type=int, help='images input size')
    parser.add_argument('--crop-pct', default=0.875, type=float,
                        metavar='N', help='Input image center crop percent (for validation only)')
    parser.add_argument('--interpolation', type=str, default='bicubic',
                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
    parser.add_argument('--mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), metavar='MEAN',
                        help='Override mean pixel value of dataset')
    parser.add_argument('--std', type=float, nargs='+', default=(0.229, 0.224, 0.225), metavar='STD',
                        help='Override std deviation of of dataset')
    
    #* Augmentation & regularization parameters
    parser.add_argument('--no-aug', action='store_true', default=False,
                        help='Disable all training augmentation, override other train aug args')
    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
                        help='Random resize scale (default: 0.08 1.0)')
    parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
                        help='Random resize aspect ratio (default: 0.75 1.33)')
    parser.add_argument('--hflip', type=float, default=0.5,
                        help='Horizontal flip training aug probability')
    parser.add_argument('--vflip', type=float, default=0.,
                        help='Vertical flip training aug probability')
    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                        help='Color jitter factor (default: 0.4)')
    parser.add_argument('--aa', type=str, default=None, metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". (default: None)'),
    parser.add_argument('--aug-repeats', type=float, default=0,
                        help='Number of augmentation repetitions (distributed training only) (default: 0)')
    parser.add_argument('--aug-splits', type=int, default=0,
                        help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
    parser.add_argument('--jsd-loss', action='store_true', default=False,
                        help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
    parser.add_argument('--bce-loss', action='store_true', default=False,
                        help='Enable BCE loss w/ Mixup/CutMix use.')
    parser.add_argument('--bce-target-thresh', type=float, default=None,
                        help='Threshold for binarizing softened BCE targets (default: None, disabled)')
    # random erase
    parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
                        help='Random erase prob (default: 0.)')
    parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')
    parser.add_argument('--mixup', type=float, default=0.0,
                        help='mixup alpha, mixup enabled if > 0. (default: 0.)')
    parser.add_argument('--cutmix', type=float, default=0.0,
                        help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup-prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup-mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
    parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                        help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
    parser.add_argument('--smoothing', type=float, default=0.1,
                        help='Label smoothing (default: 0.1)')
    parser.add_argument('--train-interpolation', type=str, default='random',
                        help='Training interpolation (random, bilinear, bicubic default: "random")')
    # drop connection
    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')
    parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
                        help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
    parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
                        help='Drop path rate (default: None)')
    parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
                        help='Drop block rate (default: None)')

    #* ema
    parser.add_argument('--model-ema', action='store_true', default=False,
                        help='Enable tracking moving average of model weights')
    parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
                        help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
    parser.add_argument('--model-ema-decay', type=float, default=0.9998,
                        help='decay factor for model weights moving average (default: 0.9998)')

    # misc
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
    parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
                    help='how many batches to wait before writing recovery checkpoint')
    parser.add_argument('--max-history', type=int, default=5, help='how many recovery checkpoints')
    parser.add_argument('--num-workers', type=int, default=4, metavar='N',
                        help='how many training processes to use (default: 4)')
    parser.add_argument('--output-dir', default='',
                        help='path where to save, empty for no saving')
    parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
                        help='Best metric (default: "top1")')
    parser.add_argument('--pin-mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')

    # advtrain
    parser.add_argument('--advtrain', default=False, help='if use advtrain')
    parser.add_argument('--attack-criterion', type=str, default='regular', choices=['regular', 'smooth', 'mixup'], help='default args for: adversarial training')
    parser.add_argument('--attack-eps', type=float, default=4.0/255, help='attack epsilon.')
    parser.add_argument('--attack-step', type=float, default=8.0/255/3, help='attack epsilon.')
    parser.add_argument('--attack-it', type=int, default=3, help='attack iteration')
    # advprop
    parser.add_argument('--advprop', default=False, help='if use advprop')
    parser.add_argument('--align-loss-weight', type=float, default=0.2, help='Weight for feature alignment loss')

    parser.add_argument('--attack_types', type=str, nargs='*', default=('autoattack',), help='autoattack, pgd100')
    parser.add_argument('--norm', type=str, default='Linf', help='You can choose norm for aa attack', choices=['Linf', 'L2', 'L1'])
    parser.add_argument('--ckpt_dir', type=str, default='', help='checkpoint dir for adv attack')
    
    return parser


def load_dinov2_teacher(BACKBONE_SIZE='base'):
    backbone_archs = {
        "small": "vits14",
        "base": "vitb14",
        "large": "vitl14",
        "giant": "vitg14",
    }
    backbone_arch = backbone_archs[BACKBONE_SIZE]
    backbone_name = f"dinov2_{backbone_arch}"
    model = torch.hub.load(
        repo_or_dir="/home/qianzhuang/.cache/torch/hub/facebookresearch_dinov2_main",
        model=backbone_name,
        source="local",
    )

    model.eval()
    for p in model.parameters():
        p.requires_grad = False
    model.cuda()

    return model

def main(args, args_text):
    # distributed settings and logger
    if "WORLD_SIZE" in os.environ:
        args.world_size=int(os.environ["WORLD_SIZE"])
    args.distributed=args.world_size>1
    distributed_init(args)
    log_path = './test_out'
    args.output_dir = os.path.join(log_path, args.output_dir)
    _logger = create_logger(args.output_dir, dist_rank=args.rank, name='main_train', default_level=logging.INFO)

    # fix the seed for reproducibility
    random_seed(args.seed, args.rank)
    torch.backends.cudnn.deterministic=False
    torch.backends.cudnn.benchmark = True
    
    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # resolve amp
    resolve_amp(args, _logger)

    # build model
    model = build_model(args, _logger, num_aug_splits)
    
    # ckpt_path = './results/ImageNet-100_ours_v1/vit_base_patch14_224_imagenet100_2_re/'
    # ckpt_name = 'model_best.pth.tar'
    # ckpt_dir = os.path.join(ckpt_path, ckpt_name)
    ckpt=torch.load(args.ckpt_dir, map_location='cpu')
    state_dict = ckpt['state_dict']
    
    model.load_state_dict(state_dict)
    
    # normalize = NormalizeByChannelMeanStd(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    # model = torch.nn.Sequential(normalize, model)
    # model = model.cuda()
    
    # create optimizer
    optimizer=None
    if args.lr is None:
        args.lr=args.lrb * args.batch_size * args.world_size / 512
    optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
    # build loss scaler
    amp_autocast, loss_scaler = build_loss_scaler(args, _logger)
    # resume_epoch = resume_checkpoint(
    #         model, './results/ImageNet-100_ours_v1/vit_base_patch14_224_imagenet100_2_re/model_best.pth.tar',
    #         optimizer=optimizer,
    #         loss_scaler=loss_scaler,
    #         log_info=args.rank == 0)
    


    # setup distributed training
    if args.distributed:
        if args.amp_version == 'apex':
            # Apex DDP preferred unless native amp is activated
            from apex.parallel import DistributedDataParallel as ApexDDP
            _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            _logger.info("Using native Torch DistributedDataParallel.")
            # Disable buffer broadcasts during forward to avoid extra collectives in eval
            model = NativeDDP(
                model,
                device_ids=[args.device_id],
                broadcast_buffers=False,
                find_unused_parameters=False,
            )
        # NOTE: EMA model does not need to be wrapped by DDP


    # create the train and eval dataloaders
    loader_train, loader_eval, mixup_fn = build_dataset(args, num_aug_splits)

    # setup loss function
    train_loss_fn, validate_loss_fn = build_loss(args, mixup_fn, num_aug_splits)

    # saver
    eval_metric = args.eval_metric
    saver = None
    best_metric = None
    best_epoch = None
    output_dir = None
    # if args.rank == 0:
    #     output_dir = get_outdir(args.output_dir)
    #     decreasing=True if eval_metric=='loss' else False
    #     saver = CheckpointSaver(
    #         model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
    #         checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.max_history)
    #     with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
    #         f.write(args_text)

    # start training
    # _logger.info(f"Start eval for {ckpt_dir}")

    # distributed bn sync
    if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
        _logger.info("Distributing BatchNorm running means and vars")
        distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

    adv_eval_metrics = adv_validate(model, loader_eval, args, 4)

    # calculate evaluation metric
    eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, _logger=_logger)
    print(f"Eval metrics: {eval_metrics}")
    print(f"Adv Eval metrics: {adv_eval_metrics}")

    # make sure all ranks finish before exiting to avoid NCCL hanging on some ranks
    if args.distributed and torch.distributed.is_initialized():
        try:
            torch.distributed.barrier()
        except Exception:
            pass

    # optional: cleanly destroy process group
    if args.distributed and torch.distributed.is_initialized():
        try:
            torch.distributed.destroy_process_group()
        except Exception:
            pass





def validate(model, loader, loss_fn, args, amp_autocast=None, log_suffix='', _logger=None):
    # Use the bare module to avoid any DDP comms during forward
    run_model = model.module if hasattr(model, 'module') else model
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    top1_m = AverageMeter()
    top5_m = AverageMeter()
    adv_losses_m = AverageMeter()
    adv_top1_m = AverageMeter()
    adv_top5_m = AverageMeter()

    # accumulate locally; one reduce at the end only
    total_samples = 0
    sum_loss = 0.0
    sum_top1_correct = 0.0
    sum_top5_correct = 0.0
    adv_total_samples = 0
    adv_sum_loss = 0.0
    adv_sum_top1_correct = 0.0
    adv_sum_top5_correct = 0.0


    model.eval()

    end = time.time()
    last_idx = len(loader) - 1
    for batch_idx, (input, target) in enumerate(loader):
        # read eval input
        last_batch = batch_idx == last_idx
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        # normal eval process
        with torch.no_grad():
            # with amp_autocast():
            output = run_model(input)
            if isinstance(output, (tuple, list)):
                output = output[0]
            
            loss = loss_fn(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))

            bs = input.size(0)
            loss_val = loss.detach().item()
            acc1_val = acc1.detach().item()
            acc5_val = acc5.detach().item()

            torch.cuda.synchronize()

            # record local meters for on-the-fly logs (rank-local)
            losses_m.update(loss_val, bs)
            top1_m.update(acc1_val, bs)
            top5_m.update(acc5_val, bs)

            # accumulate local sums for one-shot reduction at end
            total_samples += bs
            sum_loss += loss_val * bs
            sum_top1_correct += acc1_val * bs / 100.0
            sum_top5_correct += acc5_val * bs / 100.0

        # adv eval process
        if args.advtrain:
            adv_input=adv_generator(args, input, target, model, 4/255, 10, 1/255, random_start=True, use_best=False, attack_criterion='regular')
            with torch.no_grad():
                #with amp_autocast():
                adv_output = model(adv_input)
                if isinstance(adv_output, (tuple, list)):
                    adv_output = adv_output[0]
                
                adv_loss = loss_fn(adv_output, target)
                adv_acc1, adv_acc5 = accuracy(adv_output, target, topk=(1, 5))

                bs_adv = adv_input.size(0)
                adv_loss_val = adv_loss.detach().item()
                adv_acc1_val = adv_acc1.detach().item()
                adv_acc5_val = adv_acc5.detach().item()

                torch.cuda.synchronize()

                # record local meters
                adv_losses_m.update(adv_loss_val, bs_adv)
                adv_top1_m.update(adv_acc1_val, bs_adv)
                adv_top5_m.update(adv_acc5_val, bs_adv)

                # accumulate local sums for one-shot reduction at end
                adv_total_samples += bs_adv
                adv_sum_loss += adv_loss_val * bs_adv
                adv_sum_top1_correct += adv_acc1_val * bs_adv / 100.0
                adv_sum_top5_correct += adv_acc5_val * bs_adv / 100.0


        batch_time_m.update(time.time() - end)
        end = time.time()

        if last_batch or batch_idx % args.log_interval == 0:
            log_name = 'Test' + log_suffix
            _logger.info(
                '{0}: [{1:>4d}/{2}]  '
                'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})  '
                'AdvLoss: {adv_loss.val:>7.4f} ({adv_loss.avg:>6.4f})  '
                'AdvAcc@1: {adv_top1.val:>7.4f} ({adv_top1.avg:>7.4f})  '
                'AdvAcc@5: {adv_top5.val:>7.4f} ({adv_top5.avg:>7.4f})'.format(
                    log_name, batch_idx, last_idx, batch_time=batch_time_m,
                    loss=losses_m, top1=top1_m, top5=top5_m,
                    adv_loss=adv_losses_m, adv_top1=adv_top1_m, adv_top5=adv_top5_m))

    # One-shot global reduction for metrics
    if args.distributed and torch.distributed.is_initialized():
        try:
            device = torch.device(f"cuda:{args.device_id}")
            vec = torch.tensor(
                [sum_loss, sum_top1_correct, sum_top5_correct, float(total_samples),
                 adv_sum_loss, adv_sum_top1_correct, adv_sum_top5_correct, float(adv_total_samples)],
                dtype=torch.float64, device=device)
            torch.distributed.all_reduce(vec, op=torch.distributed.ReduceOp.SUM)
            sum_loss, sum_top1_correct, sum_top5_correct, total_samples, \
            adv_sum_loss, adv_sum_top1_correct, adv_sum_top5_correct, adv_total_samples = vec.tolist()
        except Exception:
            pass

    # compute final metrics from sums
    final_loss = sum_loss / max(1.0, float(total_samples)) if total_samples else losses_m.avg
    final_top1 = 100.0 * (sum_top1_correct / max(1.0, float(total_samples))) if total_samples else top1_m.avg
    final_top5 = 100.0 * (sum_top5_correct / max(1.0, float(total_samples))) if total_samples else top5_m.avg
    final_adv_loss = adv_sum_loss / max(1.0, float(adv_total_samples)) if adv_total_samples else adv_losses_m.avg
    final_adv_top1 = 100.0 * (adv_sum_top1_correct / max(1.0, float(adv_total_samples))) if adv_total_samples else adv_top1_m.avg
    final_adv_top5 = 100.0 * (adv_sum_top5_correct / max(1.0, float(adv_total_samples))) if adv_total_samples else adv_top5_m.avg

    # ensure all ranks finish validation before returning
    if args.distributed and torch.distributed.is_initialized():
        try:
            torch.distributed.barrier()
        except Exception:
            pass

    metrics = OrderedDict([
        ('loss', float(final_loss)),
        ('top1', float(final_top1)),
        ('top5', float(final_top5)),
        ('advloss', float(final_adv_loss)),
        ('advtop1', float(final_adv_top1)),
        ('advtop5', float(final_adv_top5)),
    ])

    return metrics

def adv_validate(model, loader, args, eps_int, log_suffix='robust acc'):
    batch_time_m = AverageMeter()
    top1_m = AverageMeter()
    eps=eps_int/255
    # Use the bare module to avoid any DDP comms during forward
    run_model = model.module if hasattr(model, 'module') else model
    run_model.eval()

    # set attackers
    attackers={}
    for attack_type in args.attack_types:
        if attack_type == 'autoattack':
            if args.distributed:
                from adv.autoattack_ddp import AutoAttack
            else:
                from adv.autoattack import AutoAttack
            adversary = AutoAttack(run_model, norm=args.norm, eps=eps, version='standard')
            attackers[attack_type]=adversary
        elif attack_type == 'pgd100':
            from adv.adv_utils import pgd_attack
            attackers[attack_type]=pgd_attack
        elif attack_type == 'fgsm':
            from adv.adv_utils import fgsm_attack
            attackers[attack_type]=fgsm_attack

    end = time.time()
    last_idx = len(loader) - 1
    total_robust = 0.0
    total_count = 0.0
    for batch_idx, (input, target) in enumerate(loader):
        input = input.cuda()
        target = target.cuda()
        # Ensure all ranks use the SAME per-rank batch size for DDP attackers
        if args.distributed and torch.distributed.is_initialized():
            try:
                local_bsz = torch.tensor([target.size(0)], device=input.device, dtype=torch.int64)
                bsz_list = [torch.zeros_like(local_bsz) for _ in range(args.world_size)]
                torch.distributed.all_gather(bsz_list, local_bsz)
                min_bsz = int(torch.stack(bsz_list).min().item())
                if min_bsz < target.size(0):
                    input = input[:min_bsz]
                    target = target[:min_bsz]
            except Exception:
                pass
        batch_size=target.size(0)
        robust_flag=torch.ones_like(target).cuda()

        # attack
        # Prevent DDP from expecting gradient synchronization on params during attack
        # (we only need grad wrt inputs). Temporarily disable param grads.
        _requires_grad_cache = []
        for p in run_model.parameters():
            _requires_grad_cache.append(p.requires_grad)
            p.requires_grad = False
        for attack_type in args.attack_types:
            if attack_type == 'autoattack':
                x_adv = attackers[attack_type].run_standard_evaluation(input, target, bs=target.size(0))
            elif attack_type == 'pgd100':
                x_adv = attackers[attack_type](input, target, run_model, eps, 100, 1/255, 1, gpu=args.device_id)
            elif attack_type == 'fgsm':
                x_adv = attackers[attack_type](input, target, run_model, 4/255, 4/255, gpu=args.device_id)
            with torch.no_grad():
                output = run_model(x_adv.detach())
                _, label=torch.max(output, dim=1)
                robust_label= label == target
                robust_flag = torch.logical_and(robust_flag, robust_label)
        # restore param grads
        for p, req in zip(run_model.parameters(), _requires_grad_cache):
            p.requires_grad = req
        
        acc1=robust_flag.float().sum(0) * 100. / batch_size
        #print('eps: {0}\t robust acc: {1:.2f}'.format(eps_int, acc1.item()))

        # accumulate locally; no per-batch cross-rank reduce
        torch.cuda.synchronize()
        top1_m.update(acc1.item(), output.size(0))
        total_robust += robust_flag.float().sum().item()
        total_count += float(output.size(0))
        #print('eps: {0}\t robust acc: {1:.2f}'.format(eps_int, top1_m.avg))
        batch_time_m.update(time.time() - end)
        end = time.time()

        log_name = 'Test ' + log_suffix + ' of eps ' + str(eps_int)
        print(
            '{0}: [{1:>4d}/{2}]  '
            'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
            'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '.format(
                log_name, batch_idx, last_idx, batch_time=batch_time_m, top1=top1_m))

    # One-shot reduce robust totals
    final_top1 = top1_m.avg
    if args.distributed and torch.distributed.is_initialized():
        try:
            device = torch.device(f"cuda:{args.device_id}")
            vec = torch.tensor([total_robust, total_count], dtype=torch.float64, device=device)
            torch.distributed.all_reduce(vec, op=torch.distributed.ReduceOp.SUM)
            total_robust, total_count = vec.tolist()
            if total_count > 0:
                final_top1 = 100.0 * (total_robust / total_count)
        except Exception:
            pass

    # ensure all ranks finish robust eval before returning
    if args.distributed and torch.distributed.is_initialized():
        try:
            torch.distributed.barrier()
        except Exception:
            pass

    metrics = OrderedDict([('top1', float(final_top1))])

    return metrics

if __name__ == '__main__':
    parser = argparse.ArgumentParser('Robust training script', parents=[get_args_parser()])
    args = parser.parse_args()
    opt = vars(args)
    if args.configs:
        opt.update(yaml.load(open(args.configs), Loader=yaml.FullLoader))
    
    args = argparse.Namespace(**opt)
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)

    main(args, args_text)