import torch
import torch.nn as nn
from datetime import datetime
import pandas as pd
from ignite.metrics import Loss as MLoss
from ignite.metrics import Accuracy as MAccuracy
from ignite.metrics import TopKCategoricalAccuracy as MTopKAccuracy
import numpy as np
import copy
from pathlib import Path
from numba import njit


class ModelCheckpoint():
    """ Object Used for saving the model's weights to disk

    Parameters:
    -----------
    filepath : str
        Path to save the best model file (other models will be saved relative to this path)
    model : nn.Module
        Model to save
    rewind : int
        The epoch number for which to save weights
    period : int
        Interval (number of epochs) between checkpoints saves on disk
        (use this if slow filesystem or interconnect to avoid bottlenecks)
    save_from : int
        From which epoch number can we start saving the model (used during adaptive pruning)
    verbose : int
        Wether to print messages on the best model's statistics
    """

    def __init__(self, filepath, model, rewind=-1, period=1, save_from=0, verbose=0):
        self.verbose = verbose
        self.model = model
        self.filepath = filepath

        self.best_score = -np.Inf
        self.best_weights = None

        self.period = period
        self.epochs_since_last_save = 0
        self.save_from = save_from

        self.rewind = rewind
        self.rewind_weights = None
        if self.rewind == -1:  # Save initial weights
            self.rewind_weights = copy.deepcopy(self.model.state_dict())

        self.save_suffix(suffix="init")

    def update(self, epoch, current_score):
        self.epochs_since_last_save += 1

        if epoch == self.rewind:
            self.rewind_weights = copy.deepcopy(self.model.state_dict())
            self.save_suffix(suffix="rewind")

        if epoch < self.save_from:
            return False

        if current_score >= self.best_score:
            if self.verbose > 0:
                print('\nEpoch %05d: score improved from %0.5f to %0.5f,'
                      % (epoch + 1, self.best_score,
                         current_score))
            self.best_score = current_score
            self.best_weights = copy.deepcopy(self.model.state_dict())
        else:
            if self.verbose > 0:
                print('\nEpoch %05d: score did not improve from %0.5f' %
                      (epoch + 1, self.best_score))

        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            self.save()
            return True
        return False

    def save(self):
        if self.best_weights is not None:
            torch.save(self.best_weights, self.filepath)

    def save_suffix(self, suffix="last"):
        torch.save(
            copy.deepcopy(self.model.state_dict()),
            self.filepath.parent / f"{self.filepath.stem}_{suffix}.pt"
        )

    def save_last(self):
        self.save_suffix()
        self.last_weights = copy.deepcopy(self.model.state_dict())

    def reload_best(self):
        if self.best_weights is not None:
            self.model.load_state_dict(self.best_weights)
        else:
            raise AttributeError("No best weights were saved")

    def get_rewind(self):
        if self.rewind==-2:
            return self.best_weights
        if self.rewind_weights is not None:
            return self.rewind_weights
        else:
            raise AttributeError("No rewind weights were saved")


class LabelSmoothing(nn.Module):
    """
    NLL loss with label smoothing.
    """

    def __init__(self, smoothing=0.1):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)

        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()


class TrainConfig():
    def __init__(self, **kwargs):
        self.params = {
            "n_epochs": 100,
            "accumulation": 1,  # TODO: unsupported for now
            "loss": torch.nn.CrossEntropyLoss(),
            "device": torch.device(
                "cuda" if torch.cuda.is_available() else "cpu"
            ),
            "logdir": "logs/{:%Y-%m-%dT%H%M%S}".format(datetime.now()),
            "metrics": {
                "acc": MAccuracy(),
                "acc-5": MTopKAccuracy(k=5)
            },
            "monitor": "loss",
            "oneshot": None,
            **kwargs
        }
        # Make logging directory
        if "logdir_suffix" in self.params:
            self.params["logdir"] += self.params["logdir_suffix"]
        self.params["logdir"] = Path(self.params["logdir"])
        self.logdir.mkdir(parents=True, exist_ok=True)

        # Add loss to metrics
        if type(self.loss) == str:
            if self.loss == "smooth":
                print("Using label smoothed loss")
                self.params["loss"] = LabelSmoothing()
            else:
                raise ValueError("Unsupported loss type")
        self.metrics["loss"] = MLoss(self.loss)

        # Save model
        mval = self.monitor

        def monitor_fx():
            ret = self.metrics[mval].compute()
            if mval in {"loss"}:
                ret = -ret  # loss should decrease
            return ret

        self.monitor = monitor_fx

    def __getattr__(self, name):
        if name in self.params:
            return self.params[name]
        else:
            raise AttributeError

    @torch.no_grad()
    def _update_metrics(self, outputs, targets):
        for _, metric in self.metrics.items():
            metric.update((outputs, targets))

    @torch.no_grad()
    def _reset_metrics(self, fhandle, index=0):
        # Average out metrics
        metrics_dict = {}
        for name, metric in self.metrics.items():
            metrics_dict[name] = metric.compute()
            metric.reset()

        # Print to file
        print_header = (index == 0)
        df = pd.DataFrame(metrics_dict, index=[index])
        df.to_csv(fhandle, header=print_header)
        fhandle.flush()


