import sys
from src.utils.numpy_dataset import FromNumpyDataset
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from src.models.rfphate import RFPHATE
from src.models.base_model import BaseModel
from src.models.torch_models import AETorchModule
import numpy as np

# Defaults
n_components = 2
batch_size = 128
lr = 0.001
weight_decay = 0
epochs = 100
hidden_dims = [800, 400, 100]
device = 'cpu'
lam = 0
diffusion = True

class RF_GAE_REG(BaseModel):
    """
    Supervised Generalized Regularized Autoencoder based on RF Proximies
    """

    def __init__(self,
                 n_components = n_components,
                 lr = lr,
                 batch_size = batch_size,
                 weight_decay = weight_decay,
                 random_state = None,
                 device = device,
                 optimizer = None,
                 torch_module = None,
                 epochs = epochs,
                 scheduler = None,
                 criterion = None,
                 data_shape = None,
                 hidden_dims = hidden_dims,
                 embedder_params = None,
                 lam = lam,
                 diffusion = diffusion
                 ):

        self.n_components = n_components
        self.lr = lr
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        self.random_state = random_state
        self.device = device
        self.torch_module = torch_module
        self.optimizer = optimizer
        self.epochs = epochs
        self.scheduler = scheduler
        self.criterion = criterion
        self.data_shape = data_shape
        self.hidden_dims = hidden_dims
        self.embedder_params = embedder_params
        self.lam = lam
        self.diffusion = diffusion

        if embedder_params is None:
            self.embedder = RFPHATE(random_state = random_state, n_components = self.n_components)
        else: 
            self.embedder = RFPHATE(random_state = random_state, n_components = self.n_components, **embedder_params)

        # super().__init__()

    def init_torch_module(self, data_shape):

        input_size = data_shape

        self.torch_module = AETorchModule(input_dim   = input_size,
                                          hidden_dims = self.hidden_dims,
                                          z_dim       = self.n_components)

    def fit(self, x, y):

        self.data_shape = x.shape[1]

        self.z_target = self.embedder.fit_transform(x, y)

        if self.diffusion is True:
            self.embedder.fit_phate_op(x, y, gamma=-1)
            self.training_proximities = self.embedder.phate_op.diff_potential
        else:
            self.training_proximities = self.embedder.proximity.toarray()

        if self.random_state is not None:

            torch.manual_seed(self.random_state)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        if self.torch_module is None:
            self.init_torch_module(self.data_shape)

        if self.optimizer is None:
            self.optimizer = torch.optim.AdamW(self.torch_module.parameters(),
                                              lr=self.lr,
                                              weight_decay=self.weight_decay)

        if self.criterion is None:
            self.criterion = nn.MSELoss()

        self.loader = self.get_loader(x, self.z_target, self.training_proximities)

        self.train_loop(self.torch_module, self.epochs, self.loader, self.optimizer, self.device) 

    def fit_transform(self, x, y):
        self.fit(x, y)
        return self.transform(x)

    def get_loader(self, x, z_target, training_proximities):
        indices = torch.arange(len(x))  # Creating tensor of indices
        dataset = TensorDataset(indices,    # to get batch indices
                                torch.tensor(x, dtype=torch.float),
                                torch.tensor(z_target, dtype=torch.float),
                                torch.tensor(training_proximities, dtype=torch.float))
        return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

    def loss_recon(self, x, x_hat, prox_weights):
        """
        Inputs:
        x: training batch points (batch_size*n_features)
        x_hat: output of the batch (batch_size*n_features)
        prox_weights: similarity weights between each instance in x (batch_size*batch_size)

        Computes the loss function:
        Loss = \frac{1}{batch_size * batch_size} \sum_{i \in batch idx} \sum_{j \in batch idx} w_{ij}(x_hat_i - x_j)^2
        """

        # Compute the squared differences (x_hat_i - x_j)^2 for each pair i, j
        differences = (x_hat.unsqueeze(1) - x.unsqueeze(0))**2  # Shape: (batch_size, batch_size, n_features)
        
        # Sum over the feature dimension (n_features)
        diff_sum = differences.sum(dim=2)  # Shape: (batch_size, batch_size)

        # Multiply by the prox_weights and sum over both i and j
        weighted_diff_sum = (prox_weights * diff_sum).sum()

        # Normalize by batch_size^2
        batch_size = x.size(0)
        loss = weighted_diff_sum / (batch_size * batch_size)        
        return loss

    def compute_loss(self, x, x_hat, prox_weights, z_target, z):

        loss_recon = self.loss_recon(x, x_hat, prox_weights)
        loss_emb = self.criterion(z_target, z)

        loss = loss_recon + self.lam * loss_emb

        self.recon_loss_temp = loss_recon.item()
        self.emb_loss_temp = loss_emb.item()

        loss.backward()

    def train_loop(self, model, epochs, train_loader, optimizer, device = 'cpu'):

        self.epoch_losses_recon = []
        self.epoch_losses_emb  = []
        
        for _, epoch in enumerate(range(epochs)):

            model = model.train()
            model = model.to(device)

            running_recon_loss = 0
            running_emb_loss = 0

            for _, (indices, x, z_target, prox_weights) in enumerate(train_loader, 0):

                x = x.to(device)
                z_target = z_target.to(device)
                prox_weights = prox_weights[:, indices].to(device)

                optimizer.zero_grad()

                x_hat, z = model(x)

                self.compute_loss(x, x_hat, prox_weights, z_target, z)

                running_recon_loss += self.recon_loss_temp
                running_emb_loss += self.emb_loss_temp

                optimizer.step()

            self.epoch_losses_recon.append(running_recon_loss / len(train_loader))
            self.epoch_losses_emb.append(running_emb_loss / len(train_loader))
            if epoch%10 == 0:
                print(f"Epoch {epoch}/{self.epochs}, Recon Loss: {self.epoch_losses_recon[-1]:.7f}, Geo Loss: {self.epoch_losses_emb[-1]}") 

    def transform(self, x):
        self.torch_module.eval()

        x = TensorDataset(torch.tensor(x, dtype = torch.float, device = self.device))

        loader = torch.utils.data.DataLoader(x, batch_size=self.batch_size,
                                             shuffle=False)
 
        z = [self.torch_module.encoder(batch[0].to(self.device)).cpu().detach().numpy() for batch in loader]
        return np.concatenate(z)

    def inverse_transform(self, x):
        self.torch_module.eval()
        x = FromNumpyDataset(x)
        loader = torch.utils.data.DataLoader(x, batch_size=self.batch_size,
                                             shuffle=False)
        x_hat = [self.torch_module.decoder(batch.to(self.device)).cpu().detach().numpy()
                 for batch in loader]

        return np.concatenate(x_hat)



