import os
import copy
import numpy as np
import warnings

from scipy.special import softmax
from scipy.optimize import minimize
from sklearn.neighbors import NearestNeighbors

from sklearn.base import BaseEstimator
from sklearn.svm import LinearSVC, SVR
from sklearn.cluster import KMeans

import tensorflow as tf
from tensorflow import keras
import tensorflow_probability as tfp
from utils.others_metrics import score

class _Base():

    def __init__(self, random_state=None):
        if not random_state is None:
            np.random.seed(random_state)

        self._n_features = None
        self._n_outputs = None

    def fit(self, X, y=None):
        self._X = X
        self._y = y
        self._n_features = self._X.shape[1]
        if not self._y is None:
            self._n_outputs = self._y.shape[1]

    def _not_been_fit(self):
        raise ValueError("The model has not yet been fit. "
                         "Try to call 'fit()' first with some training data.")

    @property
    def n_features(self):
        if self._n_features is None:
            self._not_been_fit()
        return self._n_features

    @property
    def n_outputs(self):
        if self._n_outputs is None:
            self._not_been_fit()
        return self._n_outputs


class _BaseLDL(_Base):

    def predict(self, _):
        pass

    def score(self, X, y,
              metrics=[ "canberra", "chebyshev", "clark", "cosine", "intersection", "kl_divergence", "spearman", "kendall"]):
        return score(y, self.predict(X), metrics=metrics)


class _BaseDeep(keras.Model):

    def __init__(self, n_hidden=None, n_latent=None, random_state=None):

        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

        keras.Model.__init__(self)
        if not random_state is None:
            tf.random.set_seed(random_state)
        self._n_latent = n_latent
        self._n_hidden = n_hidden


class BaseLDL(_BaseLDL, BaseEstimator):
    
    def __init__(self, random_state=None):
        super().__init__(random_state)


class BaseDeepLDL(_BaseLDL, _BaseDeep):

    def __init__(self, n_hidden=None, n_latent=None, random_state=None):
        _BaseLDL.__init__(self, random_state)
        _BaseDeep.__init__(self, n_hidden, n_latent, random_state)

    def fit(self, X, y):
        _BaseLDL.fit(self, X, y)
        self._X = tf.cast(self._X, dtype=tf.float32)
        self._y = tf.cast(self._y, dtype=tf.float32)


class _SA(BaseLDL):

    def __init__(self,
                 maxiter=600,
                 convergence_criterion=1e-6,
                 random_state=None):

        super().__init__(random_state)

        self.maxiter = maxiter
        self.convergence_criterion = convergence_criterion

        self._W = None

        warnings.filterwarnings('ignore', "The iteration is not making good progress,")
        
    def _object_fun(self, weights):
        W = weights.reshape(self._n_outputs, self._n_features).transpose()
        y_pred = softmax(np.dot(self._X, W), axis=1)
        
        func_loss = self._loss(y_pred)
        func_grad = self._gradient(y_pred)
        
        return func_loss, func_grad

    def _gradient(self, y_pred):
        grad = np.dot(self._X.T, y_pred - self._y)
        return grad.transpose().reshape(-1, )

    def _loss(self, y_pred):
        y_true = np.clip(self._y, 1e-7, 1)
        y_pred = np.clip(y_pred, 1e-7, 1)
        return -1 * np.sum(y_true * np.log(y_pred))

    def _specialized_alg(self, _):
        pass

    def fit(self, X, y):
        super().fit(X, y)

        weights = np.random.uniform(-0.1, 0.1, self._n_features * self._n_outputs)
        self._W = self._specialized_alg(weights)

    def predict(self, X):
        return softmax(np.dot(X, self._W), axis=1)

    @property
    def W(self):
        if self._W is None:
            self._not_been_fit()
        return self._W


