import datetime
import os
import time
from collections import OrderedDict
import pandas as pd

import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import math
from torch.cuda import amp
import torch.distributed.optim
import argparse

from spikingjelly.clock_driven import functional
from q_module import *
from torch.autograd import Variable  # for lagrange multiplier
import q_sew_resnet, utils

_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import numpy as np
np.random.seed(_seed_)

import numpy as np
np.random.seed(_seed_)

'''
### train ###
python -m torch.distributed.launch --nproc_per_node=2 --use_env main_quantize_cbp.py --model sew_resnet18 -b 32 --output-dir ./logs --tb --print-freq 128 --amp --cache-dataset --connect_f ADD --T 4 --lr 0.1 --lr_lambda 0.01 --epochs 100 --data-path ./datasets/imagenet --quant bin --period 20 --device cuda
python -m torch.distributed.launch --nproc_per_node=2 --use_env main_quantize_cbp.py --model sew_resnet18 -b 32 --output-dir ./logs --tb --print-freq 128 --amp --cache-dataset --connect_f ADD --T 4 --lr 0.1 --lr_lambda 0.01 --epochs 100 --data-path ./datasets/imagenet --quant ter --period 20 --device cuda

python -m torch.distributed.launch --nproc_per_node=2 --use_env main_quantize_cbp.py --model sew_resnet34 -b 32 --output-dir ./logs --tb --print-freq 128 --amp --cache-dataset --connect_f ADD --T 4 --lr 0.1 --lr_lambda 0.01 --epochs 100 --data-path ./datasets/imagenet --quant bin --period 20 --device cuda
python -m torch.distributed.launch --nproc_per_node=2 --use_env main_quantize_cbp.py --model sew_resnet34 -b 32 --output-dir ./logs --tb --print-freq 128 --amp --cache-dataset --connect_f ADD --T 4 --lr 0.1 --lr_lambda 0.01 --epochs 100 --data-path ./datasets/imagenet --quant ter --period 20 --device cuda

### evaluate ###
python main_quantize_cbp.py --model sew_resnet18 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path ./datasets/imagenet --quant bin --device cuda:0 
python main_quantize_cbp.py --model sew_resnet18 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path ./datasets/imagenet --quant ter --device cuda:1
python main_quantize_cbp.py --model sew_resnet34 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path ./datasets/imagenet --quant bin --device cuda:0
python main_quantize_cbp.py --model sew_resnet34 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path ./datasets/imagenet --quant ter --device cuda:1
'''


'''
### train ###
python -m torch.distributed.launch --nproc_per_node=2 --use_env main_quantize_cbp.py --model sew_resnet18 -b 32 --output-dir ./logs --tb --print-freq 128 --amp --cache-dataset --connect_f ADD --T 4 --lr 0.1 --lr_lambda 0.01 --epochs 100 --data-path /home/jsm/dataset/imagenet/ --quant bin --period 20 --device cuda
python -m torch.distributed.launch --nproc_per_node=2 --use_env main_quantize_cbp.py --model sew_resnet18 -b 32 --output-dir ./logs --tb --print-freq 128 --amp --cache-dataset --connect_f ADD --T 4 --lr 0.1 --lr_lambda 0.01 --epochs 100 --data-path /home/jsm/dataset/imagenet/ --quant ter --period 20 --device cuda

python -m torch.distributed.launch --nproc_per_node=2 --use_env main_quantize_cbp.py --model sew_resnet34 -b 32 --output-dir ./logs --tb --print-freq 128 --amp --cache-dataset --connect_f ADD --T 4 --lr 0.1 --lr_lambda 0.01 --epochs 100 --data-path /home/jsm/dataset/imagenet/ --quant bin --period 20 --device cuda
python -m torch.distributed.launch --nproc_per_node=2 --use_env main_quantize_cbp.py --model sew_resnet34 -b 32 --output-dir ./logs --tb --print-freq 128 --amp --cache-dataset --connect_f ADD --T 4 --lr 0.1 --lr_lambda 0.01 --epochs 100 --data-path /home/jsm/dataset/imagenet/ --quant ter --period 20 --device cuda

### evaluate ###
python main_quantize_cbp.py --model sew_resnet18 -b 50 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path /home/jsm/dataset/imagenet/ --quant bin --device cuda:0 (54.298)
python main_quantize_cbp.py --model sew_resnet18 -b 50 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path /home/jsm/dataset/imagenet/ --quant ter --device cuda:1 (57.982)
python main_quantize_cbp.py --model sew_resnet34 -b 50 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path /home/jsm/dataset/imagenet/ --quant bin --device cuda:0 (59.964)
python main_quantize_cbp.py --model sew_resnet34 -b 50 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path /home/jsm/dataset/imagenet/ --quant ter --device cuda:1 (62.916)

python main_quantize_cbp.py --model sew_resnet18 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path /home/jsm/dataset/imagenet/ --quant bin --device cuda:0 (54.338)
python main_quantize_cbp.py --model sew_resnet18 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path /home/jsm/dataset/imagenet/ --quant ter --device cuda:1 (57.88)
python main_quantize_cbp.py --model sew_resnet34 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path /home/jsm/dataset/imagenet/ --quant bin --device cuda:0 (60.228)
python main_quantize_cbp.py --model sew_resnet34 --test-only --output-dir ./logs --print-freq 1024 --cache-dataset --connect_f ADD --T 4 --data-path /home/jsm/dataset/imagenet/ --quant ter --device cuda:1 (62.976)
'''

