import os
import sys
import glob

import pickle
import torch
import seaborn as sns
from tensorboardX import SummaryWriter
from mlsuite import experiment
from mlsuite.utils import timed, keyboard_stoppable
from mlsuite.pytorch import fix_seed

from datasets import get_dataloader
from models import build_model
from engine import create_trainer_and_evaluator, decorate_trainer, decorate_evaluator
from tasks import get_tasks

import torch_optimizer

from methods import RotoGrad, GradNorm, GradDrop, pcgrad, RotoGradMagnitude, mgda_ub, imtl_G


def validate_args(args):
    if not hasattr(args, 'seed'):
        args.seed = None

    if not hasattr(args, 'device'):
        args.device = 'cpu'

    if 'cuda' == args.device and not torch.cuda.is_available():
        print('CUDA requested but not available.', file=sys.stderr)
        args.device = 'cpu'

    if 'lenet' in args.model.name:
        if args.dataset.name == 'mnist':
            args.model.update({'shape': (1, 28, 28)})
        elif args.dataset.name == 'svhn':
            args.model.update({'shape': (1, 32, 64)})

    if not hasattr(args.algorithms, 'optimizer'):
        args.algorithms.optimizer = args.training.optimizer
    if not hasattr(args.algorithms, 'learning_rate'):
        args.algorithms.learning_rate = args.training.learning_rate
    if not hasattr(args.algorithms, 'decay'):
        args.algorithms.decay = 1.0

    assert len(getattr(args.tasks, 'names', [])) > 0, 'select some task!'


def get_optimizers(model, tasks, leaders, args):
    optim_dict = {
        'sgd': torch.optim.SGD,
        'adam': torch.optim.Adam,
        'adagrad': torch.optim.Adagrad,
        'rmsprop': torch.optim.RMSprop,
        'radam': torch_optimizer.RAdam,
        'sgd-momentum': lambda *args, **kwargs: torch.optim.SGD(*args, momentum=0.9, weight_decay=5e-4, **kwargs)
    }

    optimizer_model = optim_dict[args.training.optimizer]
    optimizer_leader = optim_dict[args.algorithms.optimizer]

    params_model = [{'params': m.parameters()} for m in model.values()]
    params_leader = [{'params': x.parameters()} for x in leaders]

    optimizers = [optimizer_model(params_model, lr=args.training.learning_rate)]
    schedulers = []

    if args.dataset.name == 'celeba':
        schedulers.append(torch.optim.lr_scheduler.ExponentialLR(optimizers[0], args.training.decay))

    if args.dataset.name == 'cifar10':
        schedulers.append(torch.optim.lr_scheduler.CosineAnnealingLR(optimizers[0], T_max=200))

    if len(leaders) > 0:
        lr = args.algorithms.learning_rate if type(args.algorithms.learning_rate) == float else args.algorithms.learning_rate[0]
        if args.algorithms.learning_rate > 0:
            optimizers.append(optimizer_leader(params_leader, lr=lr))
            schedulers.append(torch.optim.lr_scheduler.ExponentialLR(optimizers[1], args.algorithms.decay))

    return optimizers, schedulers


def number_parameters(model):
    from functools import reduce
    total_params = []
    for m in model.values():
        total_params.append(sum(reduce(lambda a, b: a*b, x.size()) for x in m.parameters()))
    print(f'Total Params: ({sum(total_params)}) {total_params}')


@timed
@keyboard_stoppable
def train(trainer, loader, max_epochs):
    trainer.run(loader, max_epochs=max_epochs)


def load_model_and_leaders(model, leaders, state_dict):
    for k in model.keys():
        model[k].load_state_dict(state_dict[k])

    for i, leader in enumerate(leaders):
        leader.load_state_dict(state_dict[f'_leader_{i}'])


@experiment
def main(args):
    validate_args(args)
    fix_seed(args.seed)

    writer = SummaryWriter('tensorboard') if args.tensorboard else None

    loaders = get_dataloader(args.dataset.name, args.training.batch_size, **args.dataset.options)
    args.model.input_size = loaders['train'].dataset.input_size

    tasks = get_tasks(args.tasks.names, args.tasks.weights, loaders['train'].dataset, args.device)
    model = build_model(args.model.name, tasks, args.model)

    if not hasattr(args.rotograd, 'latent_size'):
        args.rotograd.latent_size = model['rep'].output_size

    leaders, callbacks = [], []
    for method in args.algorithms.methods:
        if method == 'rotograd':
            rotograd = RotoGrad(len(tasks), args.rotograd.latent_size)
            leaders.append(rotograd)
            callbacks.append(rotograd.callback)
            callbacks.append(RotoGradMagnitude(len(tasks), args.training.burn_in_period))

            for i, t_i in enumerate(tasks):
                model[t_i.name] = torch.nn.Sequential(rotograd[i], model[t_i.name])
        elif method == 'gradnorm':
            gradnorm = GradNorm(len(tasks), alpha=args.gradnorm.alpha)
            leaders.append(gradnorm)
            callbacks.append(gradnorm.callback)
        elif method == 'graddrop':
            callbacks.append(GradDrop(len(tasks)).callback)
        elif method == 'pcgrad':
            callbacks.append(pcgrad)
        elif method == 'mgda':
            callbacks.append(mgda_ub)
        elif method == 'imtl-g':
            callbacks.append(imtl_G)
        else:
            raise KeyError

    print(model)
    number_parameters(model)

    for m in model.values():
        m.to(args.device)

    for leader in leaders:
        leader.to(args.device)

    optimizers, schedulers = get_optimizers(model, tasks, leaders, args)
    trainer, evaluator = create_trainer_and_evaluator(model, leaders, callbacks, tasks, optimizers, loaders, args,
                                                      writer)

    decorate_trainer(trainer, tasks,  args, writer, model, schedulers)
    decorate_evaluator(evaluator, tasks, args, writer)

    # Load model and trainer state if they exist already
    latest_has_to_exist = False
    checkpoint_path = glob.glob('checkpoints/state.pkl')
    if len(checkpoint_path) == 1:
        latest_has_to_exist = True
        with open(checkpoint_path[0], 'rb') as f:
            trainer.state = pickle.load(f)

    checkpoint_path = glob.glob('checkpoints/latest_*.pt')  # Load latest if we keep training
    assert len(checkpoint_path) < 2
    if len(checkpoint_path) == 1:
        load_model_and_leaders(model, leaders, torch.load(checkpoint_path[0]))
        # os.rename(checkpoint_path[0], f'{checkpoint_path[0]}.old')  # Move the checkpoint
        print('Latest model loaded.')
    else:
        assert not latest_has_to_exist

    if args.train and (trainer.state is None or not trainer.state.times['COMPLETED']):
        train(trainer, loaders['train'], max_epochs=args.training.epochs)
        if len(checkpoint_path) == 1:
            os.remove(checkpoint_path[0])  # Remove the checkpoint

        print("Latest model metrics")
        evaluator.run(loaders['test'])

    # Load the best model for evaluation
    checkpoint_path = glob.glob('checkpoints/best_*.pt')
    if len(checkpoint_path) > 0:
        checkpoint_path = sorted(checkpoint_path, key=lambda x: float(x[28:x.rfind('.')]), reverse=True)
        load_model_and_leaders(model, leaders, torch.load(checkpoint_path[0]))
        print('Best model loaded.')

        evaluator.run(loaders['test'])
    else:
        print('I could not find a best model to load.')

    if args.plot:
        loaders['train'].dataset.plot(model, tasks, title=str(args.exp_name))

    if writer is not None:
        writer.close()


if __name__ == '__main__':
    sns.set_style('white')
    main()
