from typing import List, Optional

import matplotlib.pyplot as plt
import numpy as np
import scipy
import seaborn as sns
import torch
import torch.nn as nn


class AbstractWeights(nn.Module):
    def __init__(self):
        super(AbstractWeights, self).__init__()

    def plot(
        self,
        *args,
        ax: Optional[plt.Axes] = None,
        models: Optional[List[str]] = None,
        figsize: Optional[tuple] = None,
        color: Optional = sns.color_palette()[1],
        **kwargs,
    ) -> None:
        if ax is None:
            fig, ax = plt.subplots(figsize=figsize)
            styling = True
        else:
            styling = False

        weights = self().detach().numpy()
        shp = weights.shape
        wp = "m"
        if shp[1] > 1:
            wp += "i"
        if shp[2] > 1:
            wp += "t"
        if shp[3] > 1:
            wp += "q"
        assert len(wp) <= 2, "Too complex `weights_per`; can only handle tuples"

        if wp == "mi":
            mat = weights[0, :, 0, 0, :]
        if wp == "mq":
            mat = weights[0, 0, 0, :, :]
        if wp == "mt":
            mat = weights[0, 0, :, 0, :]
        if wp == "m":
            mat = weights[0, :, 0, 0, :]

        if "cmap" not in kwargs:
            kwargs["cmap"] = sns.blend_palette(["white", color], as_cmap=True)
        ax.imshow(
            mat.T,
            aspect="auto",
            interpolation="none",
            **kwargs,
        )
        ax.tick_params(left=False, bottom=False)
        ax.set_xticks([])
        ax.set_yticks([])
        if styling:
            labels = dict(m="", mi="Item", mq="Quantile", mt="Time")
            ax.set_xlabel(labels[wp])
            if models is not None:
                ax.set_yticks(
                    np.arange(len(models)),
                    labels=models,
                    # fontsize=6,
                    # rotation=90,
                )

        if styling:
            fig.tight_layout()


class SimpleWeights(AbstractWeights):
    def __init__(self, *shape, weight_transform="softmax"):
        super(SimpleWeights, self).__init__()
        if weight_transform not in ["softmax", "exp", "norm", "square", "abs", "softplus", "elu", None]:
            raise ValueError(f"Unknown weight_transform {weight_transform}")
        self.weight_transform = weight_transform
        _w = torch.softmax(torch.zeros(1, *shape, dtype=torch.float64), axis=-1)
        if weight_transform == "exp":
            self.w = nn.Parameter(torch.log(_w))
        elif weight_transform == "square":
            self.w = nn.Parameter(torch.abs(_w) ** (1 / 2))
        else:
            self.w = nn.Parameter(_w)
        self.across_quantile = len(shape) == 5

    def forward(self):
        if self.weight_transform == "softmax":
            if self.across_quantile:
                return self.w.softmax(axis=-1) / self.w.shape[-2]
            return self.w.softmax(axis=-1)
        elif self.weight_transform == "exp":
            return torch.exp(self.w)
        elif self.weight_transform == "abs":
            return torch.abs(self.w)
        elif self.weight_transform == "softplus":
            return torch.nn.functional.softplus(self.w)
        elif self.weight_transform == "elu":
            return torch.nn.functional.elu(self.w)
        elif self.weight_transform == "square":
            return self.w**2
        elif self.weight_transform == "norm":
            if self.across_quantile:
                return self.w / self.w.sum(axis=-1, keepdim=True) / self.w.shape[-2]
            return self.w / self.w.sum(axis=-1, keepdim=True)
        else:
            return self.w


