# -*- coding: utf-8 -*-
import sys

sys.path.append('.')
sys.path.append('..')
sys.path.append('../..')

import csv
import time
import torch
from model.vgg import VGG
import os
import re
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from util.util import setup_seed, Logger
from torch.utils.tensorboard import SummaryWriter
from util.image_augment import CIFAR10Policy, Cutout
from torchvision import datasets, transforms, models
from util.data import CIFAR10_DVS, CIFAR10_DVS_Aug


def TET_loss(outputs, labels, criterion, means, lamb):
    T = outputs.size(0)
    Loss_es = 0
    for t in range(T):
        Loss_es += criterion(outputs[t, ...], labels)
    Loss_es = Loss_es / T  # L_TET
    if lamb != 0:
        MMDLoss = torch.nn.MSELoss()
        y = torch.zeros_like(outputs).fill_(means)
        Loss_mmd = MMDLoss(outputs, y)  # L_mse
    else:
        Loss_mmd = 0
    return (1 - lamb) * Loss_es + lamb * Loss_mmd  # L_Total


def TET_loss_rate(outputs, rate, labels, criterion, means, lamb):
    T = outputs.size(0)
    Loss_es = 0
    # out_list = []
    for t in range(T):
        outputs[t] = outputs[t].detach() + rate - rate.detach()
    # out = torch.stack(out_list, dim=0)
    for t in range(T):
        Loss_es += criterion(outputs[t, ...], labels)
    Loss_es = Loss_es / T  # L_TET
    if lamb != 0:
        MMDLoss = torch.nn.MSELoss()
        y = torch.zeros_like(outputs).fill_(means)
        Loss_mmd = MMDLoss(outputs, y)  # L_mse
    else:
        Loss_mmd = 0
    return ((1 - lamb) * Loss_es + lamb * Loss_mmd), ((1 - lamb) * Loss_es + lamb * Loss_mmd).detach()  # L_Total


def runTrain(epoch, train_ldr, optimizer, model, evaluator, args=None, encoder=None):
    loss_record = []
    predict_tot = []
    label_tot = []
    model.train()
    start_time = time.time()
    for idx, (ptns, labels) in enumerate(train_ldr):
        ptns, labels = ptns.to(args.device), labels.to(args.device)
        if encoder is not None:
            ptns = encoder(ptns)
        # optimizer.zero_grad()
        if isinstance(optimizer, list):
            for optim in optimizer:
                optim.zero_grad()
        else:
            optimizer.zero_grad()
        output, rate = model(ptns)
        if not args.rate_flag:
            loss = TET_loss(output, labels, evaluator, args.means, args.lamb)
            loss.backward()
        else:
            if args.ce:
                loss = evaluator(rate, labels)
                loss.backward()
            else:
                loss2, loss = TET_loss_rate(output, rate, labels, evaluator, args.means, args.lamb)
                loss2.backward()
        # optimizer.step()
        if isinstance(optimizer, list):
            for optim in optimizer:
                optim.step()
        else:
            optimizer.step()
        predict = torch.argmax(output.mean(0), axis=1)
        # record results
        loss_record.append(loss.detach().cpu())
        predict_tot.append(predict)
        label_tot.append(labels)
        if (idx + 1) % args.log_interval == 0:
            log.info('\nEpoch [%d/%d], Step [%d/%d], Loss: %.5f'
                     % (
                         epoch, args.num_epoch + args.start_epoch, idx + 1,
                         len(train_ldr.dataset) // args.train_batch_size,
                         loss_record[-1] / args.train_batch_size))
            log.info('Time elasped: {}'.format(time.time() - start_time))
    predict_tot = torch.cat(predict_tot)
    label_tot = torch.cat(label_tot)
    train_acc = torch.mean((predict_tot == label_tot).float())
    train_loss = torch.tensor(loss_record).sum() / len(label_tot)
    return train_acc, train_loss