### Utils for applying CBP ###

def getparameters(model):
    lamb = []         # Lagrange multiplier
    qweight = []      # weight to be quantized
    nqweight = []     # weight not to be quantized such as first, last layer parameter
    otherparam = []   # otherparams such as batchnorm, bias
    factor = []       # factors of each quantized layers
    b = []            # b of each quantized layers (median)
    scale = []        # scale factor of each quantized layers (fixed)
    param_size = 0    
    for p in model.modules():
        if isinstance(p, (QConv2d, QLinear)):
            qweight += [p.weight]
            lamb += [Variable(torch.full(p.weight.shape, 0).float().cuda(), requires_grad=True)]
            if p.bias != None:
                otherparam += [p.bias]
            scale += [p.scale]
            factor += [p.factor]
            b += [p.b]
            param_size += p.weight.numel()
        elif isinstance(p, (nn.Conv2d, nn.Linear)):
            nqweight += [p.weight]
            if p.bias != None:
                otherparam += [p.bias]
        elif isinstance(p, (nn.BatchNorm2d, nn.BatchNorm1d)):
            otherparam += [p.weight]
            otherparam += [p.bias]
    return lamb, qweight, nqweight, otherparam, factor, b, scale, param_size

def updatelambda(optimizer2, qweight, lamb, scale, factor, b, ucs):
    const = torch.zeros(1).cuda()
    for i in range(len(lamb)):
        const = const + constraints(qweight[i].detach(), lamb[i], scale[i], factor[i], b[i], ucs) # weight detach
    optimizer2.zero_grad()
    (-const).backward(retain_graph=True) # gradient ascent
    optimizer2.step()
    
def constraints(weight, lamb, scale, factor, b, ucs):
    out = constraint().apply(weight, scale, factor, b, ucs)
    return (out*lamb).sum()
    
def CFS(weight, size, scale, factor, b):
    cfstotal = 0
    for p, q, r, s in zip(weight, scale, factor, b):
        cfs = constraint().apply(p, q, r, s, 1) # set ucs to 1
        cfstotal += cfs.sum()
    return cfstotal.item()/size