class LowRankWeights(AbstractWeights):
    def __init__(self, n_items, n_times, n_quantiles, n_models, rank, weight_transform="softmax"):
        super(LowRankWeights, self).__init__()
        if weight_transform not in ["softmax", None]:
            raise ValueError(f"Unknown weight_transform {weight_transform}")
        self.weight_transform = weight_transform
        N, H, Q, M, R = n_items, n_times, n_quantiles, n_models, rank

        # initfun = torch.ones
        # initfun = lambda *args, **kwargs: torch.ones(*args, **kwargs) + 0.1 * torch.randn(*args, **kwargs)
        # initfun = lambda *args, **kwargs: torch.randn(*args, **kwargs)
        initfun = lambda *args, **kwargs: torch.ones(*args, **kwargs)
        if N != 1:
            self.w_items = nn.Parameter(initfun(1, N, 1, 1, 1, R, dtype=torch.float64))
        else:
            self.w_items = torch.ones(1, 1, 1, 1, 1, 1, dtype=torch.float64)

        if H != 1:
            self.w_times = nn.Parameter(initfun(1, 1, H, 1, 1, R, dtype=torch.float64))
        else:
            self.w_times = torch.ones(1, 1, 1, 1, 1, 1, dtype=torch.float64)

        if Q != 1:
            self.w_quantiles = nn.Parameter(initfun(1, 1, 1, Q, 1, R, dtype=torch.float64))
        else:
            self.w_quantiles = torch.ones(1, 1, 1, 1, 1, 1, dtype=torch.float64)

        if M != 1:
            self.w_models = nn.Parameter(initfun(1, 1, 1, 1, M, R, dtype=torch.float64))
        else:
            self.w_models = torch.ones(1, 1, 1, 1, 1, 1, dtype=torch.float64)

    def forward(self):
        w = torch.mean(self.w_items * self.w_times * self.w_quantiles * self.w_models, axis=-1)
        if self.weight_transform == "softmax":
            return w.softmax(axis=-1)
        else:
            return w


class ClusterWeights(LowRankWeights):
    def _get_individual_weights(
        self,
        noise=1e-10,
        temp=200,
    ):
        wm = self.w_models
        if self.weight_transform == "softmax":
            wm = wm.softmax(axis=4)

        wi = self.w_items
        wi = wi + noise * torch.randn_like(wi)
        wi = wi.div(temp).softmax(axis=-1)

        wt = self.w_times
        wt = wt + noise * torch.randn_like(wt)
        wt = wt.div(temp).softmax(axis=-1)

        wq = self.w_quantiles
        wq = wq + noise * torch.randn_like(wq)
        wq = wq.div(temp).softmax(axis=-1)
        return wi, wt, wq, wm

    def entropy_regularizer(self):
        wi, wt, wq, wm = self._get_individual_weights(noise=0)
        reg = 0
        for w in (wi, wt, wq):
            reg += -torch.sum(w * w.clamp(min=1e-20).log(), axis=-1).mean()
        return reg

    def forward(self):
        wi, wt, wq, wm = self._get_individual_weights()
        w = torch.sum(wi * wt * wq * wm, axis=-1)
        return w


class LinearStackerRegressor(nn.Module):
    def __init__(
        self,
        model_prediction_shape: tuple,
        n_output_quantiles: int = None,
        weights_per: str = "m",
        rank=None,
        clusters=None,
        weight_transform="softmax",
    ):
        super(LinearStackerRegressor, self).__init__()
        self.weights_per = weights_per
        if n_output_quantiles is None:
            # assume we have as many output quantiles as input quantiles
            n_output_quantiles = model_prediction_shape[3]
        weight_shape = self._compute_weights_shape(model_prediction_shape, n_output_quantiles)
        if rank is not None:
            self.weight_model = LowRankWeights(*weight_shape, self.rank, weight_transform=weight_transform)
        elif clusters is not None:
            self.weight_model = ClusterWeights(*weight_shape, self.clusters, weight_transform=weight_transform)
        else:
            self.weight_model = SimpleWeights(*weight_shape, weight_transform=weight_transform)

    def _compute_weights_shape(self, model_predictions_shape: tuple, n_output_quantiles: int):
        n_folds, n_items, n_times, n_quantiles, n_models = model_predictions_shape
        if "qq" not in self.weights_per and n_output_quantiles != n_quantiles:
            s = (
                "The linear regressor can only work with different numbers of input"
                "and output quantiles if it is a 'qq'-type model."
            )
            raise ValueError(s)
        weights_shape = []  # same weights for each fold
        for str, val in (("i", n_items), ("t", n_times), ("q", n_quantiles)):
            weights_shape.append(val if str in self.weights_per else 1)
        if "qq" in self.weights_per:
            weights_shape.insert(-1, n_output_quantiles)
            self.across_quantiles = True
        else:
            self.across_quantiles = False
        weights_shape.append(n_models)  # always use different weights per model
        return weights_shape

    def forward(self, model_predictions: torch.tensor):
        # model_predictions is of shape: N_FOLDS, N_ITEMS, HORIZON, N_QUANTILES, N_MODELS
        weights = self.weight_model()
        if not self.across_quantiles:
            pred = torch.einsum("fihqm,fihqm -> fihq", weights, model_predictions)
        else:
            pred = torch.einsum("fihqcm,fihcm -> fihq", weights, model_predictions)
        return pred
