import os
import copy
import wandb
from spaghettini import quick_register, Configurable
from abc import ABC, abstractmethod
from collections.abc import Iterable

import torch
import numpy as np
import matplotlib.pyplot as plt

from src.utils.misc import stdlog, is_scalar
from src.dl.fixed_point_solvers.fixed_point_iterator import fixed_point_iterator


class LoggingFunction(ABC):
    def __init__(self, called_during="training", **kwargs):
        self.initial_kwargs = kwargs
        self.initial_kwargs.update(dict(called_during=called_during))

    @abstractmethod
    def __call__(self, metric_logs, pl_system, **kwargs):
        pass


@quick_register
class LogGenericTrainingState(LoggingFunction):
    def __call__(self, metric_logs, pl_system, **kwargs):
        metric_logs["epoch"] = float(pl_system.current_epoch)

        return metric_logs


@quick_register
class LogOptimizerStats(LoggingFunction):
    def __call__(self, metric_logs, pl_system, **kwargs):
        opt = pl_system.optimizers()
        if opt == list():  # Optimizers have not been loaded.
            stdlog(__file__, f"Optimizer is not loaded. Skipping optimizer stats logging. ")
            return metric_logs

        curr_lr = opt.param_groups[0]["lr"]
        metric_logs["learning_rate"] = curr_lr

        return metric_logs


@quick_register
class LogClassificationMetrics(LoggingFunction):
    def __call__(self, metric_logs, pl_system, **kwargs):
        preds = kwargs["preds"]
        ys = kwargs["ys"]
        total_loss = kwargs["total_loss"]

        accuracy = float((preds.argmax(dim=1) == ys).float().mean())
        error = 1.0 - accuracy
        metric_logs["error"] = error

        classification_loss = pl_system.loss_fn(preds, ys)

        # Log loss.
        metric_logs["loss/classification"] = float(classification_loss)
        metric_logs["loss/total"] = float(total_loss)
        metric_logs['loss/regularization'] = float(total_loss) - float(classification_loss)

        return metric_logs


@quick_register
class LogLossMetrics(LoggingFunction):
    def __call__(self, metric_logs, pl_system, **kwargs):
        total_loss = kwargs["total_loss"]

        # Log loss.
        metric_logs["loss/total"] = float(total_loss)

        return metric_logs


@quick_register
class LogMetricModelLogs(LoggingFunction):
    """Used for logging the scalar metrics outputted by the model inside the dict model_logs."""
    def __call__(self, metric_logs, pl_system, **kwargs):
        model_logs = kwargs["model_logs"]

        # Add the scalar model metrics in metric logs.
        for k, v in model_logs.items():
            if is_scalar(v):
                prefixed_key = self.append_model_logs_prefix_to_key(k)
                metric_logs[prefixed_key] = float(v)

        return metric_logs

    @staticmethod
    def append_model_logs_prefix_to_key(key):
        return f"model_logs_{key}"


@quick_register
class LogInAndOODSplits(LoggingFunction):
    """Log which splits are in and out of distribution."""
    def __call__(self, metric_logs, pl_system, **kwargs):
        model_logs = kwargs["model_logs"]

        model_logs["in_dist_splits"] = str(pl_system.in_distribution_splits)
        model_logs["ood_splits"] = str(pl_system.ood_splits)

        return metric_logs


