import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging
import wandb
import kmedoids

from torch.utils.data import TensorDataset, DataLoader
from src.models.rfphate import RFPHATE
from src.models.rfgap import RFGAP
from umap import ParametricUMAP, UMAP
from cne import CNE
from sklearn.preprocessing import MinMaxScaler
from src.models.base_model import BaseModel
from src.models.torch_models import ProxAETorchModule
from src.utils.numpy_dataset import FromNumpyDataset
from src.utils.set_seeds import seed_everything


class JSDivLoss(nn.Module):
    def __init__(self, reduction='batchmean', eps=1e-8):
        super().__init__()
        self.reduction = reduction
        self.eps = eps

    def forward(self, p, q):
        p = p.clamp(min=self.eps, max=1.0)
        q = q.clamp(min=self.eps, max=1.0)
        m = 0.5 * (p + q)
        return 0.5 * (
            F.kl_div(m.log(), p, reduction=self.reduction) +
            F.kl_div(m.log(), q, reduction=self.reduction)
        )


class RFAE(BaseModel):
    def __init__(self,
                 n_components,
                 lr,
                 batch_size,
                 weight_decay,
                 random_state,
                 device,
                 epochs,
                 hidden_dims,
                 emb_constraint,
                 embedder_params,
                 lam,
                 pct_prototypes,
                 dropout_prob,
                 save_model,
                 recon_loss_type):

        self.n_components = n_components
        self.lr = lr
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        self.random_state = random_state
        self.device = device
        self.epochs = epochs
        self.hidden_dims = hidden_dims
        self.emb_constraint = emb_constraint
        self.lam = lam
        self.pct_prototypes = pct_prototypes
        self.dropout_prob = dropout_prob
        self.save_model = save_model
        self.recon_loss_type = recon_loss_type.lower()

        self.embedder_params = embedder_params if embedder_params is not None else {}
        if self.emb_constraint == 'rfphate':
            self.embedder = RFPHATE(random_state=random_state, n_components=n_components, **self.embedder_params)
        
        # No support for other embeddings' params for now. Use defaults.
        elif self.emb_constraint == 'sumap':
            self.embedder = ParametricUMAP(n_components=n_components, random_state=random_state, verbose=False)
        elif self.emb_constraint == 'umap':
            self.embedder = CNE(embd_dim=n_components, loss_mode='infonce', s=1, parametric=True, device=device, seed=random_state)
        elif self.emb_constraint == 'rfumap':
            self.embedder = UMAP(n_components=n_components, metric='precomputed', random_state=random_state, verbose=False)
        else:
            raise ValueError(f"Unsupported embedding constraint: {self.emb_constraint}")
        

    def init_torch_module(self, input_shape):
        output_activation = {
            'kl': 'log_softmax',
            'jsd': 'softmax',
            'mse': 'softmax'
        }[self.recon_loss_type]

        self.torch_module = ProxAETorchModule(
            input_dim=input_shape,
            hidden_dims=self.hidden_dims,
            z_dim=self.n_components,
            dropout_prob=self.dropout_prob,
            output_activation=output_activation
        )

        if self.recon_loss_type == 'kl':
            self.criterion_recon = nn.KLDivLoss(reduction="batchmean")
        elif self.recon_loss_type == 'jsd':
            self.criterion_recon = JSDivLoss(reduction='batchmean')
        elif self.recon_loss_type == 'mse':
            self.criterion_recon = nn.MSELoss(reduction="mean")
        else:
            raise ValueError(f"Unknown recon_loss_type: {self.recon_loss_type}")

    def fit(self, x, y):
        self.input_shape = x.shape[0]
        self.labels = y

        if self.random_state is not None:
            seed_everything(self.random_state)

        if self.emb_constraint == 'rfphate': 
            self.z_target = self.embedder.fit_transform(x, y)
            self.training_proximities = self.embedder.proximity_asym.toarray() # Asymmetric proximities for consistency with OOS RGAP vectors
        elif self.emb_constraint in ['umap', 'sumap', 'rfumap']:
            # Compute RFGAP proximities separately
            self.rfgap = RFGAP(random_state=self.random_state, **self.embedder_params)
            self.rfgap.fit(x, y)
            self.training_proximities = self.rfgap.get_proximities().toarray()
            self.training_proximities = self.rfgap.proximity_asym.toarray()
            if self.emb_constraint in ['umap', 'sumap']:
                self.z_target = self.embedder.fit_transform(x, y)
            else:
                scaler = MinMaxScaler()
                K = scaler.fit_transform(self.training_proximities)
                D = np.sqrt(1 - K)
                self.z_target = self.embedder.fit_transform(D)
        else:
            raise ValueError(f"Unsupported embedding constraint: {self.emb_constraint}")

        if 0 < self.pct_prototypes < 1:
            # Max normalize each row, symmetrize, set diagonal to 1
            row_max = self.training_proximities.max(axis=1, keepdims=True)
            row_max[row_max == 0] = 1.0
            prox_matrix = self.training_proximities / row_max
            prox_matrix = 0.5 * (prox_matrix + prox_matrix.T)
            np.fill_diagonal(prox_matrix, 1.0)
            
            # Create distance matrix
            dist_matrix = 1 - prox_matrix

            k = int(self.pct_prototypes * self.input_shape)
            classes = np.unique(y)
            k_per_class = max(1, k // len(classes))

            prototype_indices = []
            for cls in classes:
                cls_indices = np.where(y == cls)[0]
                if len(cls_indices) <= k_per_class:
                    prototype_indices.extend(cls_indices)
                    continue

                sub_dists = dist_matrix[np.ix_(cls_indices, cls_indices)]
                km = kmedoids.KMedoids(k_per_class, method='fasterpam', random_state=self.random_state)
                km.fit(sub_dists)
                prototype_indices.extend(cls_indices[km.medoid_indices_])

            self.prototype_indices = np.array(prototype_indices, dtype=int)
            self.training_proximities = self.training_proximities[:, self.prototype_indices]
            self.init_torch_module(len(self.prototype_indices))
        else:
            self.prototype_indices = np.arange(self.input_shape)
            self.init_torch_module(self.input_shape)

        self.optimizer = torch.optim.AdamW(self.torch_module.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        self.criterion_geo = nn.MSELoss()

        training_proximities = torch.tensor(self.training_proximities, dtype=torch.float)
        training_proximities = F.normalize(training_proximities, p=1)

        train_loader = DataLoader(TensorDataset(training_proximities, torch.tensor(self.z_target, dtype=torch.float)),
                                  batch_size=self.batch_size, shuffle=True)

        self.train_loop(self.torch_module, self.epochs, train_loader, self.optimizer, self.device)
    
    def fit_transform(self, x, y):
        self.fit(x, y)
        return self.transform(self.training_proximities, precomputed=True)

    def compute_loss(self, x, x_hat, z_target, z):
        loss_recon = self.criterion_recon(x_hat, x)
        loss_emb = self.criterion_geo(z_target, z)

        self.recon_loss_temp = loss_recon.item()
        self.emb_loss_temp = loss_emb.item()

        balanced_loss = self.lam * loss_recon + (1 - self.lam) * loss_emb
        self.balanced_loss = balanced_loss.item()
        return balanced_loss

    def train_loop(self, model, epochs, train_loader, optimizer, device = 'cpu'):
        self.epoch_losses_recon = []
        self.epoch_losses_emb = []
        self.epoch_losses_balanced = []

        model.to(device)
        model.train()

        for epoch in range(epochs):
            running_recon_loss = 0
            running_emb_loss = 0
            running_balanced_loss = 0

            for x, z_target in train_loader:
                x = x.to(device)
                z_target = z_target.to(device)

                recon, z = model(x)

                optimizer.zero_grad()
                self.compute_loss(x, recon, z, z_target).backward()

                running_recon_loss += self.recon_loss_temp
                running_emb_loss += self.emb_loss_temp
                running_balanced_loss += self.balanced_loss

                optimizer.step()

            # Track losses per epoch
            self.epoch_losses_recon.append(running_recon_loss / len(train_loader))
            self.epoch_losses_emb.append(running_emb_loss / len(train_loader))
            self.epoch_losses_balanced.append(running_balanced_loss / len(train_loader))

            wandb.log({f"{self.random_state}: train_recon_loss": self.epoch_losses_recon[-1], "Epoch": epoch})
            wandb.log({f"{self.random_state}: train_emb_loss": self.epoch_losses_emb[-1], "Epoch": epoch})
            wandb.log({f"{self.random_state}: train_balanced_loss": self.epoch_losses_balanced[-1], "Epoch": epoch})
            # Periodic logging of losses
            if epoch % 50 == 0:
                logging.info(f"Epoch {epoch}/{self.epochs}, Recon Loss: {self.epoch_losses_recon[-1]:.7f}, Geo Loss: {self.epoch_losses_emb[-1]:.7f}")


    def transform(self, x, precomputed=False):
        self.torch_module.eval()

        if not precomputed:
            if self.emb_constraint == 'rfphate':
                x = self.embedder.prox_extend(x, self.prototype_indices).toarray()
            elif self.emb_constraint in ['umap', 'sumap', 'rfumap']:
                x = self.rfgap.prox_extend(x, self.prototype_indices).toarray()
            else:
                raise ValueError(f"Unsupported embedding constraint: {self.emb_constraint}")

        x = torch.tensor(x, dtype=torch.float)
        x = F.normalize(x, p=1)
        
        loader = DataLoader(TensorDataset(x), batch_size=self.batch_size, shuffle=False)

        z = []
        with torch.no_grad():
            for batch in loader:
                z_batch = self.torch_module.encoder(batch[0].to(self.device)).cpu().numpy()
                z.append(z_batch)
        
        return np.concatenate(z)


    def inverse_transform(self, x):
        self.torch_module.eval()
        x = FromNumpyDataset(x)
        loader = DataLoader(x, batch_size=self.batch_size, shuffle=False)
        x_hat = [self.torch_module.final_activation(self.torch_module.decoder(batch.to(self.device)))
                 .cpu().detach().numpy() for batch in loader]

        return np.concatenate(x_hat)
        