import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging
import wandb
from src.models.rfphate import RFPHATE
from torch.utils.data import TensorDataset, DataLoader
from src.utils.numpy_dataset import FromNumpyDataset
from src.models.base_model import BaseModel
from src.models.torch_models import AETorchModule
from sklearn.preprocessing import OneHotEncoder
from joblib import Parallel, delayed
import sys




class RFAE_Leaves(BaseModel):
    """
    Categorical Random Forest Autoencoder for one-hot leaf encoded inputs, with RF-PHATE regularization
    """

    def __init__(self,
                 n_components,
                 lr,
                 batch_size,
                 weight_decay,
                 random_state,
                 device,
                 dropout_prob,
                 epochs,
                 hidden_dims,
                 embedder_params,
                 lam,
                 save_model
                 ):
        super().__init__()

        self.n_components = n_components
        self.lr = lr
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        self.random_state = random_state
        self.device = device
        self.dropout_prob = dropout_prob
        self.epochs = epochs
        self.hidden_dims = hidden_dims
        self.lam = lam
        self.save_model = save_model

        if embedder_params is None:
            self.embedder = RFPHATE(random_state = random_state, n_components = self.n_components)
        else: 
            self.embedder = RFPHATE(random_state = random_state, n_components = self.n_components, **embedder_params)

    def init_torch_module(self):
        self.torch_module = AETorchModule(
            input_dim=self.input_dim,
            hidden_dims=self.hidden_dims,
            z_dim=self.n_components,
            dropout_prob=self.dropout_prob
        )


    def _prune_forest(self, forest, ccp_alpha=0.01):
        """
        Applies `_prune_tree` to each tree in a trained RandomForestClassifier.
        """
        for tree in forest.estimators_:
            tree.set_params(ccp_alpha=ccp_alpha)  # Set the pruning parameter
            tree._prune_tree()


    def fit(self, x, y):

        if self.lam < 1:
            self.z_target = self.embedder.fit_transform(x, y)
        else:
            self.embedder.fit(x, y)
            self.z_target = np.zeros((x.shape[0], self.n_components))

        # Extract Leaf Indices
        leaf_indices = self.embedder.apply(x)  # Shape (n_samples, n_trees), each entry is a leaf index

        # One-Hot Encode the Leaves
        self.onehot_encoder = OneHotEncoder(sparse_output=False)
        x_leaves_onehot = self.onehot_encoder.fit_transform(leaf_indices) # Shape (n_samples, n_leaves)
        print(f"Shape of one-hot encoded leaf matrix: {x_leaves_onehot.shape}")

        # Prune the trees in the forest
        # Assuming self.embedder is an instance of RandomForestClassifier
        self._prune_forest(self.embedder, ccp_alpha=0.02)
        # Extract Leaf Indices
        leaf_indices = self.embedder.apply(x)  # Shape (n_samples, n_trees), each entry is a leaf index
        # One-Hot Encode the Leaves
        self.onehot_encoder = OneHotEncoder(sparse_output=False)
        x_leaves_onehot = self.onehot_encoder.fit_transform(leaf_indices) # Shape (n_samples, n_leaves)
        print(f"X-leaves one-hot: {x_leaves_onehot}")
        # Check if the one-hot encoded matrix is full of zeros
        if np.all(x_leaves_onehot == 0):
            print("One-hot encoded matrix is full of zeros. Please check the input data.")
        else:
            print("One-hot encoded matrix is not full of zeros.")
        print(f"Shape of one-hot encoded leaf matrix: {x_leaves_onehot.shape}")



        # Get unique category counts per tree
        self.cat_sizes = [len(np.unique(leaf_indices[:, i])) for i in range(leaf_indices.shape[1])]

        self.input_dim = sum(self.cat_sizes)  # Total input size (sum of one-hot vectors)

        tensor_dataset = TensorDataset(torch.tensor(x_leaves_onehot, dtype=torch.float),
                                       torch.tensor(self.z_target, dtype=torch.float))

        if self.random_state is not None:
            torch.manual_seed(self.random_state)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        self.init_torch_module()

        self.optimizer = torch.optim.AdamW(
            self.torch_module.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay
        )

        self.criterion_recon = nn.CrossEntropyLoss(reduction='mean')
        self.criterion_geo = nn.MSELoss(reduction='mean')

        train_loader = DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=True)
        self.train_loop(train_loader, self.optimizer, self.torch_module, self.epochs, self.device)

    def compute_loss(self, x, x_hat, z, z_target):

        # Reconstruction loss
        split_x = torch.split(x, self.cat_sizes, dim=1)
        split_x_hat = torch.split(x_hat, self.cat_sizes, dim=1)
        loss_recon = sum(self.criterion_recon(x_hat_i, torch.argmax(x_i, dim=1)) for x_i, x_hat_i in zip(split_x, split_x_hat)) / len(split_x)
        self.recon_loss_temp = loss_recon.item()

        # Geometric loss
        loss_emb = self.criterion_geo(z_target, z)
        self.emb_loss_temp = loss_emb.item()

        balanced_loss = self.lam * loss_recon + (1 - self.lam) * loss_emb
        self.balanced_loss = balanced_loss.item()
        return balanced_loss

    def train_loop(self, train_loader, optimizer, model, epochs, device='cpu'):
        self.epoch_losses_recon = []
        self.epoch_losses_emb = []
        self.epoch_losses_balanced = []
        best_loss = float("inf")
        counter=0

        model.to(device)
        model.train()

        for epoch in range(epochs):
            running_recon_loss = 0
            running_emb_loss = 0
            running_balanced_loss = 0

            for x, z_target in train_loader:
                x = x.to(device)
                z_target = z_target.to(device)

                recon, z = model(x)

                optimizer.zero_grad()
                self.compute_loss(x, recon, z, z_target).backward()

                running_recon_loss += self.recon_loss_temp
                running_emb_loss += self.emb_loss_temp
                running_balanced_loss += self.balanced_loss

                optimizer.step()

            # Track losses per epoch
            self.epoch_losses_recon.append(running_recon_loss / len(train_loader))
            self.epoch_losses_emb.append(running_emb_loss / len(train_loader))
            self.epoch_losses_balanced.append(running_balanced_loss / len(train_loader))

            wandb.log({f"{self.random_state}: train_recon_loss": self.epoch_losses_recon[-1], "Epoch": epoch})
            wandb.log({f"{self.random_state}: train_emb_loss": self.epoch_losses_emb[-1], "Epoch": epoch})
            wandb.log({f"{self.random_state}: train_balanced_loss": self.epoch_losses_balanced[-1], "Epoch": epoch})
            # Periodic logging of losses
            if epoch % 10 == 0:
                logging.info(f"Epoch {epoch}/{self.epochs}, Recon Loss: {self.epoch_losses_recon[-1]:.7f}, Geo Loss: {self.epoch_losses_emb[-1]:.7f}")

    def transform(self, x):
        self.torch_module.eval()

        leaf_indices = self.embedder.apply(x)  # Shape (n_samples, n_trees), each entry is a leaf index
        x_leaves_onehot = self.onehot_encoder.transform(leaf_indices) # Shape (n_samples, n_leaves)
        print(f"Shape of one-hot encoded leaf matrix (test): {x_leaves_onehot.shape}")

        tensor_dataset = TensorDataset(torch.tensor(x_leaves_onehot, dtype=torch.float))
        loader = DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=False)

        z = [self.torch_module.encoder(batch[0].to(self.device)).cpu().detach().numpy() for batch in loader]
        return np.concatenate(z)

    def inverse_transform(self, x):
        self.torch_module.eval()
        x = FromNumpyDataset(x)
        loader = DataLoader(x, batch_size=self.batch_size, shuffle=False)

        x_hat = [self.torch_module.decoder(batch.to(self.device)).cpu().detach().numpy() for batch in loader]

        return np.concatenate(x_hat)
    
