import copy
import os
import argparse
import time
import math
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import optimizers

from conf import settings
from utils import get_network, get_training_dataloader, get_test_dataloader, WarmUpLR, \
    most_recent_folder, most_recent_weights, last_epoch, best_acc_weights

import utils


class ModelEMA:
    def __init__(self, net, ema):
        self.ema_net = copy.deepcopy(net)
        self.net = net
        self.ema = ema

    def update(self):
        for v1, v2 in zip(self.ema_net.state_dict().values(), self.net.state_dict().values()):
            v1.copy_(v1 * self.ema + v2 * (1 - self.ema))

    def eval(self):
        self.ema_net.eval()

    def __call__(self, *args, **kwargs):
        return self.ema_net(*args, **kwargs)

class ModelAverage:
    def __init__(self, net, gamma=8.0):
        self.ema_net = copy.deepcopy(net)
        self.net = net
        self.gamma = gamma
        self.t = 1

    def update(self):
        t = self.t
        for v1, v2 in zip(self.ema_net.state_dict().values(), self.net.state_dict().values()):
            if v1.dtype == torch.long:
                v1.copy_(v2)
            else:
                v1.mul_(1 - ((self.gamma + 1) / (self.gamma + t))).add_(v2, alpha=(self.gamma + 1) / (self.gamma + t))
        self.t += 1

    def eval(self):
        self.ema_net.eval()

    def __call__(self, *args, **kwargs):
        return self.ema_net(*args, **kwargs)


def train(epoch):

    start = time.time()
    net.train()

    for batch_index, (images, labels) in enumerate(cifar100_training_loader):

        if epoch <= args.warm and warmup_scheduler is not None:
            warmup_scheduler.step()

        if args.gpu:
            labels = labels.cuda()
            images = images.cuda()

        optimizer.zero_grad()
        outputs = net(images)
        loss = loss_function(outputs, labels)
        loss.backward()

        if args.optim in ['sps', 'pssps']:
            optimizer.step(loss=loss)
        else:
            optimizer.step()
        net_ema.update()

        n_iter = (epoch - 1) * len(cifar100_training_loader) + batch_index + 1

        if args.optim in ['sps', 'pssps']:
            lr = optimizer.state['step_size']
        elif args.optim in ['dog', 'ldog']:
            lr = optimizer.param_groups[0]['eta'][0]
        elif args.optim == 'cocob':
            lr = 0.0
        elif args.optim in ['dasgd']:
            lr = optimizer.param_groups[0]['d'] * optimizer.param_groups[0]['lr'] / optimizer.param_groups[0]['g0_norm']
        elif args.optim in ['daadam', 'prodigy']:
            lr = optimizer.param_groups[0]['d'] * optimizer.param_groups[0]['lr']
        elif args.optim in ['psdasgd']:
            lr = optimizer.param_groups[0]['elr']
        else:
            lr = optimizer.param_groups[0]['lr']

        print('Training Epoch: {epoch:3d} [{trained_samples:5d}/{total_samples:5d}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format(
            loss.item(),
            lr,
            epoch=epoch,
            trained_samples=batch_index * args.b + len(images),
            total_samples=len(cifar100_training_loader.dataset)
        ))

        #update training loss for each iteration
        writer.add_scalar('Train/loss', loss.item(), n_iter)
        writer.add_scalar('Train/lr', lr, n_iter)
        

    finish = time.time()

    print('Epoch {} training time consumed: {:.2f}s'.format(epoch, finish - start))

