import os
from src.utils.numpy_dataset import FromNumpyDataset
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, random_split
from src.models.base_model import BaseModel
from src.models.torch_models import AETorchModule, EarlyStopping
import numpy as np
import logging
import wandb
from src.models.rfphate import RFPHATE
import graphtools
from src.utils.set_seeds import seed_everything
import sys



class RF_GAE(BaseModel):
    """
    Generalized Random Forest Autoencoder
    """

    def __init__(self,
                 n_components,
                 lr,
                 batch_size,
                 weight_decay,
                 random_state,
                 device,
                 dropout_prob,
                 epochs,
                 hidden_dims,
                 embedder_params,
                 early_stopping,
                 patience,
                 delta_factor,
                 save_model
                 ):

        self.n_components = n_components
        self.lr = lr
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        self.random_state = random_state
        self.device = device
        self.dropout_prob = dropout_prob
        self.epochs = epochs
        self.input_shape = None
        self.hidden_dims = hidden_dims
        self.early_stopping = early_stopping
        self.patience = patience
        self.delta_factor = delta_factor
        self.save_model = save_model 

        if embedder_params is None:
            self.embedder = RFPHATE(random_state = random_state, n_components = self.n_components)
        else: 
            self.embedder = RFPHATE(random_state = random_state, n_components = self.n_components, **embedder_params)

    def init_torch_module(self, input_shape):

        self.torch_module = AETorchModule(input_dim   = input_shape,
                                          hidden_dims = self.hidden_dims,
                                          z_dim       = self.n_components,
                                          dropout_prob=self.dropout_prob)

    def fit(self, x, y):

        self.input_shape = x.shape[1]
        self.n_samples = x.shape[0]

        self.embedder.fit_phate_op(x, y, gamma=-1, t=1)
        self.diffused_prox = self.embedder.phate_op.diff_potential  # Shape (n_samples, min(n_samples, n_landmarks))
        if isinstance(self.embedder.phate_op.graph, graphtools.graphs.LandmarkGraph):
            self.cluster_labels = self.embedder.phate_op.graph.clusters  # Landmark associated with each data point
        else:
            self.cluster_labels = range(self.n_samples)  # No clusters

        tensor_dataset = TensorDataset(torch.tensor(x, dtype = torch.float),
                                       torch.tensor(self.diffused_prox, dtype = torch.float),
                                       torch.tensor(self.cluster_labels, dtype = torch.long))

        if self.random_state is not None:
            seed_everything(self.random_state)

        self.init_torch_module(self.input_shape)

        self.optimizer = torch.optim.AdamW(self.torch_module.parameters(),
                                            lr=self.lr,
                                            weight_decay=self.weight_decay)

        train_loader = DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=True)

        self.train_loop(train_loader, self.optimizer, self.torch_module, self.epochs, self.device) 


    def compute_loss(self, x, x_hat, diff_prox, cluster_labels):
        """
        Custom reconstruction loss that minimizes the weighted sum of squared differences 
        between each reconstructed point x_hat[i] and all original points x[j] in the batch, 
        using the weights from diff_prox indexed by cluster_labels.
        """
        
        # Compute pairwise squared differences between each reconstructed x_hat[i] and original x[j]
        diff = x_hat.unsqueeze(1) - x.unsqueeze(0)  # Shape (batch_size, batch_size, input_dim)
        squared_diff = diff.pow(2).sum(dim=-1)  # Shape (batch_size, batch_size)

        # Retrieve corresponding weights using cluster_labels
        weights = diff_prox[:, cluster_labels]  # Shape (batch_size, batch_size)

        # Apply weights and sum over all j for each i
        weighted_loss = (squared_diff * weights).sum(dim=1).mean()

        self.recon_loss_temp = weighted_loss.item()
        return weighted_loss


    def train_loop(self, train_loader, optimizer, model, epochs, device = 'cpu'):
        self.epoch_losses_recon = []
        best_loss = float("inf")
        counter=0

        model.to(device)
        model.train()

        for epoch in range(epochs):
            running_recon_loss = 0

            for x, diff_prox, cluster_labels in train_loader:
                x = x.to(device)
                diff_prox = diff_prox.to(device)
                cluster_labels = cluster_labels.to(device)

                recon, _ = model(x)

                optimizer.zero_grad()
                self.compute_loss(x, recon, diff_prox, cluster_labels).backward()

                running_recon_loss += self.recon_loss_temp

                optimizer.step()

            # Track losses per epoch
            self.epoch_losses_recon.append(running_recon_loss / len(train_loader))

            # wandb.log({f"{self.random_state}: train_recon_loss": self.epoch_losses_recon[-1], "Epoch": epoch})
            # Periodic logging of losses
            if epoch % 50 == 0:
                print(f"Epoch {epoch}/{self.epochs}, Recon Loss: {self.epoch_losses_recon[-1]:.7f}")
                logging.info(f"Epoch {epoch}/{self.epochs}, Recon Loss: {self.epoch_losses_recon[-1]:.7f}")
            
            # Check for early stopping
            if self.early_stopping:
                os.makedirs(f"{self.random_state}/", exist_ok=True)
                subfolder_path = f"{self.random_state}/best_{self.random_state}.pth"
                early_stopping = EarlyStopping(patience = self.patience,
                                        delta_factor = self.delta_factor, 
                                        save_model = self.save_model, 
                                        save_path = subfolder_path)
                should_stop, best_loss, counter = early_stopping(self.epoch_losses_balanced[-1], best_loss, counter, model)
                if should_stop:
                    logging.info(f"Stopping training early at epoch {epoch}")
                    return  

    def transform(self, x):
        self.torch_module.eval()

        x = TensorDataset(torch.tensor(x, dtype = torch.float))

        loader = DataLoader(x, batch_size=self.batch_size,
                                             shuffle=False)
 
        z = [self.torch_module.encoder(batch[0].to(self.device)).cpu().detach().numpy() for batch in loader]
        return np.concatenate(z)


    def inverse_transform(self, x):
        self.torch_module.eval()
        x = FromNumpyDataset(x)
        loader = DataLoader(x, batch_size=self.batch_size, shuffle=False)
        x_hat = [self.torch_module.decoder(batch.to(self.device)).cpu().detach().numpy()
                 for batch in loader]

        return np.concatenate(x_hat)






