import collections.abc
import functools
from typing import Type, Union

import torch
import torch.utils.data
from sacred import Ingredient

from data.dataset import collate_fn
from optim.loss import TorchLossWrapper, Loss
from optim.trainer import Trainer, CheckpointHook, EarlyStoppingHook
from utils.torch_utils import run_deterministic, run_fast
from utils.utils import objspec2constructor
from .experiment_functions import make_experiment_tempfile, open_artifact_from_run

training_ingredient = Ingredient('training')


@training_ingredient.config
def config():
    # General Training parameters
    optimizer = {
        'class': torch.optim.Adam,
        'args': dict(
            lr=1e-3,
            # Note: setting weight decay with Adam slows down the optimizer considerably after a
            # certain number of steps on CPU
            # see https://github.com/pytorch/pytorch/issues/58220
            weight_decay=0
        )
    }
    batch_size = 128
    batch_dim = 0
    num_workers = 0
    scheduler = {
        'class': torch.optim.lr_scheduler.MultiStepLR,
        'args': dict(
            milestones=[20],
            gamma=0.1,
        )
    }
    epochs = 70
    device = 'cpu'
    loss = torch.nn.MSELoss  # can also be a list of different loss classes, if you need more than one loss
    trainer = Trainer
    trainer_hooks = [
        # A list of tuples. First element should be the event, second an objspec.
        # ('post_validation', EarlyStoppingHook)
    ]
    checkpoint_interval = 15
    deterministic = False


def instantiate_loss(loss: Union[str, Loss, Type[Loss], torch.nn.modules.loss._Loss, Type[torch.nn.modules.loss._Loss]]) \
        -> Loss:
    if isinstance(loss, Loss):
        return loss
    if isinstance(loss, torch.nn.modules.loss._Loss):
        return TorchLossWrapper(loss)

    # It was not a loss object, so we have to instantiate it
    loss = objspec2constructor(loss)()

    if not isinstance(loss, Loss):
        # Wrap the standard pytorch classes to support multiple inputs
        loss = TorchLossWrapper(loss)

    return loss


@training_ingredient.capture
def train_model(_run, model, train_ds, val_ds, optimizer, batch_size, batch_dim, num_workers, scheduler, epochs, device,
                loss, trainer, trainer_hooks, checkpoint_interval, deterministic):
    # set_seed(seed)

    if deterministic:
        run_deterministic()
    else:
        run_fast()

    # Load training data
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, shuffle=True,
                                               collate_fn=collate_fn(batch_dim))
    val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers, shuffle=False,
                                             collate_fn=collate_fn(batch_dim))

    optimizer = objspec2constructor(optimizer)
    scheduler = objspec2constructor(scheduler)
    trainer = objspec2constructor(trainer)

    trainer = trainer(train_loader, val_loader, optimizer, scheduler, device=device, checkpoints=False,
                      batch_dimension=batch_dim)

    # Checkpoint after every epoch
    checkpoints = CheckpointHook(checkpoint_interval=checkpoint_interval,
                                 file_write_fn=functools.partial(make_experiment_tempfile, run=_run),
                                 file_read_fn=functools.partial(open_artifact_from_run, run=_run))
    trainer.add_hook(checkpoints, 'post_validation')

    # Add additional hooks
    for event, hook in trainer_hooks:
        trainer.add_hook(objspec2constructor(hook)(), event)

    if not isinstance(loss, collections.abc.Sequence):
        loss = [loss]

    # Create the loss objects
    loss = [instantiate_loss(l) for l in loss]

    trainer.train(model, loss, epochs, log_fn=_run.log_scalar)

    return trainer


@training_ingredient.capture
def get_dataloader(dataset, batch_size, num_workers, batch_dim):

    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True,
                                       collate_fn=collate_fn(batch_dim))