@torch.no_grad()
def eval_training(epoch=0, tb=True):

    start = time.time()
    net.eval()
    net_ema.eval()

    test_loss = 0.0 # cost function error
    correct = 0.0
    correct_ema = 0.0

    for (images, labels) in cifar100_test_loader:

        if args.gpu:
            images = images.cuda()
            labels = labels.cuda()

        outputs = net(images)
        loss = loss_function(outputs, labels)

        test_loss += loss.item()
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum()

        if net_ema is not None:
            outputs = net_ema(images)
            _, preds = outputs.max(1)
            correct_ema += preds.eq(labels).sum()

    finish = time.time()
    print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f}, Time consumed:{:.2f}s'.format(
        epoch,
        test_loss / len(cifar100_test_loader.dataset),
        correct.float() / len(cifar100_test_loader.dataset),
        finish - start
    ))
    print('')

    #add informations to tensorboard
    if tb:
        writer.add_scalar('Test/Average loss', test_loss / len(cifar100_test_loader.dataset), epoch)
        writer.add_scalar('Test/Accuracy', correct.float() / len(cifar100_test_loader.dataset), epoch)
        writer.add_scalar('Test/EMA_Accuracy', correct_ema.float() / len(cifar100_test_loader.dataset), epoch)

    return correct.float() / len(cifar100_test_loader.dataset)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-net', type=str, required=True, help='net type')
    parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not')
    parser.add_argument('-b', type=int, default=128, help='batch size for dataloader')
    parser.add_argument('-warm', type=int, default=1, help='warm up training phase')
    parser.add_argument('-optim', type=str, default='sgd')
    parser.add_argument('-lr', type=float, default=0.1, help='initial learning rate')
    parser.add_argument('-wd', type=float, default=0., help='weight decay')
    parser.add_argument('-resume', action='store_true', default=False, help='resume training')
    parser.add_argument('-pretrained', type=str, default=None, help='path to pretrained model')
    args = parser.parse_args()
    args.gpu = True

    net = get_network(args)
    device='cuda' if args.gpu else 'cpu'

    if args.pretrained is not None:
        net.load_state_dict(torch.load(args.pretrained)['net'])


    if args.optim in ['dog', 'ldog']:
        net_ema = ModelAverage(net)
    else:
        net_ema = ModelEMA(net, 0.99)

    #data preprocessing:
    cifar100_training_loader = get_training_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=4,
        batch_size=args.b,
        shuffle=True
    )

    cifar100_test_loader = get_test_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=4,
        batch_size=args.b,
        shuffle=True
    )

    loss_function = nn.CrossEntropyLoss()
    if args.optim == 'sgd':
        optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd)
    elif args.optim == 'sps':
        optimizer = optimizers.Sps(net.parameters(), weight_decay=args.wd)
    elif args.optim == 'dog':
        optimizer = optimizers.DoG(net.parameters(), lr=1.0, weight_decay=args.wd)
    elif args.optim == 'dasgd':
        optimizer = optimizers.DAdaptSGD(net.parameters(), lr=1.0, momentum=0.9, weight_decay=args.wd)
    elif args.optim == 'adam':
        optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.wd)
    elif args.optim == 'cocob':
        optimizer = optimizers.COCOB(net.parameters(), weight_decay=args.wd)
    elif args.optim == 'ldog':
        optimizer = optimizers.LDoG(net.parameters(), lr=1.0, weight_decay=args.wd)
    elif args.optim == 'daadam':
        optimizer = optimizers.DAdaptAdam(net.parameters(), lr=1.0, weight_decay=args.wd, decouple=True)
    elif args.optim == 'prodigy':
        optimizer = optimizers.Prodigy(net.parameters(), lr=1.0, weight_decay=args.wd)
    elif args.optim == 'pssps':
        optimizer = optimizers.PSSps(net.parameters(), weight_decay=args.wd)
    elif args.optim == 'psdasgd':
        optimizer = optimizers.PSDASGD(net.parameters(), lr=1.0, weight_decay=args.wd)
    else:
        raise AssertionError
        
    if args.optim in ['sps', 'dog', 'cocob', 'ldog', 'pssps']:
        train_scheduler, warmup_scheduler = None, None
    else:
        iter_per_epoch = len(cifar100_training_loader)
        milestones = settings.MILESTONES
        train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.2) #learning rate decay
        warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)

    args.net = f'{args.optim}_{args.net}'

    if args.resume:
        recent_folder = most_recent_folder(os.path.join(settings.CHECKPOINT_PATH, args.net), fmt=settings.DATE_FORMAT)
        if not recent_folder:
            raise Exception('no recent folder were found')

        checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder)

    else:
        checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW)

    #since tensorboard can't overwrite old values
    #so the only way is to create a new tensorboard log
    writer = SummaryWriter(log_dir=os.path.join(checkpoint_path, 'tb_logs'))
    # input_tensor = torch.Tensor(1, 3, 32, 32)
    # if args.gpu:
    #     input_tensor = input_tensor.cuda()
    # writer.add_graph(net, input_tensor)

    #create checkpoint folder to save model
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')

    best_acc = 0.0
    if args.resume:
        best_weights = best_acc_weights(os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder))
        if best_weights:
            weights_path = os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder, best_weights)
            print('found best acc weights file:{}'.format(weights_path))
            print('load best training file to test acc...')
            net.load_state_dict(torch.load(weights_path))
            best_acc = eval_training(tb=False)
            print('best acc is {:0.2f}'.format(best_acc))

        recent_weights_file = most_recent_weights(os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder))
        if not recent_weights_file:
            raise Exception('no recent weights file were found')
        weights_path = os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder, recent_weights_file)
        print('loading weights file {} to resume training.....'.format(weights_path))
        net.load_state_dict(torch.load(weights_path))

        resume_epoch = last_epoch(os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder))


    for epoch in range(1, settings.EPOCH + 1):
        if epoch > args.warm and train_scheduler is not None:
            train_scheduler.step()

        if args.resume:
            if epoch <= resume_epoch:
                continue

        train(epoch)
        acc = eval_training(epoch)

        if not epoch % settings.SAVE_EPOCH:
            weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='regular')
            print('saving weights file to {}'.format(weights_path))
            torch.save(net.state_dict(), weights_path)

    writer.close()