class SA_BFGS(_SA):
    
    def _specialized_alg(self, weights):
        optimize_result = minimize(self._object_fun, weights, method='L-BFGS-B', jac=True,
                                   options={'gtol': self.convergence_criterion,
                                            'disp': False, 'maxiter': self.maxiter})
        return optimize_result.x.reshape(self._n_outputs, self._n_features).transpose()


class AA_KNN(BaseLDL):

    def __init__(self,
                 k=5,
                 random_state=None):

        super().__init__(random_state)

        self.k = k
        self._model = NearestNeighbors(n_neighbors=self.k)
    
    def fit(self, X, y):
        super().fit(X, y)
        self._model.fit(self._X)
        
    def predict(self, X):
        _, inds = self._model.kneighbors(X)
        return np.average(self._y[inds], axis=1)


class BaseEnsemble(BaseLDL):

    def __init__(self, estimator=SA_BFGS(), n_estimators=None, random_state=None):
        super().__init__(random_state)
        self._estimator = estimator
        self._n_estimators = n_estimators
        self._estimators = None

    def __len__(self):
        return len(self._estimators)

    def __getitem__(self, index):
        return self._estimators[index]

    def __iter__(self):
        return iter(self._estimators)


class DF_LDL(BaseEnsemble):

    def __init__(self, estimator=SA_BFGS(), random_state=None):
        super().__init__(estimator, None, random_state)

    def fit(self, X, y):
        super().fit(X, y)

        m, c = self._y.shape[0], self._y.shape[1]
        L = {}

        for i in range(c):
            for j in range(i + 1, c):

                ss1 = []
                ss2 = []

                for k in range(m):
                    if self._y[k, i] >= self._y[k, j]:
                        ss1.append(k)
                    else:
                        ss2.append(k)

                l1 = copy.deepcopy(self._estimator)
                l1.fit(self._X[ss1], self._y[ss1])
                L[str(i)+","+str(j)] = copy.deepcopy(l1)

                l2 = copy.deepcopy(self._estimator)
                l2.fit(self._X[ss2], self._y[ss2])
                L[str(j)+","+str(i)] = copy.deepcopy(l2)

        self._estimators = L

        self._knn = AA_KNN()
        self._knn.fit(self._X, self._y)

    def predict(self, X):

        m, c = X.shape[0], self._y.shape[1]
        p_knn = self._knn.predict(X)
        p = np.zeros((m, c), dtype=np.float32)
        
        for k in range(m):
            for i in range(c):
                for j in range(i + 1, c):

                    if p_knn[k, i] >= p_knn[k, j]:
                        l = self._estimators[str(i)+","+str(j)]
                    else:
                        l = self._estimators[str(j)+","+str(i)]

                    p[k] += l.predict(X[k].reshape(1, -1)).reshape(-1)
        
        return p / (c * (c - 1) / 2)


