import csv
from model.VGG import VGG
from model.Resnet import SResNetX
import torch
import time
import re
from tqdm import tqdm
import torch.optim as optim
from util.image_augment import CIFAR10Policy, Cutout
from util.util import setup_seed
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from model.noisy_spike import NoisySpike, MutiStepNoisyRateScheduler


def extractWidth(net):
    width = []
    for m in net.modules():
        if isinstance(m, NoisySpike):
            width.append(m.sig.alpha.detach().cpu())
    return torch.tensor(width)


def runTrain(epoch, train_ldr, optimizer, model, evaluator, args=None, encoder=None):
    loss_record = []
    predict_tot = []
    label_tot = []
    model.train()
    start_time = time.time()
    for idx, (ptns, labels) in enumerate(train_ldr):
        ptns, labels = ptns.to(args.device), labels.to(args.device)
        if encoder is not None:
            ptns = encoder(ptns)
        optimizer.zero_grad()
        output = model(ptns)
        loss = evaluator(output, labels)
        loss.backward()
        optimizer.step()
        predict = torch.argmax(output, axis=1)
        # record results
        loss_record.append(loss.detach().cpu())
        predict_tot.append(predict)
        label_tot.append(labels)
        if (idx + 1) % args.log_interval == 0:
            print('\nEpoch [%d/%d], Step [%d/%d], Loss: %.5f'
                  % (epoch, args.num_epoch + args.start_epoch, idx + 1, len(train_ldr.dataset) // args.train_batch_size,
                     loss_record[-1] / args.train_batch_size))
            running_loss = 0
            print('Time elasped:', time.time() - start_time)
    predict_tot = torch.cat(predict_tot)
    label_tot = torch.cat(label_tot)
    train_acc = torch.mean((predict_tot == label_tot).float())
    train_loss = torch.tensor(loss_record).sum() / len(label_tot)
    return train_acc, train_loss


def runTest(val_ldr, model, evaluator, args=None, encoder=None):
    model.eval()
    with torch.no_grad():
        predict_tot = {}
        label_tot = []
        loss_record = []
        key = 'ann' if encoder is None else 'snn'
        for idx, (ptns, labels) in enumerate(val_ldr):
            # ptns: batch_size x num_channels x T x nNeu ==> batch_size x T x (nNeu*num_channels)
            ptns, labels = ptns.to(args.device), labels.to(args.device)
            if encoder is not None:
                ptns = encoder(ptns)
            output = model(ptns)
            if isinstance(output, dict):
                for t in output.keys():
                    if t not in predict_tot.keys():
                        predict_tot[t] = []
                    predict = torch.argmax(output[t], axis=1)
                    predict_tot[t].append(predict)
                loss = evaluator(output[encoder.nb_steps], labels)

            else:
                if key not in predict_tot.keys():
                    predict_tot[key] = []
                loss = evaluator(output, labels)
                # snn.clamp()
                predict = torch.argmax(output, axis=1)
                predict_tot[key].append(predict)
            loss_record.append(loss)
            label_tot.append(labels)

        label_tot = torch.cat(label_tot)
        val_loss = torch.tensor(loss_record).sum() / len(label_tot)
        if 'ann' not in predict_tot.keys() and 'snn' not in predict_tot.keys():
            val_acc = {}
            for t in predict_tot.keys():
                val_acc[t] = torch.mean((torch.cat(predict_tot[t]) == label_tot).float())

        else:
            predict_tot = torch.cat(predict_tot[key])
            val_acc = torch.mean((predict_tot == label_tot).float())
        return val_acc, val_loss


