import sys
sys.path.append('/mnt/data01/****/****')

import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
import torch

import numpy as np
import matplotlib.pyplot as plt
# from sklearn.manifold import TSNE
from cuml import TSNE
# from tensorflow import keras
from keras import layers, optimizers, metrics
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split

from gnn.useful_utils import visualization_metric




class Inductive_TSNE_DNN():
  """
  Classe pour t-SNE inductif utilisant un réseau de neurones profond.

  Cette classe étend l'algorithme t-SNE standard en utilisant un réseau de neurones
  pour apprendre une fonction de projection, permettant ainsi l'application de
  t-SNE à de nouvelles données sans réentraînement complet.

  Attributs:
    model (keras.Sequential): Le modèle de réseau de neurones.
    rmse (float): L'erreur quadratique moyenne entre la projection t-SNE originale
                  et celle prédite par le réseau de neurones.
    history (keras.callbacks.History): L'historique d'entraînement du modèle.
  """
  def __init__(self, **kwargs):
    """
    Initialise l'objet Inductive_TSNE_DNN.

    Args:
      n_components (int): Nombre de composantes pour la projection de sortie.
      **kwargs: Arguments supplémentaires passés à la classe parente TSNE.
    """
    # super().__init__(**kwargs)
    self.model = None
    self.rmse = None
    self.history = None

    self.perplexity = kwargs['perplexity']
    self.n_components = kwargs['n_components']

  def _build_model(self, input_dim):
    """
    Construit l'architecture du réseau de neurones.

    Args:
      input_dim (int): Dimension des données d'entrée.

    Returns:
      keras.Sequential: Le modèle de réseau de neurones construit.
    """
    return keras.Sequential([
        layers.InputLayer(shape=(input_dim,)),
        layers.Dense(256, activation='relu', name='layer1'),
        layers.BatchNormalization(),
        layers.Dense(256, activation='relu', name='layer2'),
        layers.BatchNormalization(),
        layers.Dense(128, activation='relu', name='layer3'),
        layers.BatchNormalization(),
        layers.Dense(32, activation='relu', name='layer4'),
        layers.BatchNormalization(),
        layers.Dense(8, activation='relu', name='layer5'),
        layers.BatchNormalization(),
        layers.Dense(self.n_components, activation='linear', name='output')
    ])

  def fit(self, X, y_tsne):
    """
    Ajuste le modèle aux données d'entrée.

    Args:
      X (array-like): Données d'entrée de forme (n_samples, n_features).
      y: Ignoré, présent pour la compatibilité avec l'API scikit-learn.

    Returns:
      self: Retourne l'instance de l'objet.
    """
    # Effectue d'abord la projection t-SNE standard

    # y_tsne = super().fit_transform(X)


    # Construit et compile le modèle
    self.model = self._build_model(X.shape[1])
    self.model.compile(
        optimizer=optimizers.Adam(learning_rate=0.001),
        loss='mse',
        metrics=['mae', 'mse']
    )

    # Entraîne le modèle
    self.history = self.model.fit(
        X, y_tsne,
        epochs=300,
        batch_size=min(128, len(X)),
        verbose=0,
        validation_split=0.1
      )

    # Calcule l'erreur RMSE
    y_pred = self.model.predict(X)
    self.rmse = np.sqrt(np.mean((y_tsne - y_pred) ** 2))

    return self

  def transform(self, X):
    """
    Applique la transformation aux nouvelles données.

    Args:
      X (array-like): Nouvelles données à transformer.

    Returns:
      array: Données transformées dans l'espace de faible dimension.
    """
    if self.model is None:
      raise ValueError("Le modèle n'a pas été entraîné. Appelez fit() d'abord.")
    return self.model.predict(X)

  def fit_transform(self, X, y):
    """
    Ajuste le modèle aux données et renvoie les données transformées.

    Args:
      X (array-like): Données d'entrée.
      y: Ignoré, présent pour la compatibilité avec l'API scikit-learn.

    Returns:
      array: Données transformées dans l'espace de faible dimension.
    """
    return self.fit(X, y).transform(X)

  def score(self, X, y=None):
    """
    Calcule le score du modèle (négatif de RMSE).

    Args:
      X (array-like): Données d'entrée.
      y: Ignoré, présent pour la compatibilité avec l'API scikit-learn.

    Returns:
      float: Score du modèle (plus élevé est meilleur).
    """
    if self.rmse is None:
      raise ValueError("Le modèle n'a pas été entraîné. Appelez fit() d'abord.")
    return -self.rmse

  def plot_training_history(self):
    if self.history is None:
      raise ValueError("Le modèle n'a pas été entraîné. Appelez fit() d'abord.")
    train_loss = self.history.history['loss']
    val_loss = self.history.history['val_loss']
    train_mae = self.history.history['mae']
    val_mae = self.history.history['val_mae']
    epochs = range(1, len(train_loss) + 1)
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, 'b', label='Perte d\'entraînement')
    plt.plot(epochs, val_loss, 'r', label='Perte de validation')
    plt.title('Perte d\'entraînement et de validation')
    plt.xlabel('Époques')
    plt.ylabel('Perte')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_mae, 'b', label='MAE d\'entraînement')
    plt.plot(epochs, val_mae, 'r', label='MAE de validation')
    plt.title('MAE d\'entraînement et de validation')
    plt.xlabel('Époques')
    plt.ylabel('MAE')
    plt.legend()
    plt.tight_layout()
    plt.show()



def load_data(d_name):
    features, labels = torch.load(f"/home/****/autovisual/prepare_data/data/{d_name}_features_clip.tar",
                                  weights_only=False)
    return features, labels

if __name__ == '__main__':
    train_x, train_y = load_data('mnist')
    test_x, test_y = load_data('cifar10')

    print(123)
    train_x, train_y = train_x[:3000], train_y[:3000]
    test_x, test_y = test_x[:3000], test_y[:3000]

    y_tsne = TSNE(perplexity=30, n_components=2, verbose=0, init='random', random_state=42).fit_transform(
        train_x)
    nmi, sc = visualization_metric.get_nmi_sc(y_tsne, train_y[:3000].tolist())
    print('train true', nmi, sc)

    ind_tsne = Inductive_TSNE_DNN(n_components=2, perplexity=30)
    train_z = ind_tsne.fit_transform(train_x, y_tsne)
    nmi, sc = visualization_metric.get_nmi_sc(train_z, train_y[:3000].tolist())
    print('train', nmi, sc)

    test_y_tsne = TSNE(perplexity=30, n_components=2, verbose=0, init='random', random_state=42).fit_transform(
        test_x)
    nmi, sc = visualization_metric.get_nmi_sc(test_y_tsne, test_y[:3000].tolist())
    print('test true', nmi, sc)

    test_z = ind_tsne.transform(test_x)
    nmi, sc = visualization_metric.get_nmi_sc(test_z, test_y[:3000].tolist())
    print('test', nmi, sc)