import numpy as np
import torch
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.preprocessing import OneHotEncoder
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import pandas as pd
import sys
import os

# Get the absolute path of the project root
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))

# Add project root to sys.path
sys.path.append(project_root)

# Now import the module
from src.utils.dataprep import dataprep

data = pd.read_csv('/NOBACKUP/aumona/data/rf-autoencoders/treeData.csv')
X, y = dataprep(data, label_col_idx=0, transform='normalize', cat_to_numeric=True)
X, y = np.array(X), np.array(y)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

rfgae = RF_GAE_REG(  # Number of leaf nodes for each tree
    n_components=2,  # Bottleneck dimension
    lr=1e-3,
    batch_size=32,
    weight_decay=1e-5,
    embedder_params=None,
    random_state=42,
    device=device,
    epochs=100,  # Quick test
    hidden_dims=[400, 200, 100],  # Example hidden layers
)

rfgae.fit(X, y) 

# Step 7: Encode & Decode
X_encoded = rfgae.transform(X)
X_reconstructed = rfgae.inverse_transform(X_encoded)

# # Step 8: Compare Original vs. Reconstructed (Few Samples)
# for i in range(5):
#     print(f"Original Leaf Indices: {leaf_indices[i]}")
#     print(f"Reconstructed One-Hot: {np.argmax(X_reconstructed[i].reshape(-1, 50), axis=-1)}")
#     print("-" * 50)

# Create a scatter plot
plt.figure(figsize=(8, 6))

# Create a legend with discrete color values
classes = np.unique(y)
scatter = plt.scatter(X_encoded[:, 0], X_encoded[:, 1], c=y, cmap='tab10', edgecolor='k', alpha=0.7)
handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=scatter.cmap(scatter.norm(cls)), markersize=10) for cls in classes]
plt.legend(handles, classes, title="Class Label", loc="best", frameon=True)

# Set labels and title
plt.xlabel('Embedding Dimension 1')
plt.ylabel('Embedding Dimension 2')
plt.title('Training Embeddings Colored by Labels')

# Show the plot
plt.savefig('scatterplot.png')