def loadData(name, root, cutout=False, auto_aug=False):
    num_class, normalize, train_data, test_data = None, None, None, None
    train_transform = []
    if name == 'CIFAR10' or name == 'CIFAR100':
        train_transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
    if auto_aug:
        train_transform.append(CIFAR10Policy())
    train_transform.append(transforms.ToTensor())
    if cutout:
        train_transform.append(Cutout(n_holes=1, length=16))
    if name == 'CIFAR10':
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        num_class = 10
    elif name == 'CIFAR100':
        normalize = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        num_class = 100
    elif name == 'MNIST':
        normalize = transforms.Normalize((0.1307,), (0.3081,))
        num_class = 10
    train_transform.append(normalize)
    train_transform = transforms.Compose(train_transform)
    val_transform = transforms.Compose([transforms.ToTensor(),
                                        normalize
                                        ])
    if name == 'CIFAR100':
        train_data = datasets.CIFAR100(root=root, train=True, download=True,
                                       transform=train_transform)
        val_data = datasets.CIFAR100(root=root, train=False, download=True,
                                     transform=val_transform)
    elif name == 'CIFAR10':
        train_data = datasets.CIFAR10(root=root, train=True, download=True,
                                      transform=train_transform)
        val_data = datasets.CIFAR10(root=root, train=False, download=True,
                                    transform=val_transform)
    elif name == 'MNIST':
        train_data = datasets.MNIST(root=root, train=True, download=True,
                                    transform=train_transform)
        val_data = datasets.MNIST(root=root, train=False, download=True,
                                  transform=val_transform)
    return train_data, val_data, num_class


def main(args):
    globals().update(vars(args))
    # initial varaibles both for record and visulization
    args.start_epoch = start_epoch = 0
    args.start_acc = best_acc = 0
    train_trace, val_trace = dict(), dict()
    train_trace['acc'], train_trace['loss'] = [], []
    val_trace['acc'], val_trace['loss'] = [], []
    if args.train_spike_record:
        train_trace['spike_acc'] = []
        train_trace['spike_loss'] = []
    width_trace = []
    writer = SummaryWriter(log_path)
    train_data, val_data, num_class = loadData(dataset, data_path, cutout=cutout, auto_aug=auto_aug)
    train_ldr = torch.utils.data.DataLoader(dataset=train_data, batch_size=train_batch_size, shuffle=True,
                                            pin_memory=False, num_workers=0)
    val_ldr = torch.utils.data.DataLoader(dataset=val_data, batch_size=val_batch_size, shuffle=False, pin_memory=False,
                                          num_workers=0)
    act_fun = NoisySpike(p=noisy_rate)
    kwargs_spikes = {'nb_steps': args.T, 'Act': act_fun, 'trainable_thresh': False,
                     'trainable_decay': train_decay}
    if args.architecture == 'CIFAR':
        # model = CifarNet(num_classes=num_class, in_channels=3, **kwargs_spikes)
        model = VGG(vgg_name=architecture, labels=num_class, dataset=dataset, **kwargs_spikes)
        model.to(device, dtype)
    elif 'SResNet' in args.architecture:
        depth = int(re.findall("\d+", args.architecture)[0])
        model = SResNetX(depth, num_class=num_class, **kwargs_spikes).to(device, dtype)
    evaluator = torch.nn.CrossEntropyLoss()
    if (pretrain is not None):
        state = torch.load(pretrain, map_location=device)
        model.load_state_dict(state['best_net'])
        args.start_epoch = start_epoch = state['best_epoch']
        train_trace = state['traces']['train']
        val_trace = state['traces']['val']
        best_acc = state['best_acc']
        print('The record best acc: ', best_acc)
        print('Validation for the pretrained model: ', runTest(val_ldr, model, evaluator, args))

    print(model)
    if (optim_name == 'SGDM'):
        optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=momentum, weight_decay=weight_decay)
    elif (optim_name == 'ADAM'):
        optimizer = optim.Adam(model.parameters(), lr=init_lr, amsgrad=True, weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=num_epoch)
    rate_scheduler = MutiStepNoisyRateScheduler(init_p=noisy_rate, reduce_ratio=reduce_ratio, num_epoch=num_epoch,
                                                start_epoch=start_epoch, milestones=milestones)
    best_epoch = 0
    width_trace.append(extractWidth(model).mean().item())
    for epoch in tqdm(range(start_epoch, start_epoch + num_epoch)):
        train_acc, train_loss = runTrain(epoch, train_ldr, optimizer, model, evaluator, args=args)
        scheduler.step()
        val_acc, val_loss = runTest(val_ldr, model, evaluator, args=args)
        width = extractWidth(model)
        # saving checkpoint
        print('validation record:', val_loss, val_acc)
        rate_scheduler(epoch, model)
        if (val_acc > best_acc):
            best_acc = val_acc
            best_epoch = epoch
            print('Saving model..  with acc {0} in the epoch {1}'.format(best_acc, epoch))
            state = {
                'best_acc': best_acc,
                'best_epoch': epoch,
                'best_net': model.state_dict(),
                'traces': {'train': train_trace, 'val': val_trace, 'width': width_trace},
                'config': args
            }
            torch.save(state, os.path.join(ckpt_path, model_name + '.pth'))
        # record and log
        train_trace['acc'].append(train_acc)
        train_trace['loss'].append(train_loss)
        val_trace['acc'].append(val_acc)
        val_trace['loss'].append(val_loss)
        if args.train_spike_record:
            spike_train_acc, spike_train_loss = runTest(train_ldr, model, evaluator, args=args)
            train_trace['spike_acc'].append(spike_train_acc)
            train_trace['spike_loss'].append(spike_train_loss)
        width_trace.append(width)
        # record in tensorboard
        if args.train_spike_record:
            writer.add_scalars('loss', {'val': val_loss, 'train': train_loss, 'train_spike': spike_train_loss},
                               epoch + 1)
            writer.add_scalars('acc', {'val': val_acc, 'train': train_acc, 'train_spike': spike_train_acc},
                               epoch + 1)
        else:
            writer.add_scalars('loss', {'val': val_loss, 'train': train_loss},
                               epoch + 1)
            writer.add_scalars('acc', {'val': val_acc, 'train': train_acc},
                               epoch + 1)
        writer.add_scalar('width', width.mean().item(), epoch)
        print('Epoch %d: train acc %.5f, test acc %.5f ' % (epoch, train_acc, val_acc))
    print('Finish training: the best traning accuray is {} in epoch {}. \n The relate checkpoint path: {}'.format(
        best_acc,
        best_epoch,
        os.path.join(
            ckpt_path,
            model_name + '.pth')))
    if args.save_last:
        state = {
            'best_acc': best_acc,
            'best_epoch': epoch,
            'final_model': model.state_dict(),
            'traces': {'train': train_trace, 'val': val_trace, 'width': width_trace},
            'config': args
        }
        torch.save(state, os.path.join(ckpt_path, model_name + '_final.pth'))
    if args.grid_search is not None:
        with open(args.grid_search, 'a+') as f:
            writer = csv.DictWriter(f, fieldnames=['best acc', 'model name'])
            # writer.writeheader()
            writer.writerow({'best acc': best_acc, 'model name': model_name})


