import functools
import warnings
from dataclasses import dataclass
from typing import Optional, Dict, Callable, Union, Iterable

import torch
from ignite.engine import Engine, Events
from ignite.handlers import Timer
from ignite.metrics import MetricsLambda, Accuracy, Average
from ignite.utils import convert_tensor
from tqdm import tqdm

from XXX.uib import information_quantities as iq
from XXX.uib.losses import very_approx_regularizers
from XXX.uib.modules.categorical_entropies_summarizer import CategoricalEntropiesSummarizer
from XXX.uib.modules.continuous_encoding_summarizer import (
    Estimate,
    ContinuousLatentLabelEntropiesSummarizer,
)

from XXX.uib.modules.summarizer import EntropySummarizer, IqBase
from experiments.datasets import DataLoaders
from experiments.dynamics.dynamics import Dynamics, Output
from experiments.utils import experiment_YYY
from experiments.utils.ignite_output import IgniteOutput
from experiments.utils.ignite_progress_bar import ignite_progress_bar
from experiments.utils.p_correct import PCorrect

import operator


@dataclass
class EarlyExitCriterion:
    training_accuracy_threshold: float
    epoch_percentage: int


def _prepare_batch(batch, device=None, non_blocking=False):
    """Prepare batch for training: pass to a device with options.

    """
    x, y = batch
    return (
        convert_tensor(x, device=device, non_blocking=non_blocking),
        convert_tensor(y, device=device, non_blocking=non_blocking),
    )


