import sys
sys.path.append('/mnt/data01/****/****')

import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
import torch

import numpy as np
import matplotlib.pyplot as plt
# from sklearn.manifold import TSNE
from cuml import TSNE
# from tensorflow import keras
from keras import layers, optimizers, metrics
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split

from gnn.useful_utils import visualization_metric




class Inductive_TSNE_DNN():
  """
  Classe pour t-SNE inductif utilisant un réseau de neurones profond.

  Cette classe étend l'algorithme t-SNE standard en utilisant un réseau de neurones
  pour apprendre une fonction de projection, permettant ainsi l'application de
  t-SNE à de nouvelles données sans réentraînement complet.

  Attributs:
    model (keras.Sequential): Le modèle de réseau de neurones.
    rmse (float): L'erreur quadratique moyenne entre la projection t-SNE originale
                  et celle prédite par le réseau de neurones.
    history (keras.callbacks.History): L'historique d'entraînement du modèle.
  """
  def __init__(self, **kwargs):
    """
    Initialise l'objet Inductive_TSNE_DNN.

    Args:
      n_components (int): Nombre de composantes pour la projection de sortie.
      **kwargs: Arguments supplémentaires passés à la classe parente TSNE.
    """
    # super().__init__(**kwargs)
    self.model = None
    self.rmse = None
    self.history = None

    # self.perplexity = kwargs['perplexity']
    self.n_components = kwargs['n_components']

  def _build_model(self, input_dim):
    """
    Construit l'architecture du réseau de neurones.

    Args:
      input_dim (int): Dimension des données d'entrée.

    Returns:
      keras.Sequential: Le modèle de réseau de neurones construit.
    """
    return keras.Sequential([
        layers.InputLayer(shape=(input_dim,)),
        layers.Dense(256, activation='relu', name='layer1'),
        layers.BatchNormalization(),
        layers.Dense(256, activation='relu', name='layer11'),
        layers.BatchNormalization(),
        layers.Dense(256, activation='relu', name='layer12'),
        layers.BatchNormalization(),
        layers.Dense(256, activation='relu', name='layer13'),
        layers.BatchNormalization(),
        layers.Dense(256, activation='relu', name='layer2'),
        layers.BatchNormalization(),
        layers.Dense(128, activation='relu', name='layer3'),
        layers.BatchNormalization(),
        layers.Dense(32, activation='relu', name='layer4'),
        layers.BatchNormalization(),
        layers.Dense(8, activation='relu', name='layer5'),
        layers.BatchNormalization(),
        layers.Dense(self.n_components, activation='linear', name='output')
    ])

  def fit(self, X, y_tsne):
    """
    Ajuste le modèle aux données d'entrée.

    Args:
      X (array-like): Données d'entrée de forme (n_samples, n_features).
      y: Ignoré, présent pour la compatibilité avec l'API scikit-learn.

    Returns:
      self: Retourne l'instance de l'objet.
    """
    # Effectue d'abord la projection t-SNE standard

    # y_tsne = super().fit_transform(X)


    # Construit et compile le modèle
    self.model = self._build_model(X.shape[1])
    self.model.compile(
        optimizer=optimizers.Adam(learning_rate=0.001),
        loss='mse',
        metrics=['mae', 'mse']
    )

    # Entraîne le modèle
    self.history = self.model.fit(
        X, y_tsne,
        epochs=300,
        batch_size=min(3000, len(X)),
        verbose=1,
        validation_split=0.05
      )

    # Calcule l'erreur RMSE
    y_pred = self.model.predict(X)
    self.rmse = np.sqrt(np.mean((y_tsne - y_pred) ** 2))

    return self

  def transform(self, X):
    """
    Applique la transformation aux nouvelles données.

    Args:
      X (array-like): Nouvelles données à transformer.

    Returns:
      array: Données transformées dans l'espace de faible dimension.
    """
    if self.model is None:
      raise ValueError("Le modèle n'a pas été entraîné. Appelez fit() d'abord.")
    return self.model.predict(X)

  def fit_transform(self, X, y):
    """
    Ajuste le modèle aux données et renvoie les données transformées.

    Args:
      X (array-like): Données d'entrée.
      y: Ignoré, présent pour la compatibilité avec l'API scikit-learn.

    Returns:
      array: Données transformées dans l'espace de faible dimension.
    """
    return self.fit(X, y).transform(X)

  def score(self, X, y=None):
    """
    Calcule le score du modèle (négatif de RMSE).

    Args:
      X (array-like): Données d'entrée.
      y: Ignoré, présent pour la compatibilité avec l'API scikit-learn.

    Returns:
      float: Score du modèle (plus élevé est meilleur).
    """
    if self.rmse is None:
      raise ValueError("Le modèle n'a pas été entraîné. Appelez fit() d'abord.")
    return -self.rmse

  def plot_training_history(self):
    if self.history is None:
      raise ValueError("Le modèle n'a pas été entraîné. Appelez fit() d'abord.")
    train_loss = self.history.history['loss']
    val_loss = self.history.history['val_loss']
    train_mae = self.history.history['mae']
    val_mae = self.history.history['val_mae']
    epochs = range(1, len(train_loss) + 1)
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, 'b', label='Perte d\'entraînement')
    plt.plot(epochs, val_loss, 'r', label='Perte de validation')
    plt.title('Perte d\'entraînement et de validation')
    plt.xlabel('Époques')
    plt.ylabel('Perte')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_mae, 'b', label='MAE d\'entraînement')
    plt.plot(epochs, val_mae, 'r', label='MAE de validation')
    plt.title('MAE d\'entraînement et de validation')
    plt.xlabel('Époques')
    plt.ylabel('MAE')
    plt.legend()
    plt.tight_layout()
    plt.show()


