from abc import ABC
from functools import reduce
from numbers import Integral, Real

import numpy as np
import torch
from sklearn.base import ClusterMixin, BaseEstimator
from sklearn.metrics import pairwise_distances
from sklearn.utils import check_array, check_random_state
from sklearn.utils._param_validation import Interval, StrOptions
from ._utils import mmd_ova, mmd_ovo, wasserstein_ova, wasserstein_ovo
from sklearn.utils.extmath import softmax
from sklearn.utils.validation import check_is_fitted


class TorchDouglas(ClusterMixin, BaseEstimator, ABC):
    """
    Implementation of the `DNDTs optimised using GEMINI leveraging apprised splits` tree algorithm. This model learns
    clusters by optimising learnable parameters to perform feature-wise soft-binnings and recombine those bins
    into a single cluster predictions. The parameters are optimised to maximise a generalised mutual information score.

    Parameters
    ----------

    n_clusters : int, default=3
        The number of clusters to form as well as the number of output neurons in the neural network.

    n_cuts: int, default=1
        The number of cuts to consider per feature in the soft binning function of the DNDT

    feature_mask: array of boolean [shape d], default None
        A boolean vector indicating whether a feature should be considered or not among splits. If None,
        all features are considered during training.

    temperature: float, default=0.1
        The temperature controls the relative importance of logits per leaf soft-binning. A high temperature smoothens
        the differences in probability whereas a low temperature produces distributions closer to delta Dirac
        distributions.

    n_epochs: int, default=100
        The number of epochs for training the model parameters.

    batch_size: int, default=None
        The number of samples per batch during an epoch. If set to `None`, all samples will be considered in a single
        batch.

    learning_rate: float, default=1e-2
        The learning rate hyperparameter for the optimiser's update rule.

    gemini: {'mmd_ova', 'mmd_ovo', 'wasserstein_ova', 'wasserstein_ovo'}
        The generalised mutual information objective to maximise w.r.t. the tree parameters. If set to `None`, the
        one-vs-one Wasserstein is chosen.

    use_cuda: bool, default=True
        Whether to use or not GPU acceleration from torch.

    verbose: bool, default=False
        Whether to print progress messages to stdout

    random_state: int, RandomState instance, default=None
        Determines random number generation for feature exploration.
        Pass an int for reproducible results across multiple function calls.

    Attributes
    ----------
    labels_: ndarray of shape (n_samples,)
        The cluster in which each sample of the data was put
    tree_: Tree instance
        The underlying Tree object. Please refer to `help(sklearn.tree._tree.Tree)` for attributes of Tree object.
    """
    _parameter_constraints: dict = {
        "n_clusters": [Interval(Integral, 1, None, closed="left")],
        "n_cuts": [Interval(Integral, 1, None, closed="left"), None],
        "feature_mask": [np.ndarray, None],
        "temperature": [Interval(Real, 0, None, closed="neither")],
        "n_epochs": [Interval(Integral, 1, None, closed="left")],
        "batch_size": [Interval(Integral, 1, None, closed="left"), None],
        "learning_rate": [Interval(Real, 0, None, closed="neither"), None],
        "gemini": [StrOptions({"mmd_ova", "mmd_ovo", "wasserstein_ova", "wasserstein_ovo"}), None],
        "use_cuda": [bool],
        "verbose": [bool],
        "random_state": ["random_state"]
    }

    def __init__(self, n_clusters=3, n_cuts=1, feature_mask=None, temperature=0.1, n_epochs=100, batch_size=None,
                 learning_rate=1e-2, gemini=None, use_cuda=True, verbose=False, random_state=None):
        self.n_clusters = n_clusters
        self.n_cuts = n_cuts
        self.feature_mask = feature_mask
        self.temperature = temperature
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.gemini = gemini
        self.use_cuda = True
        self.verbose = verbose
        self.random_state = random_state

    def _leaf_binning(self, X, cut_points, temperature=0.1):
        n = len(cut_points)
        W = torch.linspace(1.0, n + 1.0, n + 1, dtype=torch.float64, device=X.device)
        W = torch.reshape(W, [1, -1]) # 1xd+1
        cut_points, _ = torch.sort(cut_points)  # make sure cut_points is monotonically increasing
        b = torch.cumsum(torch.cat([torch.zeros([1], dtype=torch.float64, device=cut_points.device), -cut_points], 0), 0) # d+1
        h = X @ W + b
        return torch.softmax(h, dim=1)

    def _merge_leaf(self, leaf_res1, leaf_res2):
        # Compute feature-wise kronecker product
        product = torch.einsum("ij,ik->ijk", [leaf_res1, leaf_res2])

        # reshape to 2d
        return product.reshape((-1, np.prod(product.shape[1:])))

    def fit(self, X, y=None):
        """Performs the DOUGLAS algorithm by optimising feature-wise soft-binnings leafs to maximise a chosen GEMINI
        objective.

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            Training instances to cluster.
        y : ndarray of shape (n_samples, n_samples), default=None
            Use this parameter to give a precomputed affinity metric if the option "precomputed" was passed during
            construction. Otherwise, it is not used and present here for API consistency by convention.

        Returns
        -------
        self : object
            Fitted estimator.
        """
        self._validate_params()

        # Check that X has the correct shape
        X = check_array(X)
        X = self._validate_data(X, accept_sparse=True, dtype=np.float64, ensure_min_samples=self.n_clusters)

        # Create the random state
        random_state = check_random_state(self.random_state)

        batch_size = self.batch_size if self.batch_size is not None else len(X)
        batch_size = min(batch_size, len(X))

        device = torch.device("cuda" if torch.cuda.is_available() and self.use_cuda else "cpu")

        # Create the parameters
        if self.feature_mask is None:
            self.cut_points_list_ = [(i, random_state.normal(size=(self.n_cuts,))) for i in range(X.shape[1])]
            num_leaf = int((self.n_cuts + 1) ** X.shape[1])
        else:
            assert len(self.feature_mask) == X.shape[1], ("The boolean feature mask must have as "
                                                          "much entries as the number of features")
            self.cut_points_list_ = [(i, random_state.normal(size=self.n_cuts, )) for i in range(X.shape[1]) \
                                     if self.feature_mask[i]]
            num_leaf = int((self.n_cuts + 1) ** len(self.cut_points_list_))

        if self.verbose:
            print(self.cut_points_list_)
            print(f"Total will be {num_leaf} values per sample")
        self.leaf_scores_ = random_state.normal(size=(num_leaf, self.n_clusters))

        # Convert to torch parameters
        self.leaf_scores_ = torch.tensor(self.leaf_scores_, requires_grad=True, dtype=torch.float64, device=device)
        self.cut_points_list_ = [(z[0], torch.tensor(z[1], requires_grad=True, dtype=torch.float64, device=device))
                                 for z in self.cut_points_list_]

        self.optimiser_ = torch.optim.Adam([self.leaf_scores_] + [z[1] for z in self.cut_points_list_],
                                           lr=self.learning_rate)

        if self.gemini == "mmd_ova":
            gemini = mmd_ova
        elif self.gemini == "mmd_ovo":
            gemini = mmd_ovo
        elif self.gemini == "wasserstein_ova":
            gemini = wasserstein_ova
        else:
            gemini = wasserstein_ovo

        if self.gemini is not None and "mmd" in self.gemini:
            affinity = X @ X.T
        else:
            affinity = pairwise_distances(X, metric="euclidean")

        X_torch = torch.tensor(X, dtype=torch.float64)
        affinity_torch = torch.tensor(affinity)

        # Training algorithm
        for epoch in range(self.n_epochs):
            batch_idx = 0
            epoch_batch_order = random_state.permutation(len(X))
            avg_loss = 0
            while batch_idx * batch_size < len(X):
                section = slice(batch_idx * batch_size, (batch_idx + 1) * batch_size)
                X_batch = X_torch[epoch_batch_order[section]].to(device)

                affinity_batch = affinity_torch[epoch_batch_order[section]][:, epoch_batch_order[section]].to(device)

                y_pred = self._infer(X_batch)
                # Get probabilities from tree logits
                y_pred = torch.softmax(y_pred, dim=1)

                # Apply loss function, or rather immediately get gradients
                loss = gemini(y_pred, affinity_batch)

                # Compute backpropagation
                self.optimiser_.zero_grad()
                loss.backward()
                self.optimiser_.step()

                batch_idx += 1
                avg_loss += loss.item()

                if self.verbose:
                    print(f"\tBatch {batch_idx}: {loss.item()}")

            if self.verbose:
                print(f"Epoch: {epoch}, Loss: {avg_loss/batch_idx}")

        with torch.no_grad():
            batch_idx = 0
            self.labels_ = []
            while batch_idx * batch_size < len(X):
                section = slice(batch_idx * batch_size, (batch_idx + 1) * batch_size)
                X_batch = X_torch[section].to(device)
                self.labels_ += [torch.argmax(self._infer(X_batch).cpu(), dim=1).numpy()]
                batch_idx+=1
            self.labels_ = np.concatenate(self.labels_)

        return self

    def _infer(self, X):
        leaf_binning = lambda z: self._leaf_binning(X[:, z[0]:z[0] + 1], z[1], self.temperature)
        cut_iterator = map(leaf_binning, self.cut_points_list_)

        all_binnings = list(cut_iterator)

        leaf = reduce(self._merge_leaf, all_binnings)

        y_pred = leaf @ self.leaf_scores_

        return y_pred

    def fit_predict(self, X, y=None):
        """Performs the DOUGLAS algorithm by optimising feature-wise soft-binnings leafs to maximise a chosen GEMINI
        objective. Returns the predicted cluster memberships of the data samples.

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            Training instances to cluster.
        y : ndarray of shape (n_samples, n_samples), default=None
            Use this parameter to give a precomputed affinity metric if the option "precomputed" was passed during
            construction. Otherwise, it is not used and present here for API consistency by convention.

        Returns
        -------
        y_pred : ndarray of shape (n_samples,)
            Vector containing the cluster label for each sample.
        """
        return self.fit(X, y).labels_

    def predict_proba(self, X):
        """ Passes the data samples `X` through the tree structure to assign the probability of belonging to each
        cluster.
        This method can be called only once `fit` or `fit_predict` was performed.

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            Training instances to cluster.

        Returns
        -------
        y_pred : ndarray of shape (n_samples, n_clusters)
            Vector containing on each row the cluster membership probability of its matching sample.
        """
        # Check is fit had been called
        check_is_fitted(self)

        # Input validation
        X = check_array(X)

        with torch.no_grad():
            device = torch.device("cuda" if torch.cuda.is_available() and self.use_cuda else "cpu")
            return self._infer(torch.tensor(X, dtype=torch.float64).to(device)).cpu().numpy()

    def predict(self, X):
        """ Passes the data samples `X` through the tree structure to assign cluster membership.
        This method can be called only once `fit` or `fit_predict` was performed.

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            Training instances to cluster.

        Returns
        -------
        y_pred : ndarray of shape (n_samples,)
            Vector containing the cluster label for each sample.
        """
        return self.predict_proba(X).argmax(axis=1)


    def score(self, X, y=None):
        """
        Return the value of the GEMINI evaluated on the given test data.

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            Test samples.
        y : ndarray of shape (n_samples, n_samples), default=None
            Use this parameter to give a precomputed affinity metric if the option "precomputed" was passed during
            construction. Otherwise, it is not used and present here for API consistency by convention.

        Returns
        -------
        score : float
            GEMINI evaluated on the output of ``self.predict(X)``.
        """
        check_is_fitted(self)

        with torch.no_grad():
            y_pred = torch.softmax(self._infer(torch.tensor(X)), dim=1)
        if self.gemini == "wasserstein_ovo":
            gemini = wasserstein_ovo
        elif self.gemini == "wasserstein_ova":
            gemini = wasserstein_ova
        elif self.gemini == "mmd_ova":
            gemini = mmd_ova
        else:
            gemini = mmd_ovo

        if "mmd" in self.gemini:
            affinity = X @ X.T
        else:
            affinity = pairwise_distances(X, metric="euclidean")
        affinity = torch.tensor(affinity)

        return gemini(y_pred, affinity)