def runTest(val_ldr, model, evaluator, args=None, encoder=None):
    model.eval()
    with torch.no_grad():
        predict_tot = {}
        label_tot = []
        loss_record = []
        key = 'ann' if encoder is None else 'snn'
        for idx, (ptns, labels) in enumerate(val_ldr):

            ptns, labels = ptns.to(args.device), labels.to(args.device)
            if encoder is not None:
                ptns = encoder(ptns)
            output, rate = model(ptns)
            avg_fr = torch.mean(output, dim=0)
            if isinstance(output, dict):
                for t in output.keys():
                    if t not in predict_tot.keys():
                        predict_tot[t] = []
                    predict = torch.argmax(output[t], axis=1)
                    predict_tot[t].append(predict)
                loss = evaluator(output[encoder.nb_steps], labels)

            else:
                if key not in predict_tot.keys():
                    predict_tot[key] = []
                loss = evaluator(avg_fr, labels)

                predict = torch.argmax(output.mean(0), axis=1)
                predict_tot[key].append(predict)
            loss_record.append(loss)
            label_tot.append(labels)

        label_tot = torch.cat(label_tot)
        val_loss = torch.tensor(loss_record).sum() / len(label_tot)
        if 'ann' not in predict_tot.keys() and 'snn' not in predict_tot.keys():
            val_acc = {}
            for t in predict_tot.keys():
                val_acc[t] = torch.mean((torch.cat(predict_tot[t]) == label_tot).float())

        else:
            predict_tot = torch.cat(predict_tot[key])
            val_acc = torch.mean((predict_tot == label_tot).float())
        return val_acc, val_loss


def loadData(name, root, cutout=False, auto_aug=False):
    num_class, normalize, train_data, test_data = None, None, None, None
    train_transform = []
    if name == 'CIFAR10' or name == 'CIFAR100':
        train_transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
    if auto_aug:
        train_transform.append(CIFAR10Policy())
    train_transform.append(transforms.ToTensor())
    if cutout:
        train_transform.append(Cutout(n_holes=1, length=16))
    if name == 'CIFAR10':
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        num_class = 10
    elif name == 'CIFAR100':
        normalize = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        num_class = 100

    train_transform.append(normalize)
    train_transform = transforms.Compose(train_transform)
    val_transform = transforms.Compose([transforms.ToTensor(),
                                        normalize
                                        ])
    if name == 'CIFAR10_DVS_Aug':
        train_path = root + '/train'
        val_path = root + '/test'
        train_data = CIFAR10_DVS_Aug(root=train_path, transform=True)
        val_data = CIFAR10_DVS_Aug(root=val_path)
        num_class = 10
    else:
        raise NotImplementedError()
    return train_data, val_data, num_class


def warp_decay(decay):
    import math
    return torch.tensor(math.log(decay / (1 - decay)))


