from functools import partial
import itertools

import pickle
import torch
import torch.optim

import geotorch

from ignite.engine import Engine, Events
from ignite.metrics import Average, RunningAverage
from ignite.handlers import TerminateOnNan, ModelCheckpoint
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor

import matplotlib.pyplot as plt
from plot import setup_grid, plot_toy


def create_trainer(model, leaders, callbacks, tasks, optims, loaders, args, writer=None):
    zt = []
    zt_task = {t.name: [] for t in tasks}

    if args.dataset.name == 'dummy':
        lim = 2.5
        lims = [[-lim, lim], [-lim, lim]]
        grid = setup_grid(lims, 1000)

    def trainer_step(engine, batch):
        is_time = lambda rate: (engine.state.iteration - 1) % rate == 0

        # contextual variables
        accum_data = args.dataset.name in ['dummy', 'fair'] or 'mnist' in args.dataset.name

        for m in model.values():
            m.train()

        for t in leaders:
            t.train()

        for optim in optims:
            optim.zero_grad()

        # Batch data
        x, y = batch
        x = convert_tensor(x.float(), args.device)
        y = [convert_tensor(y_, args.device) for y_ in y]

        training_loss = 0.
        losses, grads = {}, {}

        # Intermediate representation
        with geotorch.parametrize.cached():
            rep = model['rep'](x)
            if args.dataset.name == 'dummy':
                zt.append(rep.detach().clone())

            for task_i in tasks:
                # Copy and detach the intermediate representation
                rep_i = rep.detach().clone()
                rep_i.requires_grad = True

                y_hat = model[task_i.name](rep_i)
                loss_i = task_i.loss(y_hat, y[task_i.index])

                if args.dataset.name == 'dummy':
                    loss_i = loss_i.mean(dim=0)
                    zt_task[task_i.name].append(y_hat.detach().clone())

                # Normalize loss functions (definition of balanced learning)
                if args.training.normalize:
                    if engine.state.initial_loss[task_i.name] is None or engine.state.iteration == args.training.burn_in_period:
                        engine.state.initial_loss[task_i.name] = loss_i.item()

                        if args.dataset.name == 'mnist' and task_i.name == 'density':  # special case
                            engine.state.initial_loss[task_i.name] = loaders['train'].dataset.target[-1].mean().item()

                    loss_i = loss_i / engine.state.initial_loss[task_i.name]

                # Track losses
                losses[task_i.name] = loss_i.item()
                training_loss += loss_i.item() * task_i.weight_original  # use original weight to compare

                # Compute and store gradients
                loss_i.backward()
                grads[task_i.name] = rep_i.grad.detach().clone()

                rep_i.grad.detach_()
                rep_i.grad.zero_()

            list_losses = [losses[t.name] for t in tasks]
            list_grads = [grads[t.name] for t in tasks]

            for callback in callbacks:
                list_grads = callback(list_losses, list_grads, rep)

            grads = {t.name: g for t, g in zip(tasks, list_grads)}

        if args.dataset.name == 'dummy' and engine.state.epoch == engine.state.max_epochs:
            fig = plot_toy(grid, model, tasks, [zt, *zt_task.values()], grads,
                           trainer.state.iteration - 1, levels=20, lims=lims,
                           title=str(args.exp_name))
            fig.savefig(f'plots/step_{engine.state.iteration - 1}.png')
            plt.show()
            plt.close(fig)

        # Calculate grad_z and backpropagate
        output_grad = sum(list_grads)
        rep.backward(output_grad)

        # Run the optimizers
        for optim in optims:
            optim.step()

        with torch.no_grad():
            if is_time(args.log_every):
                for task_i in tasks:
                    engine.state.cos_sim[task_i.name].append(
                        torch.cosine_similarity(output_grad.flatten(), grads[task_i.name].flatten(), dim=0)
                    )

            if args.tensorboard:
                for task_i in tasks:
                    if is_time(args.log_every):
                        writer.add_scalar(f'cos_sim_{task_i.name}',
                                          torch.cosine_similarity(output_grad.flatten(), grads[task_i.name].flatten(),
                                                                  dim=0),
                                          engine.state.iteration)

                    if accum_data:
                        engine.state.metrics[f'norm_{task_i.name}'].append(grads[task_i.name].norm(p=2, dim=-1))
                        engine.state.metrics[f'cos_sim_{task_i.name}'].append(
                            torch.cosine_similarity(output_grad, grads[task_i.name], dim=1))

                if args.dataset.name == 'dummy':
                    for task_i in tasks:
                        d = model[task_i.name][0].R.size(0)
                        diff_i = torch.norm(torch.eye(d) - model[task_i.name][0].R)
                        writer.add_scalar(f'diff_R_{task_i.name}', diff_i, engine.state.iteration)

        return training_loss, losses

    trainer = Engine(trainer_step)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    @trainer.on(Events.STARTED)
    def create_variables(engine):
        if not hasattr(engine.state, 'initial_loss'):
            engine.state.initial_loss = dict.fromkeys([t.name for t in tasks], None)

        if not hasattr(engine.state, 'val_metric'):
            engine.state.val_metric = {t.name: [] for t in tasks}

        if not hasattr(engine.state, 'cos_sim'):
            engine.state.cos_sim = {t.name: [] for t in tasks}

        if not hasattr(engine.state, 'val_loss'):
            engine.state.val_loss = []

    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'loss')
    for task_i in tasks:
        output_transform = partial(lambda name, x: x[1][name], task_i.name)
        RunningAverage(output_transform=output_transform).attach(trainer, f'train_{task_i.name}')

    if args.tensorboard:
        @trainer.on(Events.ITERATION_COMPLETED(every=args.log_every))
        def write_metrics(trainer):
            metrics = trainer.state.metrics

            writer.add_scalar('loss', metrics['loss'], trainer.state.iteration)
            for i, task_i in enumerate(tasks):
                name = task_i.name
                writer.add_scalar(f'train_{name}', metrics[f'train_{name}'], trainer.state.iteration)

                w = task_i.weight_original
                for leader in leaders:
                    if hasattr(leader, 'weight'):  # gradnorm
                        w = leader.weight[i].detach()
                        break

                writer.add_scalar(f'weight_{name}', w, trainer.state.iteration)

    pbar = ProgressBar()
    pbar.attach(trainer, 
                metric_names=['loss'] + ([f'train_{t.name}' for t in tasks] if not args.short_pbar else []))

    # Validation
    validator = create_evaluator(model, leaders, tasks, args)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validator(trainer):
        validator.run(loaders['val'])
        metrics = validator.state.metrics
        loss = 0.
        for task_i in tasks:
            if args.tensorboard:
                writer.add_scalar(f'val_loss_{task_i.name}_{task_i.loss.__name__}',
                                  metrics[f'loss_{task_i.name}'], trainer.state.epoch)
                writer.add_scalar(f'val_metric_{task_i.name}_{task_i.metric.__name__}',
                                  metrics[f'metric_{task_i.name}'], trainer.state.epoch)

            trainer.state.val_metric[task_i.name].append(metrics[f'metric_{task_i.name}'])

            loss += metrics[f'loss_{task_i.name}'] * task_i.weight_original

        trainer.state.val_loss.append(loss)
        trainer.state.metrics['val_loss'] = loss
        if args.tensorboard:
            writer.add_scalar(f'val_loss', loss, trainer.state.epoch)

    # Checkpoints
    model_checkpoint = {k: v for k, v in model.items()}
    model_checkpoint.update({f'_leader_{i}': v for i, v in enumerate(leaders)})

    handler = ModelCheckpoint('checkpoints', 'latest', require_empty=False)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=args.training.save_every), handler, model_checkpoint)

    @trainer.on(Events.EPOCH_COMPLETED(every=args.training.save_every))
    def save_state(engine):
        with open('checkpoints/state.pkl', 'wb') as f:
            pickle.dump(engine.state, f)

    @trainer.on(Events.COMPLETED(every=args.training.save_every))
    def save_state(engine):
        with open('checkpoints/state.pkl', 'wb') as f:
            pickle.dump(engine.state, f)

    handler = ModelCheckpoint('checkpoints', 'best', require_empty=False,
                              score_function=(lambda e: -e.state.metrics['val_loss'])
                              )
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=args.training.save_every), handler, model_checkpoint)
    trainer.add_event_handler(Events.COMPLETED, handler, model_checkpoint)

    return trainer