class LDLF(BaseDeepLDL):

    def __init__(self, n_estimators=5, n_depth=6, n_hidden=None, n_latent=64, random_state=None):
        super().__init__(n_hidden, n_latent, random_state)
        self._n_estimators = n_estimators
        self._n_depth = n_depth
        self._n_leaves = 2 ** n_depth

    def _call(self, X, i):
        decisions = tf.gather(self._model(X), self._phi[i], axis=1)
        decisions = tf.expand_dims(decisions, axis=2)
        decisions = tf.concat([decisions, 1 - decisions], axis=2)
        mu = tf.ones([X.shape[0], 1, 1])
        
        begin_idx = 1
        end_idx = 2
        
        for level in range(self._n_depth):
            mu = tf.reshape(mu, [X.shape[0], -1, 1])
            mu = tf.tile(mu, (1, 1, 2))
            level_decisions = decisions[:, begin_idx:end_idx, :]
            mu = mu * level_decisions
            
            begin_idx = end_idx
            end_idx = begin_idx + 2 ** (level + 1)

        mu = tf.reshape(mu, [X.shape[0], self._n_leaves])
        
        return mu

    def fit(self, X, y, learning_rate=5e-2, epochs=3000):
        super().fit(X, y)

        self._phi = [np.random.choice(
            np.arange(self._n_latent), size=self._n_leaves, replace=False
        ) for _ in range(self._n_estimators)]

        self._pi = [tf.Variable(
            initial_value = tf.constant_initializer(1 / self.n_outputs)(
                shape=[self._n_leaves, self._n_outputs]
            ),
            dtype="float32", trainable=True,
        ) for _ in range(self._n_estimators)]

        if self._n_hidden is None:
            self._n_hidden = self._n_features * 3 // 2

        self._model = keras.Sequential([keras.layers.InputLayer(input_shape=self._n_features),
                                        keras.layers.Dense(self._n_hidden, activation='sigmoid'),
                                        keras.layers.Dense(self._n_latent, activation="sigmoid")])
        self._optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        
        for _ in range(epochs):
            with tf.GradientTape() as model_tape:
                loss = 0.
                for i in range(self._n_estimators):
                    _mu = self._call(X, i)
                    _prob = tf.matmul(_mu, self._pi[i])

                    loss += tf.math.reduce_mean(keras.losses.kl_divergence(self._y, _prob))

                    _y = tf.expand_dims(self._y, axis=1)
                    _pi = tf.expand_dims(self._pi[i], axis=0)
                    _mu = tf.expand_dims(_mu, axis=2)
                    _prob = tf.clip_by_value(
                        tf.expand_dims(_prob, axis=1), clip_value_min=1e-6, clip_value_max=1.0)
                    _new_pi = tf.multiply(tf.multiply(_y, _pi), _mu) / _prob
                    _new_pi = tf.reduce_sum(_new_pi, axis=0)
                    _new_pi = keras.activations.softmax(_new_pi)
                    self._pi[i].assign(_new_pi)
                    
                loss /= self._n_estimators

            gradients = model_tape.gradient(loss, self._model.trainable_variables)
            self._optimizer.apply_gradients(zip(gradients, self._model.trainable_variables))

    def predict(self, X):
        res = np.zeros([X.shape[0], self._n_outputs], dtype=np.float32)
        for i in range(self._n_estimators):
            res += tf.matmul(self._call(X, i), self._pi[i])
        return res / self._n_estimators


class LDL_SCL(BaseDeepLDL):

    def __init__(self, n_hidden=None, n_latent=None, random_state=None):
        super().__init__(n_hidden, n_latent, random_state)

    @tf.function
    def _loss(self, X, y):
        y_pred = keras.activations.softmax(self._model(X) + tf.matmul(self._C, self._W))

        kl = tf.math.reduce_mean(keras.losses.kl_divergence(y, y_pred))
        
        corr = tf.math.reduce_mean(self._C * keras.losses.mean_squared_error(
            tf.expand_dims(y_pred, 1), tf.expand_dims(self._P, 0)
        ))

        barr = tf.math.reduce_mean(1 / self._C)

        return kl + self._alpha * corr + self._beta * barr

    def fit(self, X, y, n_clusters=5, learning_rate=5e-2, epochs=3000, alpha=1e-3, beta=1e-6):
        super().fit(X, y)

        self._n_clusters = n_clusters
        self._alpha = alpha
        self._beta = beta
        
        self._P = tf.cast(KMeans(n_clusters=self._n_clusters).fit(self._y).cluster_centers_,
                          dtype=tf.float32)

        self._C = tf.Variable(tf.zeros((self._X.shape[0], self._n_clusters)) + 1e-6,
                              trainable=True)

        self._W = tf.Variable(tf.random.normal((self._n_clusters, self._n_outputs)),
                              trainable=True)

        self._model = keras.Sequential([keras.layers.InputLayer(input_shape=self._n_features),
                                        keras.layers.Dense(self._n_outputs, activation=None)])
        self._optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

        for _ in range(epochs):
            with tf.GradientTape(persistent=True) as tape:
                loss = self._loss(self._X, self._y)

            gradients = tape.gradient(loss, self.trainable_variables)
            self._optimizer.apply_gradients(zip(gradients, self.trainable_variables))


    def predict(self, X):

        C = np.zeros((X.shape[0], self._n_clusters))
        for i in range(self._n_clusters):
            lr = SVR()
            lr.fit(self._X.numpy(), self._C.numpy()[:, i].reshape(-1))
            C[:, i] = lr.predict(X).reshape(1, -1)
        C = tf.cast(C, dtype=tf.float32)

        return keras.activations.softmax(self._model(X) + tf.matmul(C, self._W))