class BasicCallbacks():
    """
    A Callback framework to extend the basic training loop without duplicating the actual
    training code
    """
    def on_iter_start(self, epoch, iteration):
        pass

    def on_iter_end(self, epoch, iteration):
        pass

    def on_epoch_start(self, epoch):
        pass

    def on_epoch_end(self, epoch):
        pass

    def on_model_change(self):
        pass


# Helper functions for pruning TODO: make this a filter function with yield to avoid future mistakes!
def ok_layer(lname):
    """ Checks wether layer name confirms to pruning
    """
    return "weight" in lname and "batchnorm" not in lname


@torch.no_grad()
def compute_threshold_by_layer(weights, prune_perc, pruning_method):
    """
    Computes the layer-wise threshold for every layer

    Parameters
    ----------
    weights : dict
        The weights that need processing
    prune_perc : float
        Global pruning percentage
    pruning_method : str
        One of the following:
        - l1 : global l1 pruning
        - l1_layer : local l1 pruning
        - l1_normalized : global l1 localized by layer
        - lamp : (based on the according paper)

    Returns
    -------
    A dictionnary with a maximal value threshold for every layer
    """
    process_lweight = lambda xx: xx.flatten()
    if "_structured" in pruning_method:
        pruning_method = pruning_method.replace("_structured", "")
        process_lweight = (lambda xx: xx.flatten() if len(xx.shape) < 4 else np.repeat(np.abs(xx).mean(axis=(-2,-1)).flatten(), xx.shape[-1]*xx.shape[-2]) )

    threshold_per_layer = {}
    if pruning_method == "l1":
        # Flatten weights
        flat_model_weights = np.array([])
        for n, data in weights.items():
            if not ok_layer(n):
                continue
            layer_weights = process_lweight(data.cpu().numpy())
            flat_model_weights = np.concatenate((
                flat_model_weights, layer_weights
            ))

        # Set a global percentage of weights to zero
        th = simpler_percentile(
            np.abs(flat_model_weights), prune_perc  # 100*prune_perc, interpolation="lower"
        )

        for n in weights.keys():
            if not ok_layer(n):
                continue
            threshold_per_layer[n] = th

    elif pruning_method == "l1_gpu":
        # Flatten weights
        flat_model_weights = []
        for n, data in weights.items():
            if not ok_layer(n):
                continue
            layer_weights = data.flatten()
            flat_model_weights.append(
                layer_weights
            )

        flat_model_weights = torch.cat(flat_model_weights).abs()
        # Set a global percentage of weights to zero
        th = simpler_gpu_percentile(flat_model_weights, prune_perc)

        for n in weights.keys():
            if not ok_layer(n):
                continue
            threshold_per_layer[n] = th

    elif pruning_method == "l1_layer":
        for n, data in weights.items():
            if not ok_layer(n):
                continue
            layer_weights = np.abs(process_lweight(data.cpu().numpy()))
            # Set a local percentage of weights to zero
            th = simpler_percentile(
                layer_weights, prune_perc  # 100*prune_perc, interpolation="lower"
            )
            threshold_per_layer[n] = th

    elif pruning_method == "l1_normalized":
        # Flatten normalized weights
        std_per_layer = {}
        flat_model_weights = np.array([])
        for n, data in weights.items():
            if not ok_layer(n):
                continue
            layer_weights = process_lweight(data.cpu().numpy())
            std = np.std(layer_weights)
            flat_model_weights = np.concatenate((
                flat_model_weights, layer_weights/std   # Normalize by std
            ))
            std_per_layer[n] = std

        # Set a global percentage of weights to zero
        th = simpler_percentile(
            np.abs(flat_model_weights), prune_perc  # 100*prune_perc, interpolation="lower"
        )

        for n in std_per_layer.keys():
            threshold_per_layer[n] = th*std_per_layer[n]

    elif pruning_method == "lamp":
        # Flatten lamps scores per layer
        lamp_score_per_layer = {}
        weight_per_layer = {}
        flat_model_weights = np.array([])
        for n, data in weights.items():
            if not ok_layer(n):
                continue
            layer_weights = np.flip(np.sort(np.abs(process_lweight(data.cpu().numpy()))))
            weight_per_layer[n] = layer_weights
            layer_weights_squared = layer_weights**2
            lamp_score = layer_weights_squared/cumsum(layer_weights_squared)
            lamp_score_per_layer[n] = lamp_score
            flat_model_weights = np.concatenate((
                flat_model_weights, lamp_score
            ))

        # Set a global percentage of weights to zero
        th = simpler_percentile(
            flat_model_weights, prune_perc  # 100*prune_perc, interpolation="lower"
        )

        for n in weight_per_layer.keys():
            # The following finds the weight threshold for each layer, it's the
            # same as: layer_th = np.max(weight_per_layer[n][lamp_score_per_layer[n] <= th])
            idx = arg_find_th_sorted(np.flip(lamp_score_per_layer[n]), th)
            if idx < 0:
                layer_th = -1
            else:
                layer_th = np.flip(weight_per_layer[n])[idx]

            if prune_perc == 0:
                layer_th = -1  # For safety's sake
            threshold_per_layer[n] = layer_th

    elif pruning_method == "lamp_gpu":
        # Flatten lamps scores per layer
        lamp_score_per_layer = {}
        weight_per_layer = {}
        flat_model_weights = []
        for n, data in weights.items():
            if not ok_layer(n):
                continue
            layer_weights = torch.flip(torch.sort(data.flatten().abs())[0], (0,))
            weight_per_layer[n] = layer_weights
            layer_weights_squared = layer_weights**2
            lamp_score = layer_weights_squared/torch.cumsum(layer_weights**2, 0)
            del layer_weights_squared
            lamp_score_per_layer[n] = lamp_score
            flat_model_weights.append(
                lamp_score
            )

        # Set a global percentage of weights to zero
        th = simpler_gpu_percentile(
            torch.cat(flat_model_weights), prune_perc  # 100*prune_perc, interpolation="lower"
        )

        for n in weight_per_layer.keys():
            selected_weights  = weight_per_layer[n][lamp_score_per_layer[n] <= th]
            if len(selected_weights) == 0 or prune_perc == 0:
                layer_th = -1
            else:
                layer_th = torch.max(selected_weights)
            threshold_per_layer[n] = layer_th

    return threshold_per_layer