def adjust_lr(optimizer, decrease_rate):
    for p in optimizer.param_groups:
        p['lr']*=decrease_rate        


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq,  qweight, lamb, scale, factor, b, ucs, scaler=None):
    
    lagsum = torch.zeros(1).to(device)   
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))

    header = 'Epoch: [{}]'.format(epoch)

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        start_time = time.time()
        image, target = image.to(device), target.to(device)
        # with torch.autograd.detect_anomaly():
        if scaler is not None:
            with amp.autocast():
                output = model(image)
                loss = criterion(output, target)
                const = torch.zeros(1).to(device)
                for i in range(len(lamb)):         
                    const = const + constraints(qweight[i], lamb[i].detach(), scale[i], factor[i], b[i], ucs) 
                lag = loss + const
                lagsum += lag.detach()
                    
        else:
            output = model(image)
            loss = criterion(output, target)
            const = torch.zeros(1).to(device)
            for i in range(len(lamb)):
                const = const + constraints(qweight[i], lamb[i].detach(), scale[i], factor[i], b[i], ucs)
            lag = loss + const
            lagsum += lag.detach()

        optimizer.zero_grad()

        if scaler is not None:
            scaler.scale(lag).backward(retain_graph=True) 
            torch.nn.utils.clip_grad_value_(parameters=model.parameters(), clip_value=1)
            scaler.step(optimizer)
            scaler.update()
            for p in qweight:
                p.data.clamp_(min=-1, max=1)

        else:
            lag.backward(retain_graph=True)
            torch.nn.utils.clip_grad_value_(parameters=model.parameters(), clip_value=1)
            optimizer.step()
            for p in qweight:
                p.data.clamp_(min=-1, max=1)

        functional.reset_net(model)

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        loss_s = loss.item()
        if math.isnan(loss_s):
            raise ValueError('loss is Nan')
        acc1_s = acc1.item()
        acc5_s = acc5.item()

        metric_logger.update(loss=loss_s, lr=optimizer.param_groups[0]["lr"])

        metric_logger.meters['acc1'].update(acc1_s, n=batch_size)
        metric_logger.meters['acc5'].update(acc5_s, n=batch_size)
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    return lagsum, metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg



def evaluate(model, criterion, data_loader, device, print_freq=100, header='Test:'):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    with torch.no_grad():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)
            functional.reset_net(model)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()

    loss, acc1, acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    print(f' * Acc@1 = {acc1}, Acc@5 = {acc5}, loss = {loss}')
    return loss, acc1, acc5


def _get_cache_path(filepath):
    import hashlib
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path