def load_data(d_name):
    features, labels = torch.load(f"/home/****/autovisual/prepare_data/data/{d_name}_features_clip.tar",
                                  weights_only=False)
    return features, labels


def load_large_data(root, data_names):
    # ori_features, ori_labels = features_n_labels

    features = []
    labels = []
    blocks = []
    z_targets = []
    for d_name in data_names:
        print(d_name)
        s_ = torch.load(root + '/' + d_name + '_clip_cdist_3000.tar', weights_only=False)
        cdist, x, y = s_[:3]

        # selected_indices = np.isin(ori_labels, selected_labels)
        #
        # x_selected = ori_features[selected_indices]
        # y_selected = ori_labels[selected_indices]
        # x = x_selected[ind]

        features.append(x)
        labels.append(y)
        blocks.append(x.shape[0])

        selected_emb, hps = torch.load(
            '/home/****/autovisual/prepare_data/bo/res-2' + '/' + f'visual-method-TSNE_dataset-{d_name}_selected_emb.tar',
            weights_only=False)
        z_targets.append(selected_emb)

    return np.concatenate(features), np.concatenate(labels), blocks,  np.concatenate(z_targets)


def get_evaluation(zs, ys, blocks):
    nmis = []
    scs = []
    blocks = np.cumsum(blocks)
    # if len(blocks) == 1:
    blocks = np.concatenate([np.array([0]), blocks])
    for i in range(len(blocks)-1):
        z = zs[blocks[i]: blocks[i+1]]
        y = ys[blocks[i]: blocks[i+1]]
        print(z.shape)
        nmi, sc = visualization_metric.get_nmi_sc(z, y)
        nmis.append(nmi)
        scs.append(sc)
    return np.array(nmis), np.array(scs)

def get_relative_precision(gt_v, pred_v):
    precision = pred_v / gt_v
    return precision


if __name__ == '__main__':
    train_names = ['mnist_group2', 'mnist_group1', 'fmnist_group2', 'fmnist_group1', 'cifar10_group1'] #+ [
        # f'mnist_comb{i}' for i in range(252)]
    root = '/home/****/autovisual/prepare_data/clip/features'


    train_x, train_y, blocks, train_tsne_z = load_large_data(root, train_names)
    print(blocks)
    print(train_y)
    print(train_x[:100])
    test_x, test_y, test_blocks, test_tsne_z = load_large_data(root, ['cifar10_group2'])

    train_true_nmi, train_true_sc = get_evaluation(train_tsne_z, train_y, blocks)
    print('train true', np.mean(train_true_nmi), np.mean(train_true_sc))

    ind_tsne = Inductive_TSNE_DNN(n_components=2)
    train_z = ind_tsne.fit_transform(train_x, train_tsne_z)
    train_nmi, train_sc = get_evaluation(train_z, train_y, blocks)
    print('train', np.mean(train_nmi), np.mean(train_sc))

    print('train precision', np.mean(get_relative_precision(train_true_nmi, train_nmi)),
          np.mean(get_relative_precision(train_true_sc, train_sc)))

    torch.save(ind_tsne, 'inductive_tsne_large_257.save')

    test_true_nmi, test_true_sc = get_evaluation(test_tsne_z, test_y, test_blocks)
    print(test_true_nmi)
    print('test true', np.mean(test_true_nmi), np.mean(test_true_sc))

    test_z = ind_tsne.transform(test_x)
    test_nmi, test_sc = get_evaluation(test_z, test_y, test_blocks)
    print('test', np.mean(test_nmi), np.mean(test_sc))
    print('test precision', np.mean(get_relative_precision(test_true_nmi, test_nmi)),
          np.mean(get_relative_precision(test_true_sc, test_sc)))
