import argparse
import torch.nn as nn
from data_loaders import *
from torch.utils.data import DataLoader
from models.vgg import vggsnn
from functions import TET_loss, seed_all, get_logger, split_weights
from pack.util import save_checkpoint

os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('-----------DVSCIFAR10 vggsnn 32w32u------------')

parser = argparse.ArgumentParser(description='PyTorch Temporal Efficient Training')
parser.add_argument('-j',
                    '--workers',
                    default=16,
                    type=int,
                    metavar='N',
                    help='number of data loading workers (default: 16)')
parser.add_argument('--seed',
                    default=3407,
                    type=int,
                    metavar='S',
                    help='random seed (default: 3407)')
parser.add_argument('--epochs',
                    default=300,
                    type=int,
                    metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--b',
                    default=64, # dvs 64; other 256
                    type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--means',
                    default=1.0,
                    type=float,
                    metavar='N',
                    help='make all the potential increment around the means (default: 1.0)')
parser.add_argument('-lr',
                    '--learning_rate',
                    default=1e-1,
                    type=float,
                    metavar='LR',
                    help='initial learning rate',
                    dest='lr')
parser.add_argument('--time',
                    default=10,
                    type=int,
                    metavar='N',
                    help='snn simulation time (default: 2)')
parser.add_argument('--TET',
                    default=False,
                    type=bool,
                    metavar='N',
                    help='if use Temporal Efficient Training (default: True)')
parser.add_argument('--lamb',
                    default=1e-3,
                    type=float,
                    metavar='N',
                    help='adjust the norm factor to avoid outlier (default: 0.0)')
parser.add_argument('--dataset',
                    default='DVSCIFAR10',
                    type=str,
                    help='dataset name',
                    choices=['CIFAR10', 'CIFAR100', 'ImageNet', 'TinyImageNet', 'DVSCIFAR10'])
parser.add_argument('--arch',
                    default='vggsnn',
                    type=str,
                    help='model',
                    choices=['res19', 'res18', 'vgg16', 'vggsnn'])
args = parser.parse_args()


def train(model, device, train_loader, criterion, optimizer, epoch, args):
    running_loss = 0

    model.train()

    total = 0
    correct = 0
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        labels = labels.to(device)
        images = images.to(device)
        outputs = model(images)
        mean_out = outputs.mean(1)
        if args.TET:
            loss = TET_loss(outputs, labels, criterion, 1.0, args.lamb)
        else:
            loss = criterion(mean_out, labels)
        running_loss += loss.item()
        loss.mean().backward()
        optimizer.step()
        total += float(labels.size(0))
        _, predicted = mean_out.cpu().max(1)
        correct += float(predicted.eq(labels.cpu()).sum().item())
    return running_loss, 100 * correct / total


@torch.no_grad()
def test(model, test_loader, device):
    correct = 0
    total = 0
    model.eval()
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        outputs = model(inputs)
        mean_out = outputs.mean(1)
        _, predicted = mean_out.cpu().max(1)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets).sum().item())
        if batch_idx % 100 == 0:
            acc = 100. * float(correct) / float(total)
            print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)
    final_acc = 100 * correct / total
    return final_acc


if __name__ == '__main__':
    seed_all(args.seed)

    trainset, testset = build_dvscifar10(T=args.time)
    CLASSES = 10

    train_loader = DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.workers, pin_memory=True)
    test_loader = DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.workers, pin_memory=True)

    model = vggsnn(num_classes=CLASSES)

    parallel_model = torch.nn.DataParallel(model).cuda()
    print(model)

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(params=split_weights(model), lr=args.lr,momentum=0.9, weight_decay=5e-4)    #
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.epochs)

    best_acc = 0
    best_epoch = 0

    logger = get_logger('dvscifar10-vggsnn-teacher.log')
    logger.info('start training!')

    for epoch in range(args.epochs):

        loss, acc = train(parallel_model, device, train_loader, criterion, optimizer, epoch, args)
        logger.info('Epoch:[{}/{}]\t loss={:.5f}\t acc={:.3f}'.format(epoch, args.epochs, loss, acc))
        scheduler.step()
        facc = test(parallel_model, test_loader, device)
        logger.info('Epoch:[{}/{}]\t Test acc={:.3f}'.format(epoch, args.epochs, facc))

        if best_acc < facc:
            best_acc = facc
            best_epoch = epoch + 1
            save_checkpoint({
                'epoch': epoch,
                'state_dict': parallel_model.module.state_dict(),
                'best_top1_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, './checkpoints/','dvscifar10-vggsnn-teacher.pth.tar')
        logger.info('Best Test acc={:.3f}'.format(best_acc))
        print('\n')