def load_data(traindir, valdir, cache_dataset, distributed):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    if cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
        dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        if cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)
    if cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))
        if cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
    
    max_test_acc1 = 0.
    test_acc5_at_max_test_acc1 = 0.


    train_tb_writer = None
    te_tb_writer = None
    

    utils.init_distributed_mode(args)
    print(args)
    if args.test_only:
        pass
    else:
        output_dir = os.path.join(args.output_dir, f'{args.model}_b{args.batch_size}_lr{args.lr}_lr_lambda{args.lr_lambda}_T{args.T}_cnf_{args.connect_f}_P{args.period}_{args.quant}')

        if args.zero_init_residual:
            output_dir += '_zi'

        if args.weight_decay:
            output_dir += f'_wd{args.weight_decay}'

        if output_dir:
            utils.mkdir(output_dir)


    device = torch.device(args.device)

    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')
    dataset_train, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
                                                                   args.cache_dataset, args.distributed)
    print(f'dataset_train:{dataset_train.__len__()}, dataset_test:{dataset_test.__len__()}')

    data_loader = torch.utils.data.DataLoader(
        dataset_train, batch_size=args.batch_size,
        sampler=train_sampler, num_workers=args.workers, pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=args.batch_size,
        sampler=test_sampler, num_workers=args.workers, pin_memory=True)

    print("Creating model")
    
    ### Define model ###
    if args.model in q_sew_resnet.__dict__:
        model = q_sew_resnet.__dict__[args.model](zero_init_residual=args.zero_init_residual, T=args.T, connect_f=args.connect_f, mode=args.quant)
    else:
        raise NotImplementedError(args.model)
    #print(model)

    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    
    ### Get pre-trained weight ###
    if args.model == 'sew_resnet18':
        checkpoint_path = './trained_params/sew18_checkpoint_319.pth' # from 'trained_params' directory
    elif args.model == 'sew_resnet34':
        checkpoint_path = './trained_params/sew34_checkpoint_319.pth' # from 'trained_params' directory     
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    new_state_dict = OrderedDict()
    for n,v in checkpoint['model'].items():
        name = n.replace(".module", "")
        new_state_dict[name] = v
        
    model_state_dict = model.state_dict()
    model_state_dict.update(new_state_dict)
    model.load_state_dict(model_state_dict)
    print("Loading model")
    
    ### Initialization of scale ###
    for p in model.modules():
        if isinstance(p, (QConv2d, QLinear)):
            p.scale.data[0] = p.weight.abs().mean()
                      
    ### Get parameters ###
    lamb, qweight, nqweight, otherparam, factor, b, scale, param_size = getparameters(model)
       
    criterion = nn.CrossEntropyLoss()
    
    ### Optimizer ###
    if args.adam:
        optimizer1 = torch.optim.Adam([{'params':qweight, 'lr':args.lr, 'weight_decay':args.weight_decay},
                                       {'params':nqweight, 'lr':args.lr, 'weight_decay':args.weight_decay},
                                       {'params':otherparam, 'lr':args.lr}])
        optimizer2 = torch.optim.Adam([{'params':lamb, 'lr':args.lr_lambda}])
    else:
        optimizer1 = torch.optim.SGD([{'params':qweight, 'lr':args.lr, 'weight_decay':args.weight_decay},
                                       {'params':nqweight, 'lr':args.lr, 'weight_decay':args.weight_decay},
                                       {'params':otherparam, 'lr':args.lr}], momentum=args.momentum)
        optimizer2 = torch.optim.Adam([{'params':lamb, 'lr':args.lr_lambda}])

    if args.amp:
        scaler = amp.GradScaler()
    else:
        scaler = None

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer1.load_state_dict(checkpoint['optimizer1'])
        optimizer2.load_state_dict(checkpoint['optimizer2'])
        args.start_epoch = checkpoint['epoch'] + 1
        lamb = checkpoint['lamb']
        lagsum_pre = checkpoint['lagsum_pre']
        period = checkpoint['period']
        ucs = checkpoint['ucs']
        g = checkpoint['g']
        progress = checkpoint['progress']
        max_test_acc1 = checkpoint['max_test_acc1']
        test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']

    if args.test_only:
        if args.model == 'sew_resnet18':
            model_state_dict = torch.load('./trained_params/SEW_ResNet_18_ImageNet_' + args.quant + '_cbp_prequantized.pth')  # from 'trained_params' directory
        elif args.model == 'sew_resnet34':
            model_state_dict = torch.load('./trained_params/SEW_ResNet_34_ImageNet_' + args.quant + '_cbp_prequantized.pth')  # from 'trained_params' directory
        model.load_state_dict(model_state_dict)
        evaluate(model, criterion, data_loader_test, device=device, header='Test:')
        
        return

    if args.tb and utils.is_main_process():
        purge_step_train = args.start_epoch
        purge_step_te = args.start_epoch
        train_tb_writer = SummaryWriter(output_dir + '_logs/train', purge_step=purge_step_train)
        te_tb_writer = SummaryWriter(output_dir + '_logs/te', purge_step=purge_step_te)
        with open(output_dir + '_logs/args.txt', 'w', encoding='utf-8') as args_txt:
            args_txt.write(str(args))

        print(f'purge_step_train={purge_step_train}, purge_step_te={purge_step_te}')
    
    if args.resume:
        pass
    else:
        ### Initialization of unconstrained window ###
        g = 1
        ucs = 1-1/g
        
        ### Initial update of multiplier ###
        updatelambda(optimizer2, qweight, lamb, scale, factor, b, ucs)
        
        ### Save epoch, train_acc1, train_acc5, test_acc1, test_acc5, train loss, test loss, cfs, lagsum_pre ###
        progress = np.zeros((1,9))
        
        ### lagsum_max, period ###
        lagsum_pre = 1e10  # lagsum_max in algorithm 1 
        period = 0

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        save_max = False
        if args.distributed:
            train_sampler.set_epoch(epoch)
        lagsum, train_loss, train_acc1, train_acc5 = train_one_epoch(model, criterion, optimizer1, data_loader, device, epoch, args.print_freq, qweight, lamb, scale, factor, b, ucs, scaler)
        if utils.is_main_process():
            train_tb_writer.add_scalar('train_loss', train_loss, epoch)
            train_tb_writer.add_scalar('train_acc1', train_acc1, epoch)
            train_tb_writer.add_scalar('train_acc5', train_acc5, epoch)
        
        period += 1
        print(f'epoch={epoch}...lagsum_pre={lagsum_pre}...lagsum={lagsum}')
        
        if lagsum >= lagsum_pre or period == args.period:
            print('lambda update...')
            
            ### Update of unconstrained window ###
            if g<10:
                g+=1
            else:
                g+=10
            ucs = 1-1/g
            
            ### Learning rate scheduler ###
            if g==20:
                adjust_lr(optimizer1, 0.1)
                
            ### Update of lambda ###    
            updatelambda(optimizer2, qweight, lamb, scale, factor, b, ucs)
            
            ### Reset lagsum and period ###
            lagsum_pre = 1e10
            period = 0
        else:
            lagsum_pre = lagsum.item()
        
        
        test_loss, test_acc1, test_acc5 = evaluate(model, criterion, data_loader_test, device=device, header='Test:')
        
        ### Calculate cfs ###  
        cfs = CFS(qweight, param_size, scale, factor, b)
        
        ### Save data ###
        progress=np.append(progress, np.array([[epoch, train_acc1, train_acc5, test_acc1, test_acc5, train_loss, test_loss, cfs, lagsum_pre]]), axis=0)
        progress_data=pd.DataFrame(progress)
        progress_data.to_csv(output_dir+f"/progress_{args.model}_{args.quant}.txt",
        index=False, header=False,sep='\t')
                
        if te_tb_writer is not None:
            if utils.is_main_process():

                te_tb_writer.add_scalar('test_loss', test_loss, epoch)
                te_tb_writer.add_scalar('test_acc1', test_acc1, epoch)
                te_tb_writer.add_scalar('test_acc5', test_acc5, epoch)
                

        if max_test_acc1 < test_acc1:
            max_test_acc1 = test_acc1
            test_acc5_at_max_test_acc1 = test_acc5
            save_max = True


        if output_dir:

            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'factor': factor,
                'optimizer1': optimizer1.state_dict(),
                'optimizer2': optimizer2.state_dict(),
                'epoch': epoch,
                'lamb': lamb,
                'lagsum_pre': lagsum_pre,
                'period': period,
                'ucs': ucs,
                'g': g,
                'progress': progress,
                'args': args,
                'max_test_acc1': max_test_acc1,
                'test_acc5_at_max_test_acc1': test_acc5_at_max_test_acc1,
            }

            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir, 'checkpoint_latest.pth'))
            save_flag = False

            if epoch % 16 == 0 or epoch == args.epochs - 1:
                save_flag = True

            if save_flag:
                utils.save_on_master(
                    checkpoint,
                    os.path.join(output_dir, f'checkpoint_{epoch}.pth'))

            if save_max:
                utils.save_on_master(
                    checkpoint,
                    os.path.join(output_dir, 'checkpoint_max_test_acc1.pth'))
        print(args)
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print(output_dir)

        print('Training time {}'.format(total_time_str), 'max_test_acc1', max_test_acc1,
              'test_acc5_at_max_test_acc1', test_acc5_at_max_test_acc1)