class LDL_LRR(BaseDeepLDL):

    def __init__(self, n_hidden=None, n_latent=None, random_state=None):
        super().__init__(n_hidden, n_latent, random_state)

    @tf.function
    def _ranking_loss(self, y_pred, P, W):
        Phat = tf.math.sigmoid((tf.expand_dims(y_pred, -1) - tf.expand_dims(y_pred, 1)) * 100)
        l = ((1 - P) * tf.math.log(tf.clip_by_value(1 - Phat, 1e-9, 1.0)) + P * tf.math.log(tf.clip_by_value(Phat, 1e-9, 1.0))) * W
        return -tf.reduce_sum(l)
    
    @tf.function
    def _loss(self, X, y):
        y_pred = self._model(X)
        kl = tf.math.reduce_mean(keras.losses.kl_divergence(y, y_pred))
        rank = self._ranking_loss(y_pred, self._P, self._W)
        return kl + self._alpha / (2 * X.shape[0]) * rank

    def _get_obj_func(self):

        shapes = tf.shape_n(self._model.trainable_variables)
        n_tensors = len(shapes)

        count = 0
        idx = []
        part = []

        for i, shape in enumerate(shapes):
            n = np.product(shape)
            idx.append(tf.reshape(tf.range(count, count+n, dtype=tf.int32), shape))
            part.extend([i]*n)
            count += n

        part = tf.constant(part)

        @tf.function
        def _assign_new_model_parameters(params_1d):
            params = tf.dynamic_partition(params_1d, part, n_tensors)
            for i, (shape, param) in enumerate(zip(shapes, params)):
                self._model.trainable_variables[i].assign(tf.reshape(param, shape))

        @tf.function
        def _f(params_1d):

            with tf.GradientTape() as tape:
                _assign_new_model_parameters(params_1d)
                loss = self._loss(self._X, self._y)
            
            grads = tape.gradient(loss, self._model.trainable_variables)
            grads = tf.dynamic_stitch(idx, grads)

            return loss, grads
        
        _f.iter = tf.Variable(0)
        _f.idx = idx
        _f.part = part
        _f.shapes = shapes
        _f.assign_new_model_parameters = _assign_new_model_parameters
        _f.history = []
        return _f

    def fit(self, X, y, alpha=1e-2, beta=1.):
        super().fit(X, y)

        self._alpha = alpha
        self._beta = beta

        P = tf.nn.sigmoid(tf.expand_dims(self._y, -1) - tf.expand_dims(self._y, 1))
        self._P = tf.where(P > .5, 1., 0.)

        self._W = tf.square(tf.expand_dims(self._y, -1) - tf.expand_dims(self._y, 1))

        self._model = keras.Sequential(
            [keras.layers.InputLayer(input_shape=self._n_features),
             keras.layers.Dense(self._n_outputs,
                                activation="softmax",
                                kernel_regularizer=keras.regularizers.L2(self._beta))]
        )
        
        func = self._get_obj_func()
        init_params = tf.dynamic_stitch(func.idx, self._model.trainable_variables)

        results = tfp.optimizer.lbfgs_minimize(
            value_and_gradients_function=func, initial_position=init_params,
            max_iterations=500, tolerance=1.4901161193847656e-08
        )

        func.assign_new_model_parameters(results.position)

    def predict(self, X):
        return self._model(X)
    