import numpy as np
import torch
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.preprocessing import OneHotEncoder
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import pandas as pd
import sys
import os

# Get the absolute path of the project root
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))

# Add project root to sys.path
sys.path.append(project_root)

# Now import the module
from src.utils.dataprep import dataprep

data = pd.read_csv('/NOBACKUP/aumona/data/rf-autoencoders/treeData.csv')
X, y = dataprep(data, label_col_idx=0, transform='normalize', cat_to_numeric=True)
X, y = np.array(X), np.array(y)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

rfgae = RF_GAE(  # Number of leaf nodes for each tree
    n_components=2,  # Bottleneck dimension
    lr=1e-3,
    batch_size=32,
    weight_decay=1e-5,
    embedder_params=None,
    random_state=42,
    device=device,
    dropout_prob=0,
    epochs=100,  # Quick test
    hidden_dims=[400, 200, 100],  # Example hidden layers
    early_stopping=False,
    patience=5,
    delta_factor=0.01,
    save_model=False
)

rfgae.fit(X, y) 

# Step 7: Encode & Decode
X_encoded = rfgae.transform(X)
X_reconstructed = rfgae.inverse_transform(X_encoded)

# # Step 8: Compare Original vs. Reconstructed (Few Samples)
# for i in range(5):
#     print(f"Original Leaf Indices: {leaf_indices[i]}")
#     print(f"Reconstructed One-Hot: {np.argmax(X_reconstructed[i].reshape(-1, 50), axis=-1)}")
#     print("-" * 50)

# Create a scatter plot
plt.figure(figsize=(8, 6))

# Create a legend with discrete color values
classes = np.unique(y)
scatter = plt.scatter(X_encoded[:, 0], X_encoded[:, 1], c=y, cmap='tab10', edgecolor='k', alpha=0.7)
handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=scatter.cmap(scatter.norm(cls)), markersize=10) for cls in classes]
plt.legend(handles, classes, title="Class Label", loc="best", frameon=True)

# Set labels and title
plt.xlabel('Embedding Dimension 1')
plt.ylabel('Embedding Dimension 2')
plt.title('Training Embeddings Colored by Labels')

# Show the plot
plt.savefig('scatter_plot_rfae_leaf.png')