# -*- coding: utf-8 -*-
import sys

sys.path.append('.')
sys.path.append('..')
sys.path.append('../..')

import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from util import setup_seed, Logger
from util import Bar, AverageMeter, accuracy, get_dataset, warp_decay, get_model_name
from model import *


def train(train_dataloader, optimizer, model, evaluator, args=None):
    model.train()

    batch_time, data_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
    end = time.time()

    bar = Bar('Processing', max=len(train_dataloader))
    for idx, (X, y) in enumerate(train_dataloader):
        # measure data loading time
        data_time.update(time.time() - end)

        X, y = X.to(args.device), y.to(args.device)
        output, rate = model(X)
        loss = evaluator(rate, y)
        loss2 = loss / model.timestep
        loss2.backward()

        optimizer.step()
        optimizer.zero_grad()

        # measure accuracy and record loss
        prec1, prec5 = accuracy(rate.data, y.data, topk=(1, 5))
        losses.update(loss.data.item(), X.size(0))
        top1.update(prec1.item(), X.size(0))
        top5.update(prec5.item(), X.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
            batch=idx + 1,
            size=len(train_dataloader),
            data=data_time.avg,
            bt=batch_time.avg,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=losses.avg,
            top1=top1.avg,
            top5=top5.avg,
        )
        bar.next()
    bar.finish()

    return top1.avg, losses.avg


def test(val_dataloader, model, evaluator, args=None):
    model.eval()

    batch_time, data_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()

    end = time.time()
    bar = Bar('Processing', max=len(val_dataloader))
    with torch.no_grad():
        for idx, (X, y) in enumerate(val_dataloader):
            X, y = X.to(args.device), y.to(args.device)

            output, rate = model(X)
            avg_fr = output.mean(dim=0)
            loss = evaluator(avg_fr, y)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(avg_fr.data, y.data, topk=(1, 5))
            losses.update(loss.data.item(), X.size(0))
            top1.update(prec1.item(), X.size(0))
            top5.update(prec5.item(), X.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                batch=idx + 1,
                size=len(val_dataloader),
                data=data_time.avg,
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg,
                top1=top1.avg,
                top5=top5.avg,
            )
            bar.next()
        bar.finish()

        return top1.avg, losses.avg


def main():
    setup_seed(args.seed)
    dtype = torch.float
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    model_name = get_model_name('', args)

    args.log_path = os.path.join(args.log_path, model_name)
    if not os.path.exists(args.log_path):
        os.mkdir(args.log_path)
    log = Logger(args, args.log_path)
    log.info_config(args)
    start_epoch = 0
    best_epoch = 0
    best_acc = 0
    train_trace, val_trace = dict(), dict()
    train_trace['acc'], train_trace['loss'], train_trace['temp'] = [], [], []
    val_trace['acc'], val_trace['loss'] = [], []

    train_data, val_data, num_class = get_dataset(args.dataset, args.data_path, cutout=args.cutout,
                                                  auto_aug=args.auto_aug)
    train_dataloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=args.train_batch_size, shuffle=True,
                                                   pin_memory=True, num_workers=args.num_workers)
    val_dataloader = torch.utils.data.DataLoader(dataset=val_data, batch_size=args.val_batch_size, shuffle=False,
                                                 pin_memory=True, num_workers=args.num_workers)
    decay = nn.Parameter(warp_decay(args.decay))
    thresh = nn.Parameter(torch.tensor(args.thresh)) if args.train_thresh else args.thresh
    args.alpha = 1 / args.alpha

    kwargs_spikes = {'timestep': args.T, 'vreset': 0.0, 'thresh': thresh,
                     'decay': decay, 'detach_reset': args.detach_reset, "rate_flag": args.rate_flag}
    model = eval(args.arch + f'(num_classes={num_class}, **kwargs_spikes)')
    model.to(device, dtype)
    print(model)

    if args.optim.lower() == 'sgdm':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    elif args.optim.lower() == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=False)
    else:
        raise NotImplementedError()

    evaluator = torch.nn.CrossEntropyLoss()
    if args.resume is not None:
        state = torch.load(args.resume, map_location=device)
        model.load_state_dict(state['best_net'])
        optimizer.load_state_dict(state['optimizer'])
        start_epoch = state['best_epoch']
        best_acc = state['best_acc']
        train_trace = state['traces']['train']
        val_trace = state['traces']['val']
        log.info('Load checkpoint from epoch {}'.format(start_epoch))
        log.info('Best accuracy so far {}.'.format(best_acc))
        log.info('Test the checkpoint: {}'.format(test(val_dataloader, model, evaluator, args=args)))

    args.start_epoch = start_epoch
    if args.scheduler.lower() == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.num_epoch)
    elif args.scheduler.lower() == 'step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestone, gamma=0.8)
    else:
        raise NotImplementedError()

    for epoch in range(start_epoch, start_epoch + args.num_epoch):
        train_acc, train_loss = train(train_dataloader, optimizer, model, evaluator, args=args)
        if args.scheduler != 'None':
            scheduler.step()
        val_acc, val_loss = test(val_dataloader, model, evaluator, args=args)

        if val_acc > best_acc:
            best_acc = val_acc
            best_train_acc = train_acc
            best_epoch = epoch
            # log.info('Saving custom_model..  with acc {0} in the epoch {1}'.format(best_acc, epoch))
            state = {
                'best_acc': best_acc,
                'best_epoch': epoch,
                'best_net': model.state_dict(),
                'best_train_acc': best_train_acc,
                'optimizer': optimizer.state_dict(),
                'traces': {'train': train_trace, 'val': val_trace},
            }
            torch.save(state, os.path.join(args.ckpt_path, model_name + '.pth'))
        log.info(
            'Epoch %03d: train loss %.5f, test loss %.5f, train acc %.5f, test acc %.5f, Saved custom_model..  with acc %.5f in the epoch %03d' % (
                epoch, train_loss, val_loss, train_acc, val_acc, best_acc, best_epoch))

        # record and log
        train_trace['acc'].append(train_acc)
        train_trace['loss'].append(train_loss)
        val_trace['acc'].append(val_acc)
        val_trace['loss'].append(val_loss)

    log.info(
        'Finish training: the best validation accuracy of SNN is {} in epoch {}. \n The relate checkpoint path: {}'.format(
            best_acc, best_epoch, os.path.join(args.ckpt_path, model_name + '.pth')))


if __name__ == '__main__':
    # set random seed, device, data type
    from config.config import args

    main()