def get_model_name(model_name, args):
    globals().update(vars(args))
    aug_str = '_'.join(['cut' if cutout else ''] + ['aug' if auto_aug else ''])
    if aug_str[0] != '_': aug_str = '_' + aug_str
    if aug_str[-1] != '_': aug_str = aug_str + '-'
    model_name += dataset.lower() + aug_str + 'ann' + '_' + architecture.lower() + '_T_' + str(
        args.T) + '_act_' + act_fun_name + '-opt_' + optim_name.lower() + (
                      '_bn' if use_bn else '') + ('_bias' if bias else '') + '_wd_' + str(
        weight_decay)
    cas_num = len([one for one in os.listdir(log_path) if one.startswith(model_name)])  # todo: update back
    model_name += '_cas_' + str(cas_num)
    print('model name: ' + model_name)
    return model_name


if __name__ == '__main__':
    # global config
    from vgg_config import *

    # set random seed, device, data type
    setup_seed(args.seed)
    dtype = torch.float
    device = torch.device(args.device_name if torch.cuda.is_available() else "cpu")
    print(device)
    args.dtype = dtype
    args.device = device

    # runtime params
    best_acc = 0
    start_epoch = 0
    args.start_epoch = start_epoch
    args.log_interval = 100
    # practical config for learning scheduler
    momentum = 0.9
    lr_scale = 0.1
    args.lr_interval = [int(ratio * args.num_epoch) for ratio in args.lr_interval]

    model_name = get_model_name('', args)
    args.log_path = os.path.join(args.log_path, model_name)

    if not os.path.exists(args.log_path):
        os.mkdir(args.log_path)
    main(args)