def no_grad_wrapper(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        with torch.no_grad():
            return func(*args, **kwargs)

    return wrapper


def create_supervised_dynamics_trainer(
    dynamics, device=None, non_blocking=False, metrics=None, prepare_batch=_prepare_batch
):
    """
    Factory function for creating a trainer for supervised models.

    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_Y)`.
        output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
            to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.

    Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
        of the processed batch by default.

    Returns:
        Engine: a trainer engine with supervised update function.
    """
    metrics = metrics or {}

    def _update(engine, batch):
        dynamics.model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        output: Output = dynamics.fit(x, y)

        return IgniteOutput(
            y.to(device=output.prediction.device),
            output.prediction,
            output.latent,
            output.loss.item(),
            output.cross_entropies,
        )

    engine = Engine(_update)
    add_epoch_duration(engine)

    for name, metric in metrics.items():
        metric.attach(engine, name)

    return engine


def create_supervised_dynamics_evaluator(
    dynamics: Dynamics, train_dataloader, metrics=None, device=None, non_blocking=False, prepare_batch=_prepare_batch
):
    """
    Factory function for creating an evaluator for supervised models.

    Args:
        model (`torch.nn.Module`): the model to train.
        metrics (dict of str - :class:`~ignite.metrics.Metric`): a map of metric names to Metrics.
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_Y)`.
        output_transform (callable, optional): function that receives 'x', 'y', 'y_pred' and returns value
            to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
            output expected by metrics. If you change it you should use `output_transform` in metrics.

    Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is
        a tuple of `(batch_pred, batch_Y)` by default.

    Returns:
        Engine: an evaluator engine with supervised inference function.
    """
    metrics = metrics or {}

    def pre_validation(engine):
        dynamics.model.eval()
        dynamics.pre_validation(train_dataloader)

    def _inference(engine, batch):
        with torch.no_grad():
            x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
            output = dynamics.predict(x, y)
            return IgniteOutput(
                y.to(device=output.prediction.device),
                output.prediction,
                output.latent,
                output.loss,
                output.cross_entropies,
            )

    engine = Engine(_inference)

    add_epoch_duration(engine)
    engine.add_event_handler(Events.EPOCH_STARTED, pre_validation)

    for name, metric in metrics.items():
        metric.attach(engine, name)

    return engine


def str_metric(metric):
    if isinstance(metric, Estimate):
        return str(metric)
    else:
        return f"{float(metric):.4}"


def str_metrics(metrics):
    return "    " + "\n    ".join(
        f"{key}: {str_metric(metrics[key])}" for key in sorted(metrics.keys()) if "iq_base" not in key
    )


def install_metric(engine, name, getter: Callable):
    metric = MetricsLambda(getter)
    metric.attach(engine, name)


def install_entropy_summarizer_metrics(summarizer: EntropySummarizer, engine, prefix, iqs: Dict[str, torch.Tensor]):
    iq_base: Optional[IqBase] = None

    @engine.on(Events.EPOCH_COMPLETED)
    def compute_base(engine):
        nonlocal iq_base
        iq_base = summarizer.get_iq_base()

    def get_iq_base():
        return iq_base

    install_metric(engine, f"{prefix}iq_base", get_iq_base)

    def get_iq_value(iq):
        def getter():
            result = iq_base.get_iq_estimate(iq)
            return result

        return getter

    for name, iq in iqs.items():
        install_metric(engine, f"{prefix}{name}", get_iq_value(iq))


def default_latent_getter(engine):
    output: IgniteOutput = engine.state.output
    return output.z


def install_summarizer(summarizer, engine, *, stochastic, latent_getter: Callable = None):
    if latent_getter is None:
        latent_getter = lambda: default_latent_getter(engine)

    @engine.on(Events.EPOCH_STARTED)
    def epoch_started(_):
        summarizer.reset()

    @engine.on(Events.ITERATION_COMPLETED)
    def update_summarizer(_):
        output: IgniteOutput = engine.state.output
        latent = latent_getter()
        if latent is not None:
            if not stochastic:
                latent = latent[:, None, ...]
            latent = latent.flatten(2)
            summarizer.fit(latent, output.y)
        else:
            warnings.warn("update_summarizer but no Z available!")


def get_trainer_epoch(trainer):
    return trainer.state.epoch if trainer.state else 0


def log_training_pbar(trainer, log_interval):
    ignite_progress_bar(trainer, lambda engine: f"Training [{get_trainer_epoch(trainer)}]", log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        metrics = trainer.state.metrics
        tqdm.write(f"Training Results - Epoch: {get_trainer_epoch(trainer)}\n{str_metrics(metrics)}")


def log_evaluator_pbar(trainer, evaluator, title: str, log_interval):
    ignite_progress_bar(evaluator, lambda engine: f"{title} [{get_trainer_epoch(trainer)}]", log_interval)

    @evaluator.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        metrics = evaluator.state.metrics
        tqdm.write(f"{title} Results - Epoch: {get_trainer_epoch(trainer)}\n{str_metrics(metrics)}")


def ensure_float(dict):
    return {key: float(value) for key, value in dict.items()}


def install_early_exit_criterion(trainer: Engine, early_exit_criterion: EarlyExitCriterion, store):
    store["early_exit_criterion"] = early_exit_criterion

    @trainer.on(Events.EPOCH_COMPLETED)
    def check_early_exit_criterion(engine):
        current_training_accuracy = trainer.state.metrics["accuracy"]
        current_epoch_percentage = trainer.state.epoch * 100 / trainer.state.max_epochs
        if current_epoch_percentage >= early_exit_criterion.epoch_percentage:
            if current_training_accuracy < early_exit_criterion.training_accuracy_threshold:
                trainer.terminate()
                print(f"Early exit criterion hit: {early_exit_criterion}!")


def store_metrics(engine: Engine, prefix: str, store):
    epochs = experiment_YYY.create_log_epochs(store, prefix)

    @engine.on(Events.EPOCH_COMPLETED)
    def save_training_results(engine):
        metrics = engine.state.metrics
        epochs.append(metrics)


def UnwrapTensor(inner_metric):
    def unwrap_tensor(inner_metric: torch.Tensor):
        return inner_metric.item() if isinstance(inner_metric, torch.Tensor) else inner_metric

    return MetricsLambda(unwrap_tensor, inner_metric)


def create_default_metrics(
    *, incl_loss: bool, stochastic: bool, mean_l2_squared: bool, covariance_traces: bool, minibatch_entropies
):
    metrics = {
        # NOTE: DONT EVER USE RunningAverage here!!!
        "accuracy": Accuracy(IgniteOutput.get_y_pred_y),
        "correct_prob": UnwrapTensor(PCorrect()),
        "xe_decoder": UnwrapTensor(Average(output_transform=IgniteOutput.get_decoder_cross_entropy)),
        "xe_prediction": UnwrapTensor(Average(output_transform=IgniteOutput.get_prediction_cross_entropy)),
    }
    if mean_l2_squared:
        metrics.update(
            {
                "mean_squared_Z": UnwrapTensor(Average(output_transform=IgniteOutput.get_mean_squared_z(stochastic))),
                # "mean_squared_mean_Z__X": UnwrapTensor(
                #     Average(output_transform=IgniteOutput.get_mean_squared_mean_z_given_x(stochastic))
                # ),
            }
        )
    if covariance_traces:
        metrics.update(
            {
                "covariance_trace_z": UnwrapTensor(
                    Average(
                        output_transform=IgniteOutput.get_covariance_trace(
                            very_approx_regularizers.covariance_trace, stochastic
                        )
                    )
                ),
                "covariance_trace_Z__X": UnwrapTensor(
                    Average(
                        output_transform=IgniteOutput.get_covariance_trace(
                            very_approx_regularizers.covariance_trace_given_X, stochastic
                        )
                    )
                ),
                "covariance_trace_Z__Y": UnwrapTensor(
                    Average(
                        output_transform=IgniteOutput.get_covariance_trace(
                            very_approx_regularizers.covariance_trace_given_Y, stochastic
                        )
                    )
                ),
                "covariance_trace_mean_Z__X": UnwrapTensor(
                    Average(
                        output_transform=IgniteOutput.get_covariance_trace(
                            very_approx_regularizers.covariance_trace_mean_by_X, stochastic
                        )
                    )
                ),
            }
        )
    if minibatch_entropies:
        metrics.update(
            {
                "mb_H_Z": UnwrapTensor(
                    Average(
                        output_transform=IgniteOutput.get_entropy_estimate(
                            very_approx_regularizers.covariance_trace, stochastic
                        )
                    )
                ),
                "mb_global_H_Z__X": UnwrapTensor(
                    Average(
                        output_transform=IgniteOutput.get_entropy_estimate(
                            very_approx_regularizers.covariance_trace_given_X, stochastic
                        )
                    )
                ),
                "mb_global_H_Z__Y": UnwrapTensor(
                    Average(
                        output_transform=IgniteOutput.get_entropy_estimate(
                            very_approx_regularizers.covariance_trace_given_Y, stochastic
                        )
                    )
                ),
                "mb_H_mean_Z__X": UnwrapTensor(
                    Average(
                        output_transform=IgniteOutput.get_entropy_estimate(
                            very_approx_regularizers.covariance_trace_mean_by_X, stochastic
                        )
                    )
                ),
                "mb_H_Z__X": UnwrapTensor(
                    Average(
                        output_transform=IgniteOutput.call_z_y(
                            very_approx_regularizers.estimate_entropy_Z__X(stochastic=stochastic)
                        )
                    )
                ),
                "mb_H_Z__Y": UnwrapTensor(
                    Average(
                        output_transform=IgniteOutput.call_z_y(
                            very_approx_regularizers.estimate_entropy_Z__Y(stochastic=stochastic)
                        )
                    )
                ),
            }
        )
    if incl_loss:
        metrics["loss"] = UnwrapTensor(Average(output_transform=IgniteOutput.get_loss))
    return metrics


def evaluate_after_training(trainer, evaluator, dataloader, seed):
    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate_each_epoch(engine):
        evaluator.run(dataloader, seed=seed)


def interpret_summarizer_arg(eval_type, summarizer_arg):
    if isinstance(summarizer_arg, dict):
        return summarizer_arg[eval_type]
    return summarizer_arg


def run_common_experiment(
    *,
    dynamics,
    dataloaders: DataLoaders,
    max_epochs,
    in_capacity,
    out_capacity,
    log_interval,
    store,
    seed,
    stochastic=False,
    discrete_summary=False,
    continuous_summary: Union[bool, dict] = False,
    device=None,
    mean_l2_squared=False,
    covariant_traces=False,
    minibatch_entropies=False,
    early_exit_criterion: EarlyExitCriterion = None,
    extra_hooks: Callable = None,
    train_eval: bool = False,
    zero_eval: bool = False,
    validation_lr_schedulers=None,
    train_lr_schedulers=None,
    extra_eval_dynamics: Optional[Dict[str, Dynamics]]=None,
    test_eval: bool = True,
):
    extra_eval_dynamics = extra_eval_dynamics or {}

    trainer, evaluator, train_evaluator, validator = create_trainer_evaluator(
        dataloaders,
        dynamics,
        stochastic,
        mean_l2_squared=mean_l2_squared,
        covariance_traces=covariant_traces,
        minibatch_entropies=minibatch_entropies,
        device=device,
        train_eval=train_eval,
    )

    # TODO: add extra evals for the validation set as well...
    extra_test_evals = {name: create_supervised_dynamics_evaluator(
            dynamics,
            dataloaders.test,
            device=device,
            metrics=create_default_metrics(
                incl_loss=True,
                stochastic=stochastic,
                mean_l2_squared=mean_l2_squared,
                covariance_traces=covariant_traces,
                minibatch_entropies=minibatch_entropies,
            ),
        ) for name, dynamics in extra_eval_dynamics.items()}

    if train_eval:
        extra_train_evals = {name: create_supervised_dynamics_evaluator(
            dynamics,
            dataloaders.train_eval,
            device=device,
            metrics=create_default_metrics(
                incl_loss=True,
                stochastic=stochastic,
                mean_l2_squared=mean_l2_squared,
                covariance_traces=covariant_traces,
                minibatch_entropies=minibatch_entropies,
            ),
        ) for name, dynamics in extra_eval_dynamics.items()}
    else:
        extra_train_evals = {}

    if early_exit_criterion:
        install_early_exit_criterion(trainer, early_exit_criterion, store)

    install_common_summarizers(
        evaluator,
        in_capacity,
        out_capacity,
        discrete_summary=discrete_summary,
        continuous_summary=interpret_summarizer_arg("evaluator", continuous_summary),
        stochastic=stochastic,
    )

    for name, extra_eval in extra_test_evals:
        install_common_summarizers(
            extra_eval,
            in_capacity,
            out_capacity,
            discrete_summary=discrete_summary,
            continuous_summary=interpret_summarizer_arg(f"{name}_evaluator", continuous_summary),
            stochastic=stochastic,
        )

    if validator:
        install_common_summarizers(
            validator,
            in_capacity,
            out_capacity,
            discrete_summary=discrete_summary,
            continuous_summary=interpret_summarizer_arg("validator", continuous_summary),
            stochastic=stochastic,
        )

    if train_evaluator:
        install_common_summarizers(
            train_evaluator,
            in_capacity,
            out_capacity,
            discrete_summary=discrete_summary,
            continuous_summary=interpret_summarizer_arg("train_evaluator", continuous_summary),
            stochastic=stochastic,
        )

        for name, extra_eval in extra_test_evals:
            install_common_summarizers(
                extra_eval,
                in_capacity,
                out_capacity,
                discrete_summary=discrete_summary,
                continuous_summary=interpret_summarizer_arg(f"{name}_train_evaluator", continuous_summary),
                stochastic=stochastic,
            )

    if extra_hooks:
        extra_hooks(trainer=trainer, evaluator=evaluator, train_evaluator=train_evaluator, validator=validator)

    install_schedulers(trainer, train_lr_schedulers)

    if validator:
        install_schedulers(validator, validation_lr_schedulers)
    else:
        assert not validation_lr_schedulers

    store_metrics(trainer, "training", store)

    store_metrics(evaluator, "test", store)
    for name, extra_eval in extra_test_evals:
        store_metrics(extra_eval, f"{name}_test", store)

    if train_evaluator:
        store_metrics(train_evaluator, "train_eval", store)
        for name, extra_eval in extra_train_evals:
            store_metrics(extra_eval, f"{name}_train_eval", store)

    if validator:
        store_metrics(validator, "validation", store)

    # Install progress bars.
    log_training_pbar(trainer, log_interval)
    log_evaluator_pbar(trainer, evaluator, "Test Set", log_interval)
    if validator:
        log_evaluator_pbar(trainer, validator, "Validation Set", log_interval)
    if train_evaluator:
        log_evaluator_pbar(trainer, train_evaluator, "Training Set Eval", log_interval)

    # Set up execution chain.
    if train_evaluator:
        evaluate_after_training(trainer, train_evaluator, dataloaders.train_eval, seed + 1)

    if validator:
        evaluate_after_training(trainer, validator, dataloaders.validation, seed + 2)

    if test_eval:
        evaluate_after_training(trainer, evaluator, dataloaders.test, seed + 3)

    if zero_eval:
        if train_evaluator:
            train_evaluator.run(dataloaders.train_eval, seed=seed + 1)

        if validator:
            validator.run(dataloaders.validation, seed=seed + 2)

        evaluator.run(dataloaders.test, seed=seed + 3)

    trainer.run(dataloaders.train, max_epochs=max_epochs, seed=seed + 1)


def install_schedulers(engine, schedulers):
    if isinstance(schedulers, Iterable):
        for s in schedulers:
            lr_step_after_epoch(engine, s)
    elif schedulers:
        lr_step_after_epoch(engine, schedulers)


def create_trainer_evaluator(
    dataloaders: DataLoaders,
    dynamics,
    stochastic,
    minibatch_entropies,
    covariance_traces,
    mean_l2_squared,
    device=None,
    train_eval: bool = False,
):
    trainer = create_supervised_dynamics_trainer(
        dynamics,
        device=device,
        metrics=create_default_metrics(
            incl_loss=True,
            stochastic=stochastic,
            mean_l2_squared=mean_l2_squared,
            covariance_traces=covariance_traces,
            minibatch_entropies=minibatch_entropies,
        ),
    )

    evaluator = create_supervised_dynamics_evaluator(
        dynamics,
        dataloaders.train,
        device=device,
        metrics=create_default_metrics(
            incl_loss=True,
            stochastic=stochastic,
            mean_l2_squared=mean_l2_squared,
            covariance_traces=covariance_traces,
            minibatch_entropies=minibatch_entropies,
        ),
    )

    if train_eval:
        train_evaluator = create_supervised_dynamics_evaluator(
            dynamics,
            dataloaders.train_eval,
            device=device,
            metrics=create_default_metrics(
                incl_loss=True,
                stochastic=stochastic,
                mean_l2_squared=mean_l2_squared,
                covariance_traces=covariance_traces,
                minibatch_entropies=minibatch_entropies,
            ),
        )
    else:
        train_evaluator = None

    if dataloaders.validation:
        validator = create_supervised_dynamics_evaluator(
            dynamics,
            dataloaders.validation,
            device=device,
            metrics=create_default_metrics(
                incl_loss=True,
                stochastic=stochastic,
                mean_l2_squared=mean_l2_squared,
                covariance_traces=covariance_traces,
                minibatch_entropies=minibatch_entropies,
            ),
        )
    else:
        validator = None
    return trainer, evaluator, train_evaluator, validator


def add_epoch_duration(trainer):
    trainer_timer = Timer()
    trainer_timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )
    install_metric(trainer, "epoch_duration", trainer_timer.value)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, trainer_timer.reset)


