import os
from src.utils.numpy_dataset import FromNumpyDataset
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from src.models.rfphate import RFPHATE
from src.models.base_model import BaseModel
from src.models.torch_models import ProxAETorchModule, EarlyStopping
import numpy as np
from src.utils.set_seeds import seed_everything
import kmedoids
import logging
import wandb


class RFAE(BaseModel):

    """
    This model takes the row normalized RF proximities as input and attempts to reconstruct them.
    Introduces a prototype selection step to select a subset of the data points as representatives (prototypes).
    Adds RF-PHATE Geometric Regularization into the bottleneck layer with hyperparameter lam between
    0 (no reconstruction) and 1 (reconstruction only).
    """

    def __init__(self,
                 n_components,
                 lr,
                 batch_size,
                 weight_decay,
                 random_state,
                 device,
                 epochs,
                 n_pca,
                 hidden_dims,
                 embedder_params,
                 lam,
                 loss_scaling,
                 pct_prototypes,
                 dropout_prob,
                 early_stopping,
                 patience,
                 delta_factor,
                 save_model
                 ):

        self.n_components = n_components
        self.lr = lr
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        self.random_state = random_state
        self.device = device
        self.epochs = epochs
        self.n_pca = n_pca
        self.hidden_dims = hidden_dims
        self.lam = lam
        self.loss_scaling = loss_scaling
        self.pct_prototypes = pct_prototypes
        self.dropout_prob = dropout_prob
        self.early_stopping = early_stopping
        self.patience = patience
        self.delta_factor = delta_factor
        self.save_model = save_model

        if embedder_params is None:
            self.embedder = RFPHATE(random_state = random_state, n_components = self.n_components)
        else: 
            self.embedder = RFPHATE(random_state = random_state, n_components = self.n_components, **embedder_params)

    def init_torch_module(self, input_shape):

        self.torch_module = ProxAETorchModule(input_dim   = input_shape,
                                              hidden_dims = self.hidden_dims,
                                              z_dim       = self.n_components,
                                              dropout_prob = self.dropout_prob,
                                              output_activation = "log_softmax")
    def fit(self, x, y):
            
        self.input_shape = x.shape[0]
        self.labels = y

        self.z_target = self.embedder.fit_transform(x, y)
        self.training_proximities = self.embedder.proximity.toarray()  # it is csr.sparse!

        if self.random_state is not None:
            seed_everything(self.random_state)
        
        if 0 < self.pct_prototypes < 1:
            prox_matrix = self.training_proximities / np.max(self.training_proximities)
            dist_matrix = 1 - prox_matrix

            # Total number of prototypes
            k = int(self.pct_prototypes * self.input_shape)

            # Classwise K-Medoids
            classes = np.unique(y)
            n_classes = len(classes)
            k_per_class = max(1, k // n_classes)

            prototype_indices = []
            for cls in classes:
                cls_indices = np.where(y == cls)[0]
                if len(cls_indices) <= k_per_class:
                    # Use all points if not enough for K-Medoids
                    prototype_indices.extend(cls_indices)
                    continue

                sub_dists = dist_matrix[np.ix_(cls_indices, cls_indices)]
                km = kmedoids.KMedoids(k_per_class, method='fasterpam', random_state=self.random_state)
                km.fit(sub_dists)
                selected = cls_indices[km.medoid_indices_]
                prototype_indices.extend(selected)

            self.prototype_indices = np.array(prototype_indices, dtype=int)

            # Reduce proximity matrix to selected prototypes
            self.training_proximities = self.training_proximities[:, self.prototype_indices]
            self.init_torch_module(len(self.prototype_indices))
        else:
            self.prototype_indices = np.arange(self.input_shape)
            self.init_torch_module(self.input_shape)

        self.optimizer = torch.optim.AdamW(self.torch_module.parameters(),
                                            lr=self.lr,
                                            weight_decay=self.weight_decay)

        self.criterion_recon = nn.KLDivLoss(reduction="batchmean")
        self.criterion_geo = nn.MSELoss()

        # Row-normalized Tensor proximities
        training_proximities = torch.tensor(self.training_proximities, dtype=torch.float)
        training_proximities = F.normalize(training_proximities, p=1)

        # Training dataset
        tensor_dataset = TensorDataset(training_proximities, torch.tensor(self.z_target, dtype=torch.float))

        train_loader = DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=True)
        
        self.train_loop(self.torch_module, self.epochs, train_loader, self.optimizer, self.device)

    def fit_transform(self, x, y):
        self.fit(x, y)
        self.z_latent = self.transform(self.training_proximities, precomputed=True)
        return self.z_latent


    def compute_loss(self, x, x_hat, z_target, z):

        loss_recon = self.criterion_recon(x_hat, x)
        loss_emb = self.criterion_geo(z_target, z)

        self.recon_loss_temp = loss_recon.item()
        self.emb_loss_temp = loss_emb.item()

        # Dynamic scaling factors for balancing magnitudes
        # Compute scaling factors only once and store them
        if self.loss_scaling and (not hasattr(self, "scale_recon") or not hasattr(self, "scale_emb")):
            self.scale_recon = 1 / (loss_recon.detach().mean() + 1e-8)
            self.scale_emb = 1 / (loss_emb.detach().mean() + 1e-8)
        else:
            self.scale_recon = 1
            self.scale_emb = 1
        balanced_loss = self.lam * loss_recon * self.scale_recon + (1 - self.lam) * loss_emb * self.scale_emb
        self.balanced_loss = balanced_loss.item()
        return balanced_loss

    def train_loop(self, model, epochs, train_loader, optimizer, device = 'cpu'):
        self.epoch_losses_recon = []
        self.epoch_losses_emb = []
        self.epoch_losses_balanced = []
        best_loss = float("inf")
        counter=0

        model.to(device)
        model.train()

        for epoch in range(epochs):
            running_recon_loss = 0
            running_emb_loss = 0
            running_balanced_loss = 0

            for x, z_target in train_loader:
                x = x.to(device)
                z_target = z_target.to(device)

                recon, z = model(x)

                optimizer.zero_grad()
                self.compute_loss(x, recon, z, z_target).backward()

                running_recon_loss += self.recon_loss_temp
                running_emb_loss += self.emb_loss_temp
                running_balanced_loss += self.balanced_loss

                optimizer.step()

            # Track losses per epoch
            self.epoch_losses_recon.append(running_recon_loss / len(train_loader))
            self.epoch_losses_emb.append(running_emb_loss / len(train_loader))
            self.epoch_losses_balanced.append(running_balanced_loss / len(train_loader))

            wandb.log({f"{self.random_state}: train_recon_loss": self.epoch_losses_recon[-1], "Epoch": epoch})
            wandb.log({f"{self.random_state}: train_emb_loss": self.epoch_losses_emb[-1], "Epoch": epoch})
            wandb.log({f"{self.random_state}: train_balanced_loss": self.epoch_losses_balanced[-1], "Epoch": epoch})
            # Periodic logging of losses
            if epoch % 50 == 0:
                logging.info(f"Epoch {epoch}/{self.epochs}, Recon Loss: {self.epoch_losses_recon[-1]:.7f}, Geo Loss: {self.epoch_losses_emb[-1]:.7f}")

            # Check for early stopping
            if self.early_stopping:
                os.makedirs(f"{self.random_state}/", exist_ok=True)
                subfolder_path = f"{self.random_state}/best_{self.random_state}.pth"
                early_stopping = EarlyStopping(patience = self.patience,
                                        delta_factor = self.delta_factor, 
                                        save_model = self.save_model, 
                                        save_path = subfolder_path)
                should_stop, best_loss, counter = early_stopping(self.epoch_losses_balanced[-1], best_loss, counter, model)
                if should_stop:
                    logging.info(f"Stopping training early at epoch {epoch}")
                    return  

    def evaluate_recon(self, x, precomputed=False):
        self.torch_module.eval()
        total_kl_div  = 0
        total_samples = 0

        if not precomputed:
            x = self.embedder.prox_extend(x, self.prototype_indices).toarray()
        
        x = torch.tensor(x, dtype=torch.float)
        x = F.normalize(x, p=1)
        
        loader = DataLoader(TensorDataset(x), batch_size=self.batch_size, shuffle=False)
        
        with torch.no_grad():
            for x_batch in loader:
                batch_size = x_batch.size(0)
                x_batch.to(self.device)

                recon, _ = self.torch_module(x)

                total_kl_div  += self.criterion_recon(recon, x).item() * batch_size
                total_samples += batch_size
        
        return total_kl_div / total_samples


    def transform(self, x, precomputed=False):
        self.torch_module.eval()

        if not precomputed:
            x = self.embedder.prox_extend(x, self.prototype_indices).toarray()

        x = torch.tensor(x, dtype=torch.float)
        x = F.normalize(x, p=1)
        
        loader = DataLoader(TensorDataset(x), batch_size=self.batch_size, shuffle=False)

        z = []
        with torch.no_grad():
            for batch in loader:
                z_batch = self.torch_module.encoder(batch[0].to(self.device)).cpu().numpy()
                z.append(z_batch)
        
        return np.concatenate(z)


    def inverse_transform(self, x):
        self.torch_module.eval()
        x = FromNumpyDataset(x)
        loader = DataLoader(x, batch_size=self.batch_size, shuffle=False)
        x_hat = [self.torch_module.final_activation(self.torch_module.decoder(batch.to(self.device)))
                 .cpu().detach().numpy() for batch in loader]

        return np.concatenate(x_hat)
        