@quick_register
class LogAccuracyWithDifferentForwardDepths(LoggingFunction):
    """Log the accuracy obtained with different forward depths."""
    def __call__(self, metric_logs, pl_system, **kwargs):
        if "test" in kwargs["prepend_key"]:
            pass
        elif "validation" not in kwargs["prepend_key"] or kwargs["batch_idx"] != 0:
            return metric_logs
        elif "dataloader_idx" in kwargs:
            if kwargs["dataloader_idx"] != 0:
                return metric_logs
        print(f"Running logger {self.__class__.__name__}")

        all_forward_iters = self.initial_kwargs["all_forward_iters"]
        loaders = [pl_system.valid_loader] if not isinstance(pl_system.valid_loader, Iterable) \
            else pl_system.valid_loader

        # Save the state of modified attributes, so that we can revert the modification later.
        original_forward_solver = copy.deepcopy(pl_system.model.forward_solver)
        original_pretraining_mode = copy.deepcopy(pl_system.model_kwargs_getter)
        mode = pl_system.training

        # Put the model in eval mode.
        pl_system.eval()

        # Run the accuracy computation.
        results = dict()
        for curr_loader in loaders:
            results_per_loader = dict()
            # Take a batch from each validation loader.
            for i, batch in enumerate(curr_loader):
                curr_xs, curr_ys = batch
                curr_xs = pl_system.reconcile_input_and_model_types(tensor=curr_xs)
                curr_ys = pl_system.reconcile_input_and_model_types(tensor=curr_ys)
                break

            # Set up the number of forward iterations.
            curr_length = curr_xs.shape[-1]
            print(f"Processing batch with length {curr_length}")
            for num_forward_iter in all_forward_iters:
                pl_system.model_kwargs_getter.initial_kwargs['num_pretraining_steps'] = 0
                if "threshold" in pl_system.model.forward_solver.initial_kwargs:
                    pl_system.model.forward_solver.initial_kwargs["threshold"] = num_forward_iter
                else:
                    assert "num_iters" in pl_system.model.forward_solver.initial_kwargs
                    pl_system.model.forward_solver.initial_kwargs["num_iters"] = num_forward_iter

                # Run forward pass.
                with torch.no_grad():
                    outs, _ = pl_system.forward(inputs=curr_xs)

                # Compute sequence level error.
                mean_correctness_per_example = (torch.argmax(outs, dim=1) == curr_ys).float().mean(axis=-1)
                mean_correctness_along_examples = (mean_correctness_per_example == 1.).float().mean()
                mean_error_along_examples = 1. - mean_correctness_along_examples

                metric_logs[f"seq_error/seq_error_with_{num_forward_iter}_forward_steps_length_{curr_length}"] = float(mean_error_along_examples)
                results_per_loader[num_forward_iter] = mean_error_along_examples
            results[curr_length] = results_per_loader

        # ____ Produce a plot and log it. ____
        # Draw the matplotlib plot.
        fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False, figsize=(10, 10))
        for length, results_per_length in results.items():
            items = results_per_length.items()
            items = sorted(items, key=lambda x: x[0])
            xs = [k for k, v in items]
            ys = [v.item() for k, v in items]
            axs[0, 0].plot(xs, ys, "o-", label=f"length: {length}")
        axs[0, 0].set_xlabel("num_forward_iters")
        axs[0, 0].set_ylabel("error")
        axs[0, 0].set_title(f"Error vs. Number of Forward Iters\nBatch Size: {curr_xs.shape[0]}\nGlobal Step: {kwargs['global_step']}")
        axs[0, 0].legend()

        # If requested, save.
        if "save_locally" in kwargs:
            if kwargs["save_locally"]:
                plot_name = "accuracy_per_number_of_forward_iterations.png"
                plot_save_path = os.path.join(kwargs["save_dir"], plot_name)
                plt.savefig(plot_save_path)

        # Convert it into numpy array and log it.
        fig.canvas.draw()
        data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        wandb.log({f"{kwargs['prepend_key']}accuracy_per_num_forward_iters": wandb.Image(data),
                   "global_step": kwargs["global_step"]})
        plt.close('all')

        # Revert the original solver.
        pl_system.model.forward_solver = original_forward_solver
        pl_system.model_kwargs_getter = original_pretraining_mode
        pl_system.training = mode

        return metric_logs