def install_common_summarizers(
    evaluator,
    in_capacity,
    out_capacity,
    *,
    discrete_summary: bool,
    continuous_summary: Union[bool, dict],
    stochastic: bool,
):
    if discrete_summary:
        summarizer = CategoricalEntropiesSummarizer(in_capacity, out_capacity)
        install_common_summarizer_iqs(summarizer, evaluator, stochastic=stochastic, prefix="discrete_")

    if continuous_summary is not False:
        # TODO: merge the two classes...
        if isinstance(continuous_summary, dict):
            continuous_summarizer = ContinuousLatentLabelEntropiesSummarizer(**continuous_summary)
        else:
            continuous_summarizer = ContinuousLatentLabelEntropiesSummarizer()
        install_common_summarizer_iqs(continuous_summarizer, evaluator, stochastic=stochastic, prefix="continuous_")


def install_common_summarizer_iqs(summarizer, trainer, *, stochastic: bool, prefix=""):
    install_summarizer(summarizer, trainer, stochastic=stochastic)
    install_entropy_summarizer_metrics(
        summarizer,
        trainer,
        prefix,
        {
            "decoder_uncertainty": iq.decoder_uncertainty,
            "encoding_entropy": iq.encoding_entropy,
            "preserved_information": iq.preserved_information,
            "H_Z__X": iq.H_Z__X,
            "H_Z__Y": iq.H_Z__Y,
            "H_Z": iq.H_Z,
        },
    )


