import datetime
import os
import time
from timm.utils import *



import matplotlib.pyplot as plt
import torch
import torch.utils.data
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import math
from torch.cuda import amp
import model, utils
from spikingjelly.clock_driven import functional
from spikingjelly.datasets import cifar10_dvs
from timm.models import create_model
from timm.data import Mixup
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
import autoaugment
from factory import Betascheduler,TET_loss,compute_mutual_info_matrix


from collections import OrderedDict
from contextlib import suppress


import random
root_path = os.path.abspath(__file__)

import numpy as np
#writer = SummaryWriter("./")
from timm.utils import update_summary
# import criterion




import adv



def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')

    parser.add_argument('--model', default='SEWResNet', help='model')
    parser.add_argument('--dataset', default='cifar10dvs', help='dataset')
    parser.add_argument('--num-classes', type=int, default=10, metavar='N',
                        help='number of label classes (default: 1000)')
    parser.add_argument('--data-path', default='/zhuzizheng/data/cifar10dvs', help='dataset')
    parser.add_argument('--device', default='cuda', help='device')
    parser.add_argument('-b', '--batch-size', default=16, type=int)
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    parser.add_argument('--print-freq', default=256, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='./logs', help='path where to save')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        # default=True,
        help="Only test the model",
        action="store_true",
    )

    # Mixed precision training parameters
    parser.add_argument('--amp', default=True, action='store_true',
                        help='Use AMP training')


    # distributed training parameters
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

    parser.add_argument('--tb', default=True,  action='store_true',
                        help='Use TensorBoard to record logs')
    parser.add_argument('--T', default=16, type=int, help='simulation steps')
    # parser.add_argument('--adam', default=True, action='store_true',
    #                     help='Use Adam')

    # Optimizer Parameters
    parser.add_argument('--opt', default='adamw', type=str, metavar="OPTIMIZER", help='Optimizer (default: "adamw")')
    parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)')
    parser.add_argument('--opt-betas', default=None, type=float, metavar='BETA', help='Optimizer Betas')
    parser.add_argument('--weight-decay', default=0.06, type=float, help='weight decay')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='Momentum for SGD. Adam will not use momentum')

    parser.add_argument('--connect_f', default='ADD', type=str, help='element-wise connect function')
    parser.add_argument('--T_train', default=None, type=int)

    #Learning rate scheduler
    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
                        help='LR scheduler (default: "cosine"')
    parser.add_argument('--lr', type=float, default=5e-3, metavar='LR',
                        help='learning rate (default: 5e-4)')
    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                        help='learning rate noise on/off epoch percentages')
    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                        help='learning rate noise limit percent (default: 0.67)')
    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                        help='learning rate noise std-dev (default: 1.0)')
    parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
                        help='learning rate cycle len multiplier (default: 1.0)')
    parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
                        help='learning rate cycle limit')
    parser.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
                        help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
    parser.add_argument('--epochs', type=int, default=120, metavar='N',
                        help='number of epochs to train (default: 2)')
    parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
                        help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--decay-epochs', type=float, default=20, metavar='N',
                        help='epoch interval to decay LR')
    parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                        help='patience epochs for Plateau LR scheduler (default: 10')
    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                        help='LR decay rate (default: 0.1)')

    # Augmentation & regularization parameters
    parser.add_argument('--smoothing', type=float, default=0.1,
                        help='Label smoothing (default: 0.1)')
    parser.add_argument('--mixup', type=float, default=0.5,
                        help='mixup alpha, mixup enabled if > 0. (default: 0.)')
    parser.add_argument('--cutmix', type=float, default=0.,
                        help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup-prob', type=float, default=0.5,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup-mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
    parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                        help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
    parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Start with pretrained version of specified network (if avail)')
    
    # changed on 10-22
    parser.add_argument('--top_down', type = int ,default=2,
                    help='load top-down mechanism')
    #parser.add_argument('--beta', type=float, default=0.8,
    #                    help='top-down loss function hyper-parameter')
    parser.add_argument("--beta", default= [0.5,0.5], 
                        type= float, nargs = 2, metavar='N N')
    parser.add_argument('--V', type=int, default=4,
                        help='type of processing module')
    parser.add_argument('--num_decoder_layers', type=int, default=-1,
                        help='decoder_layers')
    parser.add_argument('--aug', type=float, default=1.0,
                        help='the parameter of top_down feedback')
    parser.add_argument('--layer_td', type=str, default='batch')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--loss',type = str,default='SDT')
    # parser.add_argument('--eval',type = bool,action = 'store_ture')
    

    
    

    

    #adv
 
    parser.add_argument('--adv',type = str, default= None)




    args = parser.parse_args()
    return args

def split_to_train_test_set(train_ratio: float, origin_dataset: torch.utils.data.Dataset, num_classes: int, random_split: bool = False):
    '''
    :param train_ratio: split the ratio of the origin dataset as the train set
    :type train_ratio: float
    :param origin_dataset: the origin dataset
    :type origin_dataset: torch.utils.data.Dataset
    :param num_classes: total classes number, e.g., ``10`` for the MNIST dataset
    :type num_classes: int
    :param random_split: If ``False``, the front ratio of samples in each classes will
            be included in train set, while the reset will be included in test set.
            If ``True``, this function will split samples in each classes randomly. The randomness is controlled by
            ``numpy.randon.seed``
    :type random_split: int
    :return: a tuple ``(train_set, test_set)``
    :rtype: tuple
    '''
    label_idx = []
    for i in range(num_classes):
        label_idx.append([])

    for i, item in enumerate(origin_dataset):
        y = item[1]
        if isinstance(y, np.ndarray) or isinstance(y, torch.Tensor):
            y = y.item()
        label_idx[y].append(i)
    train_idx = []
    test_idx = []
    if random_split:
        for i in range(num_classes):
            np.random.shuffle(label_idx[i])

    for i in range(num_classes):
        pos = math.ceil(label_idx[i].__len__() * train_ratio)
        train_idx.extend(label_idx[i][0: pos])
        test_idx.extend(label_idx[i][pos: label_idx[i].__len__()])

    return torch.utils.data.Subset(origin_dataset, train_idx), torch.utils.data.Subset(origin_dataset, test_idx)



def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, scaler=None, 
                    T_train=None, aug=None, trival_aug=None, mixup_fn=None, td=False, beta=None):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))

    header = 'Epoch: [{}]'.format(epoch)

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        start_time = time.time()
        image, target = image.to(device), target.to(device)
        image = image.float()  # [N, T, C, H, W]
        N,T,C,H,W = image.shape 
          

        if aug != None:
            # image = image.flatten(1, 2).contiguous() # 合并T,C
            image = torch.stack([(aug(image[i])) for i in range(N)])
            # image = image.reshape(N,T,C,H,W)

        if trival_aug != None:
            # image = image.flatten(0,1).contiguous()
            image = torch.stack([(trival_aug(image[i])) for i in range(N)])
            # image = image.reshape(N,T,C,H,W).contiguous()

        if mixup_fn is not None:
            # image = image.flatten(1, 2).contiguous() # 合并T,C
            image, target = mixup_fn(image, target)
            target_for_compu_acc = target.argmax(dim=-1)
            # image = image.reshape(N,T,C,H,W)


        if T_train:
            sec_list = np.random.choice(image.shape[1], T_train, replace=False)
            sec_list.sort()
            image = image[:, sec_list]

        if scaler is not None:
            with amp.autocast():
                if args.top_down < 2:
                    output = model(image)
                    loss = criterion(output, target)
                else: # changed on 10-22
                    ########### traditional #########
                    # x1, td, tmp = model(image)  # traditional
                    # functional.reset_net(model)
                    # output, recon_loss = model(tmp, td=td)
                    # loss = (1-beta) * criterion(x1, target) + beta * criterion(output, target) 
                    ########### traditional #########
                    
                    

                    ########### feedback 2 #########
                    if args.top_down == 2:
                        image = image.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
                        T = image.size(0)
                        x1,td,tmp = model(image[:T//2])
        
                        if args.loss == 'TET':
                            x2,output = model(image[T//2:],td = td)
                            loss = (1-beta) * TET_loss(x1, target,criterion=criterion) + \
                            beta * TET_loss(x2, target,criterion=criterion) 
                        else:
                            output = model(image[T//2:],td = td)
                            loss = (1-beta) * criterion(x1, target) + \
                                beta * criterion(output, target) 
                    ########### feedback 2 #########
                    
                    
                    
                    ########### feedback 3 #########
                    if args.top_down == 3:
                        loss = 0
                        td = None
                        image = image.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
                        for i in range(args.T):
                            output,td = model(image[i],td = None) if td is None else model(image[i], td = td)
                            loss += criterion(output, target)
                        loss = loss / args.T
                        
                        output = model(image)
                        loss = criterion(output, target)  
                    ########### feedback 3 #########
       
        
        else:
            exit(0)
        #     if td < 2:
        #         output = model(image)
        #         loss = criterion(output, target)
        #     else:
        #         ########### traditional #########
        #         # x1, td, tmp = model(image)  # traditional
        #         # functional.reset_net(model)
        #         # output, recon_loss = model(tmp, td=td)
        #         # loss = (1-beta) * criterion(x1, target) + beta * criterion(output, target) 
        #         ########### traditional #########
                
                
                
        #         ########### feedback 2 #########
        #         # x1,td,tmp = model(image)
        #         # output,recon_loss = model(image,td = td)
        #         # loss = (1-beta) * criterion(x1, target) + beta * criterion(output, target) 
        #         ########### feedback 2 #########
                
                
                
        #         ########### feedback 3 #########
        #         loss = 0
        #         td = None
        #         image = image.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
        #         for i in range(args.T):
        #             output,td = model(image[i],td = None) if td is None else model(image[i], td = td)
        #             loss += criterion(output, target)
        #         loss = loss / args.T
        #         ########### feedback 3 #########

        optimizer.zero_grad()

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        else:
            loss.backward()
            optimizer.step()

        functional.reset_net(model)
        if mixup_fn is not None:
            acc1, acc5 = utils.accuracy(output, target_for_compu_acc, topk=(1, 5))
        else:
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        loss_s = loss.item()
        if math.isnan(loss_s):
            raise ValueError('loss is Nan')
        acc1_s = acc1.item()
        acc5_s = acc5.item()

        metric_logger.update(loss=loss_s, lr=optimizer.param_groups[0]["lr"])

        metric_logger.meters['acc1'].update(acc1_s, n=batch_size)
        metric_logger.meters['acc5'].update(acc5_s, n=batch_size)
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    return metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg



def evaluate(model,args, criterion, data_loader, device, print_freq=100, header='Test:', td=2):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    with torch.no_grad():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            image = image.float()
            if args.top_down == 1:                      #changed on 10-22
              
                output = model(image)
                
                if args.test_only and args.top_down == 1:
                    return output

            else:
                ########### traditional #########
                # x1, td, tmp = model(image)   
                # functional.reset_net(model)
                # output, recon_loss = model(tmp, td = td)
                ########### traditional #########
                
                
                ########### feedback 2 #########
                if args.top_down == 2:
                    image = image.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
                    T = image.size(0)
                    x1,td,tmp = model(image[:T//2])
                    output = model(image[T//2:],td = td)
                ########### feedback 2 #########
                
                
                
                ########### feedback 3 #########
                if args.top_down == 3:
                    loss = 0  
                    td = None
                    output_list = []
                    image = image.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
                    for i in range(args.T):
                        output,td = model(image[i],td = None) if td is None else model(image[i], td = td)
                        output_list.append(output)
                    output = output_list[-1]
                    output = model(image)
    
                ########### feedback 3 #########
            # if args.TET:
            #     loss = TET_loss(output, target, criterion=criterion)
            # else:
            loss = criterion(output, target)
            functional.reset_net(model)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()

    loss, acc1, acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    print(f' * Acc@1 = {acc1}, Acc@5 = {acc5}, loss = {loss}')
    
    if args.test_only:
        if args.top_down == 2:
            return loss, acc1, acc5, x1,output,td
        elif args.top_down < 2:
            return loss, acc1, acc5, output
        elif args.top_down == 3:
            return loss, acc1, acc5, output_list 
    return loss, acc1, acc5

def load_data(dataset_dir, distributed, T):
    # Data loading code
    print("Loading data")

    st = time.time()

    #if not os.path.exists(os.path.join(dataset_dir, f"frames_number_{T}_split_by_number")): 

    origin_set = cifar10_dvs.CIFAR10DVS(root=dataset_dir, data_type='frame', frames_number=T, split_by='number')
    dataset_train, dataset_test = split_to_train_test_set(0.9, origin_set, 10)
    print("Took", time.time() - st)
    #else:
    #    print(f"{os.path.join(dataset_dir, f"frames_number_{T}_split_by_number")} has already been existed!")
    #    dataset_train = cifar10_dvs.CIFAR10DVS(root=dataset_dir, train=True, data_type='frame', frames_number=T, split_by='number')
    #    dataset_test = cifar10_dvs.CIFAR10DVS()

    print("Creating data loaders")
    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset_train)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset_train, dataset_test, train_sampler, test_sampler

def main(args):
    
    
    _seed_ = args.seed
    random.seed(_seed_)
    torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
    torch.cuda.manual_seed_all(_seed_)
    np.random.seed(_seed_)  
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

    

    max_test_acc1 = 0.
    test_acc5_at_max_test_acc1 = 0.


    train_tb_writer = None
    te_tb_writer = None

    utils.init_distributed_mode(args)
    print(args)

    output_dir = os.path.join(args.output_dir, f'{args.model}_b{args.batch_size}_T{args.T}')

    if args.T_train:
        output_dir += f'_Ttrain{args.T_train}'

    if args.weight_decay:
        output_dir += f'_wd{args.weight_decay}'


    if args.opt == 'adamw':
        output_dir += '_adamw'
    else:
        output_dir += '_sgd'

    if args.connect_f:
        output_dir += f'_cnf_{args.connect_f}'

    if not os.path.exists(output_dir):
        utils.mkdir(output_dir)

    output_dir = os.path.join(output_dir, f'lr{args.lr}')
    if not os.path.exists(output_dir):
        utils.mkdir(output_dir)

    device = torch.device(args.device)

    data_path = args.data_path

    dataset_train, dataset_test, train_sampler, test_sampler = load_data(data_path, args.distributed, args.T)


    data_loader = torch.utils.data.DataLoader(
        dataset=dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        drop_last=True,
        # sampler=train_sampler,
        pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset=dataset_test,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        # sampler=test_sampler,
        drop_last=False,
        pin_memory=True)
    # dataset_train, dataset_test, train_sampler, test_sampler = load_data(data_path, args.distributed, args.T)
    # print(f'dataset_train:{dataset_train.__len__()}, dataset_test:{dataset_test.__len__()}')

    model = create_model(
        #'Spikingformer',
        "QKFormer",
        #'sglformer2',
        pretrained=args.pretrained,
        drop_rate=0.,
        drop_path_rate=0.1,
        drop_block_rate=None,
        T=args.T, td=args.top_down, V=args.V, num_decoder_layers=args.num_decoder_layers, 
        top_down_aug=args.aug, layer_td = args.layer_td, loss = args.loss,test_only = args.test_only
    )
    print("Creating model")
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"number of params: {n_parameters}")
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    # criterion_train = LabelSmoothingCrossEntropy()
    criterion_train = SoftTargetCrossEntropy().cuda()
    criterion = nn.CrossEntropyLoss()

    optimizer = create_optimizer(args, model)
    if args.amp:
        scaler = amp.GradScaler()
    else:
        scaler = None
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=128, eta_min=1e-5)
    #op = torch.optim.adamw()
    start_epoch = 0
    model_without_ddp = model

    # changed on 10-24
    if args.beta[0] != args.beta[1]:
        beta_scheduler = Betascheduler(args.beta[0],args.beta[1], num_epochs - start_epoch)
    else:
        beta_scheduler = None


    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module



    if args.resume :
        if args.test_only:
            checkpoint = torch.load(args.resume, map_location='cpu')
            model_without_ddp.load_state_dict(checkpoint['model'])

        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
            model_without_ddp.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            max_test_acc1 = checkpoint['max_test_acc1']
            test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']

    if args.test_only:

        if args.top_down == 2:
            loss ,acc1,acc5, x1,x2,td  = evaluate(model,args, criterion, data_loader_test, device=device, 
                                        header='Test:',td = args.top_down)
            all_output = torch.cat((x1, x2), dim=0)
            print(all_output.shape)
        elif args.top_down == 1:
            output = evaluate(model,args, criterion, data_loader_test, device=device, 
                                        header='Test:',td = args.top_down)
            all_output = output 
    
        elif args.top_down == 3:
            loss, acc1, acc5, output_list = evaluate(model,args, criterion, data_loader_test, device=device, 
                                        header='Test:',td = args.top_down)
            all_output = output_list   
        print(all_output.shape) 
        mi = compute_mutual_info_matrix(all_output)
        print(mi)
        # print("acc1 is :",acc1)
        exit(0)


        sum_acc = 0
        res = []

        if args.adv == 'FGSM':
            # eps = [ 0.001,0.01 ,0.025, 0.05 ,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]
            eps = [ 0.001,0.01 ,0.025, 0.05 ,0.1,0.2,0.3]
            for i in eps:
                
                acc = validate_adv(model,data_loader_test, criterion, args, 
                                        adv_fn = adv.fgsm,eps = i,td = args.top_down)
                sum_acc += acc
                res.append(acc)
                print("\neps is",i, "acc is", acc)
                print("\n")

            print("\naverage acc is ",sum_acc/len(eps))
        if args.adv == 'PGD':
            
            eps = 8/255
            iters = [5,10,30,50]
            for j in iters:

                acc = validate_adv(model,data_loader_test, criterion, args, 
                                adv_fn = adv.pgd,eps = eps,iters = j,td = args.top_down)

                sum_acc += acc
                res.append(acc)
                print("\niter is",j,"eps is",eps, "acc is", acc)
                print("\n")


            print("\naverage acc is ",sum_acc/len(iters)) 
            
        print(res)
        exit(0)


    if args.tb and utils.is_main_process():
        purge_step_train = args.start_epoch
        purge_step_te = args.start_epoch
        train_tb_writer = SummaryWriter(output_dir + '_logs/train', purge_step=purge_step_train)
        te_tb_writer = SummaryWriter(output_dir + '_logs/te', purge_step=purge_step_te)
        with open(output_dir + '_logs/args.txt', 'w', encoding='utf-8') as args_txt:
            args_txt.write(str(args))

        print(f'purge_step_train={purge_step_train}, purge_step_te={purge_step_te}')


    train_snn_aug = transforms.Compose([
                    transforms.RandomHorizontalFlip(p=0.5)
                    ])
    train_trivalaug = autoaugment.SNNAugmentWide()
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.num_classes)
        mixup_fn = Mixup(**mixup_args)
    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, num_epochs):
        save_max = False
        if args.distributed:
            train_sampler.set_epoch(epoch)
        if epoch >= 75:
            mixup_fn.mixup_enabled = False
        
        # changed on 10-24
        if beta_scheduler is not None:
            beta = beta_scheduler.get()
        else:
            beta = args.beta[0]

        train_loss, train_acc1, train_acc5 = train_one_epoch(
            model, criterion_train, optimizer, data_loader, device, epoch,
            args.print_freq, scaler, args.T_train,
            train_snn_aug, train_trivalaug, mixup_fn, td=args.top_down, beta=beta) # changed on 10-22
        
        # changed on 10-24
        if beta_scheduler is not None:
            beta_scheduler.step()
            
        if utils.is_main_process():
            train_tb_writer.add_scalar('train_loss', train_loss, epoch)
            train_tb_writer.add_scalar('train_acc1', train_acc1, epoch)
            train_tb_writer.add_scalar('train_acc5', train_acc5, epoch)
        lr_scheduler.step(epoch + 1)

        test_loss, test_acc1, test_acc5 = evaluate(model,args, criterion, data_loader_test, device=device, header='Test:', td=args.top_down)
        if te_tb_writer is not None:
            if utils.is_main_process():

                te_tb_writer.add_scalar('test_loss', test_loss, epoch)
                te_tb_writer.add_scalar('test_acc1', test_acc1, epoch)
                te_tb_writer.add_scalar('test_acc5', test_acc5, epoch)
        
        # changed on 10-23
        if output_dir is not None:
            save_train_metrics = {"loss":train_loss}
            save_eval_metrics = {"loss":test_loss, "top1":round(test_acc1, 4), "top5":round(test_acc5, 4)}
            update_summary(
                    epoch, save_train_metrics, save_eval_metrics, os.path.join(output_dir, 'summary.csv'),
                    write_header=None is None)

        if max_test_acc1 < test_acc1:
            max_test_acc1 = test_acc1
            test_acc5_at_max_test_acc1 = test_acc5
            save_max = True


        if output_dir:

            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args,
                'max_test_acc1': max_test_acc1,
                'test_acc5_at_max_test_acc1': test_acc5_at_max_test_acc1,
            }

            if save_max:
                utils.save_on_master(
                    checkpoint,
                    os.path.join(output_dir, 'checkpoint_max_test_acc1.pth'))
        print(args)
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))

        print('Training time {}'.format(total_time_str), 'max_test_acc1', max_test_acc1, 'test_acc5_at_max_test_acc1', test_acc5_at_max_test_acc1)
        print(output_dir)
    if output_dir:
        utils.save_on_master(
            checkpoint,
            os.path.join(output_dir, f'checkpoint_{epoch}.pth'))

    return max_test_acc1


def validate_adv(model, loader, loss_fn, args,adv_fn,
                 amp_autocast=suppress,iteration = None,alpha = None, eps = None,iters = None,td = False):
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.eval()
    last_idx = len(loader) - 1

    for batch_idx, (image, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        image = image.cuda()
        target = target.cuda()

        with amp_autocast():
            if iters is None:   #FGSM
                image = adv_fn(model = model,images = image,labels = target,
                            loss_fn = loss_fn,eps = eps,td =td)
            else:       #PGD
                image = adv_fn(model = model,images = image,labels = target,
                            loss_fn = loss_fn,eps = eps,iters = iters,td = td)
           
            if td:
                # x1,td,tmp = model(image)     #traditional
                # functional.reset_net(model)
                # outputs,rencon_loss = model(tmp,td = td)
                
                x1,td,tmp = model(image)
                outputs,recon_loss = model(image,td = td)
                
            else:
                outputs = model(image)

            # outputs = outputs.mean(0)
    


            functional.reset_net(model)

            acc1, acc5 = accuracy(outputs, target, topk=(1, 5))

            top1.update(acc1.item(), outputs.size(0))
            top5.update(acc5.item(), outputs.size(0))

            if last_batch or batch_idx % 40 == 0:
                log_name = 'Test'
                print(
                    '{0}: [{1:>4d}/{2}]  '
                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                        log_name, batch_idx, last_idx,top1=top1, top5=top5))

    metrics = OrderedDict([('top1', top1.avg), ('top5', top5.avg)])

    return metrics['top1']



if __name__ == "__main__":
    torch.set_printoptions(threshold=float('inf'))
    args = parse_args()
    main(args)

'''
/raid/wfang/datasets/DVS128Gesture

python train_imagenet.py --tb --amp --output-dir ./logs --model PlainNet --device cuda:0 --lr-step-size 64 --epoch 192 --T_train 12 --T 16 --data-path /raid/wfang/datasets/DVS128Gesture

python train_imagenet.py --tb --amp --output-dir ./logs --model SEWResNet --connect_f ADD --device cuda:0 --lr-step-size 64 --epoch 192 --T_train 12 --T 16 --data-path /raid/wfang/datasets/DVS128Gesture

'''