@quick_register
class LogEvolutionOfJacobianNorm(LoggingFunction):
    """Log how the """
    def __call__(self, metric_logs, pl_system, **kwargs):
        if "validation" not in kwargs["prepend_key"] or kwargs["batch_idx"] != 0:
            return metric_logs
        if "dataloader_idx" in kwargs:
            if kwargs["dataloader_idx"] != 0:
                return metric_logs
        print(f"Running logger {self.__class__.__name__}")

        # Save the state of modified attributes, so that we can revert the modification later.
        original_forward_solver = copy.deepcopy(pl_system.model.forward_solver)
        original_backward_solver = copy.deepcopy(pl_system.model.backward_solver)
        original_pretraining_mode = copy.deepcopy(pl_system.model_kwargs_getter)
        original_z0_init_method = copy.deepcopy(pl_system.model.z0_init_method)

        # Get the loaders.
        loaders = [pl_system.valid_loader] if not isinstance(pl_system.valid_loader, Iterable) \
            else pl_system.valid_loader

        fig, axs = plt.subplots(nrows=1, ncols=self.initial_kwargs["num_examples"], figsize=(8, 8), squeeze=False)
        for loader_idx, curr_loader in enumerate(loaders):
            # Take a batch from each validation loader.
            for i, batch in enumerate(curr_loader):
                curr_xs, curr_ys = batch
                break
            curr_length = curr_xs.shape[-1]

            all_grad_norms = list()

            for curr_example_idx in range(self.initial_kwargs["num_examples"]):
                xs = curr_xs[curr_example_idx:curr_example_idx+1]

                # ____ Run inference on the datapoint. ____
                num_forward_iters = self.initial_kwargs["forward_iters"]
                all_positions = self.initial_kwargs["sequence_positions"]
                num_classes = self.initial_kwargs["num_classes"]

                grad_norms = list()
                pls = pl_system
                for num_forward_iter in num_forward_iters:
                    print(f"Loader idx: {loader_idx}, num_forward_iter: {num_forward_iter}")

                    # Prepare for providing external zs0.
                    pls.model.z0_init_method = "external"
                    pls.model.forward_solver = Configurable(fixed_point_iterator, num_iters=-1)
                    pls.model.backward_solver = None

                    # Set the number of forward pass steps.
                    pls.model_kwargs_getter.initial_kwargs['num_pretraining_steps'] = 0
                    if "threshold" in pls.model.forward_solver.initial_kwargs:
                        pls.model.forward_solver.initial_kwargs["threshold"] = num_forward_iter
                    else:
                        assert "num_iters" in pls.model.forward_solver.initial_kwargs
                        pls.model.forward_solver.initial_kwargs["num_iters"] = num_forward_iter

                    # Construct the Jacobian.
                    jac_tensor = [[None for p in all_positions] for c in range(num_classes)]
                    for pos_idx, pos in enumerate(all_positions):
                        for cls_idx in range(num_classes):
                            # Zero grad.
                            pls.model.zero_grad()

                            # Run forward pass.
                            external_zs0 = torch.zeros(size=(1, self.initial_kwargs["z_init_dim"], xs.shape[-1]), requires_grad=True)
                            external_zs0.retain_grad = True
                            pls.model.zero_grad()
                            with torch.enable_grad():
                                outs, model_dict = pls(xs, external_zs0=external_zs0)

                                # Compute sensitivity of output at position (cls_idx, pos).
                                outs[0, cls_idx, pos].backward()

                            # Keep track of the computed sensitivity values.
                            curr_grad_np = external_zs0.grad.numpy()

                            jac_tensor[cls_idx][pos_idx] = curr_grad_np[0]
                    jac_tensor_np = np.array(jac_tensor)

                    # Log the jac tensor Frobenius norms.
                    grad_norms.append(np.linalg.norm(jac_tensor_np))
                all_grad_norms.append(grad_norms)

                axs[0, curr_example_idx].plot(num_forward_iters, grad_norms, '-o', markersize=10, label=f"length: {curr_length}")
                axs[0, curr_example_idx].set_yscale('log')
                axs[0, curr_example_idx].set_xscale('linear')
                axs[0, curr_example_idx].set_xlabel("Number of unroll steps")
                axs[0, curr_example_idx].set_ylabel(f"Frobenius norm of Jacobian")
                axs[0, curr_example_idx].set_title(f"Sensitivity of the Logits \n wrt. State Initialization")
        plt.legend()
        plt.tight_layout()

        # Log the plot.
        fig.canvas.draw()
        data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        wandb.log({"sensitivity_plot": wandb.Image(data)})

        # Revert the changes made to the pl_system.
        pl_system.model.forward_solver = original_forward_solver
        pl_system.model.backward_solver = original_backward_solver
        pl_system.model_kwargs_getter = original_pretraining_mode
        pl_system.model.z0_init_method = original_z0_init_method

        return metric_logs
