import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import numpy as np

from functools import reduce

from sklearn.cluster import KMeans
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split

import keras
import tensorflow as tf

from base import BaseDeepLDL
from utils import load_dataset


class SLDL(BaseDeepLDL):

    def __init__(self, combi, loss_function='LRR', n_hidden=None, random_state=None):
        super().__init__(n_hidden, None, random_state)
        self._combi = combi
        self._n_latent = reduce(lambda count, l: count + len(l), self._combi, 0)
        self._loss_function = loss_function

    @tf.function
    def _SCL(self, y, y_pred):
        corr = tf.math.reduce_mean(self._C * keras.losses.mean_squared_error(
            tf.expand_dims(y_pred, 1), tf.expand_dims(self._P, 0)
        ))
        return corr
    
    @tf.function
    def _LRR(self, y_pred, P, W):
        Phat = tf.math.sigmoid((tf.expand_dims(y_pred, -1) - tf.expand_dims(y_pred, 1)) * 100)
        l = ((1 - P) * tf.math.log(tf.clip_by_value(1 - Phat, 1e-9, 1.0)) + P * tf.math.log(tf.clip_by_value(Phat, 1e-9, 1.0))) * W
        return -tf.reduce_sum(l)

    @tf.function
    def _KL(self, y, y_pred):
        return tf.math.reduce_sum(keras.losses.kl_divergence(y, y_pred))

    def _loss(self, X, y, now_batch):
        
        batch_start = self._batch_size * now_batch
        batch_end = min(self._y.shape[0], self._batch_size * (now_batch + 1))

        rep = self._encoder1(X)
        latent = self._encoder2(rep)

        l1, start = 0., 0
        for i in range(len(self._combi)):
            temp = keras.activations.softmax(latent[:, start:start+len(self._combi[i])])
            mae = keras.losses.mean_absolute_error(self._real_task[i][batch_start:batch_end], temp)
            l1 += tf.math.reduce_mean(self._weights[i][batch_start:batch_end] * mae)
            start += len(self._combi[i])
        
        features = tf.concat([rep, latent], axis=1)

        outputs = keras.activations.softmax(self._decoder(features))
        l2 = tf.math.reduce_mean(keras.losses.kl_divergence(y, outputs))
        
        if self._loss_function == 'LRR':
            extra_loss = self._LRR(outputs, self._P[batch_start:batch_end], self._W[batch_start:batch_end])
            extra_loss *= self._alpha / (2 * X.shape[0])
        else:
            extra_loss = eval(f'self._{self._loss_function}(y, outputs)')
            extra_loss *= self._alpha

        return self._gamma * l1 + l2 + extra_loss

    def fit(self, X, y, batch_size=256, epochs=350, learning_rate=2e-3,
            gamma=1e-1, alpha=1e-3):
        
        super().fit(X, y)

        self._gamma = gamma
        self._alpha = alpha

        if self._loss_function == 'LRR':
            self._P = tf.where(tf.nn.sigmoid(tf.expand_dims(self._y, -1) - tf.expand_dims(self._y, 1)) > .5, 1., 0.)
            self._W = tf.square(tf.expand_dims(self._y, -1) - tf.expand_dims(self._y, 1))
        elif self._loss_function == 'SCL':
            self._n_clusters = 5
            self._P = tf.cast(KMeans(n_clusters=self._n_clusters).fit(self._y).cluster_centers_,
                          dtype=tf.float32)
            self._C = tf.Variable(tf.zeros((self._X.shape[0], self._n_clusters))+1e-7,
                                trainable=True)
            self._W = tf.Variable(tf.random.normal((self._n_clusters, self._n_outputs)),
                                trainable=True)

        temp = [tf.clip_by_value(tf.gather(self._y, axis=1, indices=c), 1e-7, 1.0) for c in self._combi]
        self._real_task = [i / (tf.reshape(tf.reduce_sum(i, axis=1), (-1, 1))) for i in temp]

        self._weights = [tf.math.reduce_sum(tf.gather(self._y, axis=1, indices=c), axis=1) for c in self._combi]

        if self._n_hidden is None:
            self._n_hidden = self._n_features * 3 // 2

        self._encoder1 = keras.Sequential([keras.layers.InputLayer(input_shape=self._n_features),
                                           keras.layers.BatchNormalization(trainable=False),
                                           keras.layers.Dense(self._n_hidden, activation=keras.layers.LeakyReLU())])
        self._encoder2 = keras.Sequential([keras.layers.Dense(self._n_latent, activation=keras.layers.LeakyReLU())])
        

        self._decoder = keras.Sequential([keras.layers.InputLayer(input_shape=self._n_hidden+self._n_latent),
                                        keras.layers.BatchNormalization(trainable=False),
                                        keras.layers.Dense(self._n_hidden, activation=keras.layers.LeakyReLU()),
                                        keras.layers.Dense(self._n_outputs, activation=keras.layers.LeakyReLU())])


        self._optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        
        self._batch_size = batch_size
        data = tf.data.Dataset.from_tensor_slices((self._X, self._y)).batch(self._batch_size)

        for _ in range(epochs):      
            now_batch = 0
            for batch in data:
                with tf.GradientTape() as tape:
                    loss = self._loss(batch[0], batch[1], now_batch)
                gradients = tape.gradient(loss, self.trainable_variables)
                self._optimizer.apply_gradients(zip(gradients, self.trainable_variables))
                now_batch += 1

    def predict(self, X):
        rep = self._encoder1(X)
        latent = self._encoder2(rep)
        features = tf.concat([rep, latent], axis=1)

        if self._loss_function == 'SCL':
            C = np.zeros((X.shape[0], self._n_clusters))
            for i in range(self._n_clusters):
                lr = LinearRegression()
                lr.fit(self._X.numpy(), self._C.numpy()[:, i].reshape(-1, 1))
                C[:, i] = lr.predict(X).reshape(1, -1)
            C = tf.cast(C, dtype=tf.float32)
            return keras.activations.softmax(self._decoder(features) + tf.matmul(C, self._W))
        else:
            outputs = keras.activations.softmax(self._decoder(features))
            return outputs