def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')

    parser.add_argument('--data-path', default='/home/wfang/datasets/ImageNet', help='dataset')
    parser.add_argument('--model', default='sew_resnet18', help='model')
    parser.add_argument('--gpu', default=0, type=int, help='GPU id to use') 
    parser.add_argument('--device', default='cuda', help='device') 
    
    parser.add_argument('-b', '--batch-size', default=32, type=int)
    parser.add_argument('--epochs', default=100, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate of optimzer1')
    parser.add_argument('--lr_lambda', default=0.01, type=float, help='initial learning rate of optimizer2') 
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='Momentum for SGD. Adam will not use momentum')
    parser.add_argument('--wd', '--weight-decay', default=0, type=float,
                        metavar='W', help='weight decay (default: 0)',
                        dest='weight_decay')
    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='.', help='path where to save')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )

    # Mixed precision training parameters
    parser.add_argument('--amp', action='store_true',
                        help='Use AMP training')


    # distributed training parameters
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')


    parser.add_argument('--tb', action='store_true',
                        help='Use TensorBoard to record logs')
    parser.add_argument('--T', default=4, type=int, help='simulation steps')
    parser.add_argument('--adam', action='store_true',
                        help='Use Adam. The default optimizer is SGD.')

    parser.add_argument('--cos_lr_T', default=320, type=int,
                        help='T_max of CosineAnnealingLR.')
    parser.add_argument('--connect_f', type=str, help='spike-element-wise connect function')
    parser.add_argument('--zero_init_residual', action='store_true', help='zero init all residual blocks')
    
    # quantization parameters
    parser.add_argument('--quant', default='bin', type=str, help='quantization')               
    parser.add_argument('--period', type=int, help='max period of lambda update', default=20)                              

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    main(args)