def create_evaluator(model, leaders, tasks, args):
    for m in model.values():
        m.to(args.device)

    @torch.no_grad()
    def evaluator_step(engine, batch):
        for m in model.values():
            m.eval()

        for t in leaders:
            t.eval()

        x, y = batch
        x = x.to(args.device)
        y = [_y.to(device=args.device) for _y in y]

        rep = model['rep'](x)
        losses, preds = {}, {}
        for task_i in tasks:
            rep_i = model[task_i.name](rep)

            losses[f'loss_{task_i.name}'] = task_i.loss(rep_i, y[task_i.index]).mean(dim=0)
            losses[f'metric_{task_i.name}'] = task_i.metric(rep_i, y[task_i.index]).mean(dim=0)

            preds[task_i.name] = rep_i.detach().clone()

        return losses, y, preds

    evaluator = Engine(evaluator_step)
    for task_i in tasks:
        for prefix in ['metric', 'loss']:
            name = f'{prefix}_{task_i.name}'
            output_transform = partial(lambda name, x: x[0][name], name)
            Average(output_transform=output_transform).attach(evaluator, name)

    return evaluator


def create_trainer_and_evaluator(model, leaders, callbacks, tasks, optim, loaders, args, writer):
    trainer = create_trainer(model, leaders, callbacks, tasks, optim, loaders, args, writer)
    evaluator = create_evaluator(model, leaders, tasks, args)

    return trainer, evaluator