class S_KLD(SLDL):

    def __init__(self, combi, **params):
        super().__init__(combi, loss_function='KL', **params)


class S_SCL(SLDL):

    def __init__(self, combi, **params):
        super().__init__(combi, loss_function='SCL', **params)


class S_LRR(SLDL):

    def __init__(self, combi, **params):
        super().__init__(combi, loss_function='LRR', **params)


class SC(keras.Model):

    def _loss(self):

        loss = 0.
        W = keras.activations.sigmoid(self._W)
        for i in range(self._t):
            loss += tf.exp(-tf.reduce_mean(W[i] * self._y))
        loss /= self._t

        sim = 0.
        for i in range(self._t):
            for j in range(i+1):
                sim += tf.abs(tf.keras.losses.cosine_similarity(W[i], W[j]))
        sim /= ((self._t*(self._t-1))/2)
        sim *= self._alpha
        loss += sim

        return loss

    def fit(self, y, t=10, alpha=.2):

        self._y = y
        self._t = t
        self._alpha = alpha
        self._W = tf.Variable(tf.random.normal((self._t, y.shape[1])), trainable=True)

        epochs = 100
        learning_rate = 1.
        self._optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        for _ in range(epochs):
            with tf.GradientTape() as tape:
                loss = self._loss()
            gradients = tape.gradient(loss, self.trainable_variables)
            self._optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    
    def transform(self):

        def binary_to_indices(matrix):
            indices_list = []
            for row in matrix:
                indices = np.where(np.array(row) == 1)[0]
                indices_list.append(indices)
            return indices_list

        W =  tf.where(keras.activations.sigmoid(self._W) > .95, 1, 0).numpy()
        masks = set()
        for i in range(W.shape[0]):
            if np.all(W[i] == 1) or np.sum(W[i] == 1) == 1:
                continue
            masks.add(tuple(W[i]))
        return binary_to_indices(list(masks))


if __name__ == '__main__':

    methods = ['S_LRR']
    # methods = ['S_LRR', 'S_SCL', 'S_KLD']
    metrics = ["chebyshev", "clark", "canberra", "kl_divergence", "cosine", "intersection", "spearman", "kendall"]
    datasets = ['JAFFE']

    for dataset in datasets:

        X, y = load_dataset(f'{dataset}')

        print(f'dataset: {dataset}')
        print('construct subtasks...')
        sc = SC()
        sc.fit(y)
        combi = sc.transform()

        for method in methods:

            print(f'training {method}...')
            model = eval(f'{method}(combi)')
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=0)
            model.fit(X_train, y_train)
            scores = model.score(X_test, y_test, metrics=metrics)
            print(metrics)
            print(scores, '')
            del model

        print('done!')