def lr_step_after_epoch(trainer, scheduler):
    if isinstance(scheduler, ReduceLROnPlateauWrapper):

        @trainer.on(Events.EPOCH_COMPLETED)
        def lr_step(engine):
            scheduler.step(engine)

    elif isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        warnings.warn(
            "Use ReduceLROnPlateauWrapper instead of torch.optim.lr_scheduler.ReduceLROnPlateau!", DeprecationWarning
        )

        @trainer.on(Events.EPOCH_COMPLETED)
        def lr_step(engine):
            scheduler.step(engine.state.output.loss)

    elif isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts):

        @trainer.on(Events.GET_BATCH_COMPLETED)
        def lr_step(engine):
            scheduler.step(engine.state.epoch + engine.state.iteration / engine.state.epoch_length)

    else:

        @trainer.on(Events.EPOCH_COMPLETED)
        def lr_step(engine):
            scheduler.step()


class ReduceLROnPlateauWrapper(torch.optim.lr_scheduler.ReduceLROnPlateau):
    output_transform: Optional[Callable]
    metrics_transform: Optional[Callable]

    def __init__(
        self,
        optimizer,
        *,
        output_transform: Optional[Callable] = None,
        metrics_transform: Optional[Callable] = None,
        mode="min",
        factor=0.1,
        patience=10,
        verbose=False,
        threshold=1e-4,
        threshold_mode="rel",
        cooldown=0,
        min_lr=0,
        eps=1e-8,
    ):
        super().__init__(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
        self.metrics_transform = metrics_transform
        self.output_transform = output_transform
        assert self.metrics_transform or self.output_transform

    # noinspection PyMethodOverriding
    def step(self, engine: Engine):
        if self.output_transform:
            super().step(self.output_transform(engine.state.output))
        elif self.metrics_transform:
            super().step(self.metrics_transform(engine.state.metrics))
        else:
            raise NotImplementedError()