def decorate_trainer(trainer, tasks, args, writer, model, schedulers):
    @trainer.on(Events.ITERATION_COMPLETED)
    def apply_schedulers(engine):
        for i, sched in enumerate(schedulers):
            if i == 0 and args.dataset.name == 'celeba':
                if (engine.state.iteration - 1) % 2400 == 0:
                    sched.step()
            else:
                sched.step()

    if args.tensorboard:
        if args.dataset.name in ['dummy', 'fair'] or 'mnist' in args.dataset.name:
            for task_i in tasks:
                @trainer.on(Events.EPOCH_STARTED)
                def reset_buffers(engine):
                    for task_i in tasks:
                        engine.state.metrics[f'cos_sim_{task_i.name}'] = []
                        engine.state.metrics[f'norm_{task_i.name}'] = []

                @trainer.on(Events.EPOCH_COMPLETED, task_i.name)
                def plot_distribution(engine, name):
                    # Cosine similarity histogram
                    values = torch.cat(engine.state.metrics[f'cos_sim_{name}'], dim=0)
                    writer.add_histogram(f'cos sim hist {name}', values, engine.state.epoch)
                    writer.add_scalar(f'cosine sim median {name}', values.median(dim=0).values, engine.state.epoch)

                    # Norm histogram
                    values = torch.cat(engine.state.metrics[f'norm_{name}'], dim=0)
                    writer.add_histogram(f'norm hist {name}', values, engine.state.epoch)
                    writer.add_scalar(f'norm2 median {name}', values.median(dim=0).values, engine.state.epoch)


def decorate_evaluator(evaluator, tasks, args, writer):
    @evaluator.on(Events.COMPLETED)
    def log_metrics(evaluator):
        rep_tasks = list(itertools.chain.from_iterable(itertools.repeat(x, 2) for x in tasks))
        metrics = {f'{k[7:]}_{t.metric.__name__}': v for (k, v), t in zip(evaluator.state.metrics.items(), rep_tasks) if
                   k.startswith('metric_')}

        print(f'Metrics: {metrics}')