@torch.no_grad()
def update_masks(weights, masks, prune_perc, pruning_method):
    threshold_per_layer = compute_threshold_by_layer(weights, prune_perc, pruning_method)
    process_lweight = lambda xx: xx.abs()
    if "_structured" in pruning_method:
        process_lweight = (lambda xx: xx.abs() if len(xx.shape) < 4 else xx.abs().mean(axis=(-2,-1)))
    for n, p in weights.items():
        if not ok_layer(n):
            continue
        th = threshold_per_layer[n]
        masks[n][process_lweight(p) <= th] = 0
        masks[n][process_lweight(p) > th] = 1


@torch.no_grad()
def apply_threshold_to_model(model, threshold_per_layer, prune_perc, force=False):
    for lname, layer in model.named_modules():
        if not hasattr(layer, "th_prune"):
            continue
        for pname, _ in layer.named_parameters():
            if pname == "weight":
                name = lname + "." + pname
                if name in threshold_per_layer:
                    if force:
                        layer.force_prune(threshold_per_layer[name], prune_perc)
                    else:
                        layer.set_prune(threshold_per_layer[name], prune_perc)

# A faster Cumulative sum implementation for LAMP pruning
@njit
def cumsum(x):
    """ Same as np.cumsum on 1d np.ndarrays but 2x faster on smaller arrays
    """
    y = np.zeros(len(x), dtype=x.dtype)
    y[0] = x[0]
    for i in range(1, len(x)):
        y[i] = y[i-1]+x[i]
    return y


# A faster way to find the correct index for LAMP pruning
@njit
def arg_find_th_sorted(x, th):
    """ Finds the last index i at which x[i] is <= th

    Parameters
    ----------
    x: 1d np.ndarray
        A sorted (in increasing order) 1d np array
    th: float
        Threshold to find

    Returns
    -------
    The last index at which the values of x are <= to th.
    x[j] <= th  if j <= index
    x[j] >  th  otherwise

    If the index < 0 -> no element in the array is <= th

    """
    for i in range(len(x)):
        if x[i] > th:
            return i-1
    return len(x)-1


def simpler_percentile(x, th):
    return np.percentile(x, th*100)  # FIXME -> put this back into the code


def simpler_gpu_percentile(x, th):
    sorted_weights, _ = torch.sort(x)
    idx = (len(sorted_weights)-1)*th
    idx1 = int(np.floor(idx))
    idx2 = int(np.ceil(idx))

    e1, e2 = sorted_weights[idx1], sorted_weights[idx2]
    perc = e1 + (e2-e1)*(idx-idx1)
    return perc.item()