def main():
    start_epoch = 0
    best_epoch = 0
    best_acc = 0
    train_trace, val_trace = dict(), dict()
    train_trace['acc'], train_trace['loss'], train_trace['temp'] = [], [], []
    val_trace['acc'], val_trace['loss'] = [], []
    writer = SummaryWriter(args.log_path)
    train_data, val_data, num_class = loadData(args.dataset, args.data_path, cutout=args.cutout, auto_aug=args.auto_aug)
    train_ldr = torch.utils.data.DataLoader(dataset=train_data, batch_size=args.train_batch_size, shuffle=True,
                                            pin_memory=True, num_workers=args.num_workers)
    val_ldr = torch.utils.data.DataLoader(dataset=val_data, batch_size=args.val_batch_size, shuffle=False,
                                          pin_memory=True,
                                          num_workers=args.num_workers)
    # decay = nn.Parameter(torch.tensor(args.decay)) if args.train_decay else args.decay
    decay = nn.Parameter(warp_decay(args.decay)) if args.train_decay else warp_decay(args.decay)
    thresh = nn.Parameter(torch.tensor(args.thresh)) if args.train_thresh else args.thresh
    args.alpha = 1 / args.alpha

    kwargs_spikes = {'timestep': args.T, 'vreset': 0.0, 'thresh': thresh,
                     'decay': decay, 'detach_reset': args.detach_reset, "rate_flag": args.rate_flag}

    if 'vgg' in args.arch.lower():
        model = VGG(architecture=args.arch, use_bias=args.bias, in_channel=2, bn_type=args.bn_type, **kwargs_spikes).to(
            device, dtype)
    else:
        raise NotImplementedError()

    log.info(model)

    if args.optim.lower() == 'sgdm':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    elif args.optim.lower() == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=False)
    else:
        raise NotImplementedError()

    evaluator = torch.nn.CrossEntropyLoss()
    if args.resume is not None:
        state = torch.load(args.resume, map_location=device)
        model.load_state_dict(state['best_net'])
        optimizer.load_state_dict(state['optimizer'])
        start_epoch = state['best_epoch']
        best_acc = 0
        train_trace = state['traces']['train']
        val_trace = state['traces']['val']
        log.info('Load checkpoint from epoch {}'.format(start_epoch))
        log.info('Best accuracy so far {}.'.format(best_acc))
        log.info('Test the checkpoint: {}'.format(runTest(val_ldr, model, evaluator, args=args)))

    args.start_epoch = start_epoch

    if args.scheduler.lower() == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.num_epoch)
    elif args.scheduler.lower() == 'step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestone, gamma=0.8)
    else:
        raise NotImplementedError()

    for epoch in tqdm(range(start_epoch, start_epoch + args.num_epoch)):
        train_acc, train_loss = runTrain(epoch, train_ldr, optimizer, model, evaluator, args=args)
        if args.scheduler != 'None':
            scheduler.step()
        val_acc, val_loss = runTest(val_ldr, model, evaluator, args=args)

        # saving checkpoint
        log.info('Epoch %d: train loss %.5f, train acc %.5f, test acc %.5f ' % (epoch, train_loss, train_acc, val_acc))
        if val_acc > best_acc:
            best_acc = val_acc
            best_train_acc = train_acc
            best_epoch = epoch
            log.info('Saving custom_model..  with acc {0} in the epoch {1}'.format(best_acc, epoch))
            state = {
                'best_acc': best_acc,
                'best_epoch': epoch,
                'best_net': model.state_dict(),
                'best_train_acc': best_train_acc,
                'optimizer': optimizer.state_dict(),
                'traces': {'train': train_trace, 'val': val_trace},
                'config': args
            }
            torch.save(state, os.path.join(args.ckpt_path, model_name + '.pth'))
        # record and log
        train_trace['acc'].append(train_acc)
        train_trace['loss'].append(train_loss)
        val_trace['acc'].append(val_acc)
        val_trace['loss'].append(val_loss)

        # record in tensorboard
        writer.add_scalars('Loss', {'val': val_loss, 'train': train_loss},
                           epoch + 1)
        writer.add_scalars('Acc', {'val': val_acc, 'train': train_acc},
                           epoch + 1)

    log.info('Finish training: the best training accuracy is {} in epoch {}. \n The relate checkpoint path: {}'.format(
        best_acc, best_epoch, os.path.join(args.ckpt_path, model_name + '.pth')))
    if args.save_last:
        state = {
            'best_acc': best_acc,
            'best_epoch': epoch,
            'final_model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'traces': {'train': train_trace, 'val': val_trace},
            'config': args
        }
        torch.save(state, os.path.join(args.ckpt_path, model_name + '_final.pth'))
    if args.grid_search is not None:
        write_head = not os.path.exists(args.grid_search)
        with open(args.grid_search, 'a+') as f:
            args.cmd = ['python ' + os.path.basename(__file__)] + args.cmd
            writer = csv.DictWriter(f, fieldnames=['custom_model name', 'best acc', 'cmd'])
            if write_head:
                writer.writeheader()
            writer.writerow({'best acc': best_acc, 'custom_model name': model_name, 'cmd': ' '.join(args.cmd)})


def get_model_name(model_name, args):
    aug_str = '_'.join(['cut' if args.cutout else ''] + ['aug' if args.auto_aug else ''])
    if aug_str[0] != '_': aug_str = '_' + aug_str
    if aug_str[-1] != '_': aug_str = aug_str + '-'
    model_name += args.dataset.lower() + aug_str + 'snn' + '_t' + str(
        args.T) + '_' + args.arch.lower() + '_width_' + str(
        args.alpha) + '-opt_' + args.optim.lower() + (
                          '_' + args.bn_type) + ('_bias' if args.bias else '') + '_wd_' + str(args.wd)
    cas_num = len([one for one in os.listdir(args.log_path) if one.startswith(model_name)])
    model_name += '_cas_' + str(cas_num)
    print('Generate custom_model name: ' + model_name)
    return model_name


if __name__ == '__main__':
    # set random seed, device, data type
    from config.config import args

    setup_seed(args.seed)
    dtype = torch.float
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    # runtime configures
    args.log_interval = 100
    # practical config for learning scheduler
    model_name = get_model_name('', args)

    # redirect the output
    args.log_path = os.path.join(args.log_path, model_name)
    if not os.path.exists(args.log_path):
        os.mkdir(args.log_path)
    log = Logger(args, args.log_path)
    log.info_config(args)
    # sys.stdout = Logger(os.path.join(args.log_path, model_name + '.log'))
    # sys.stderr = Logger(os.path.join(args.log_path, model_name + '.log'))
    main()
