import os
from src.utils.numpy_dataset import FromNumpyDataset
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, random_split
from src.models.base_model import BaseModel
from src.models.torch_models import AETorchModule, EarlyStopping
import numpy as np
import logging
import wandb
from src.models.rfphate import RFPHATE
from src.utils.set_seeds import seed_everything



class RF_GRAE(BaseModel):
    """
    GRAE with RF-PHATE regularization. This model is only used for toy comparison with RFAE.
    """

    def __init__(self,
                 n_components,
                 lr,
                 batch_size,
                 weight_decay,
                 random_state,
                 device,
                 dropout_prob,
                 epochs,
                 hidden_dims,
                 embedder_params,
                 lam,
                 early_stopping,
                 patience,
                 delta_factor,
                 save_model
                 ):

        self.n_components = n_components
        self.lr = lr
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        self.random_state = random_state
        self.device = device
        self.dropout_prob = dropout_prob
        self.epochs = epochs
        self.data_shape = None
        self.hidden_dims = hidden_dims
        self.lam = lam
        self.early_stopping = early_stopping
        self.patience = patience
        self.delta_factor = delta_factor
        self.save_model = save_model 

        if embedder_params is None:
            self.embedder = RFPHATE(random_state = random_state, n_components = self.n_components)
        else: 
            self.embedder = RFPHATE(random_state = random_state, n_components = self.n_components, **embedder_params)

    def init_torch_module(self, data_shape):

        input_size = data_shape

        self.torch_module = AETorchModule(input_dim   = input_size,
                                          hidden_dims = self.hidden_dims,
                                          z_dim       = self.n_components,
                                          dropout_prob=self.dropout_prob)

    def fit(self, x, y):

        self.data_shape = x.shape[1]
        self.z_target = self.embedder.fit_transform(x, y)
        tensor_dataset = TensorDataset(torch.tensor(x, dtype = torch.float), torch.tensor(self.z_target, dtype=torch.float))

        if self.random_state is not None:
            seed_everything(self.random_state)

        self.init_torch_module(self.data_shape)

        self.optimizer = torch.optim.AdamW(self.torch_module.parameters(),
                                            lr=self.lr,
                                            weight_decay=self.weight_decay)
        self.criterion = nn.MSELoss()

        train_loader = DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=True)

        self.train_loop(train_loader, self.optimizer, self.torch_module, self.epochs, self.device) 


    def compute_loss(self, x, x_hat, z, z_target):

        loss_recon = self.criterion(x_hat, x)
        loss_emb = self.criterion(z_target, z)

        self.recon_loss_temp = loss_recon.item()
        self.emb_loss_temp = loss_emb.item()

        balanced_loss = self.lam * loss_recon  + (1 - self.lam) * loss_emb 
        self.balanced_loss = balanced_loss.item()
        return balanced_loss


    def train_loop(self, train_loader, optimizer, model, epochs, device = 'cpu'):
        self.epoch_losses_recon = []
        self.epoch_losses_emb = []
        self.epoch_losses_balanced = []
        best_loss = float("inf")
        counter=0

        model.to(device)
        model.train()

        for epoch in range(epochs):
            running_recon_loss = 0
            running_emb_loss = 0
            running_balanced_loss = 0

            for x, z_target in train_loader:
                x = x.to(device)
                z_target = z_target.to(device)

                recon, z = model(x)

                optimizer.zero_grad()
                self.compute_loss(x, recon, z, z_target).backward()

                running_recon_loss += self.recon_loss_temp
                running_emb_loss += self.emb_loss_temp
                running_balanced_loss += self.balanced_loss

                optimizer.step()

            # Track losses per epoch
            self.epoch_losses_recon.append(running_recon_loss / len(train_loader))
            self.epoch_losses_emb.append(running_emb_loss / len(train_loader))
            self.epoch_losses_balanced.append(running_balanced_loss / len(train_loader))

            wandb.log({f"{self.random_state}: train_recon_loss": self.epoch_losses_recon[-1], "Epoch": epoch})
            wandb.log({f"{self.random_state}: train_emb_loss": self.epoch_losses_emb[-1], "Epoch": epoch})
            wandb.log({f"{self.random_state}: train_balanced_loss": self.epoch_losses_balanced[-1], "Epoch": epoch})
            # Periodic logging of losses
            if epoch % 50 == 0:
                logging.info(f"Epoch {epoch}/{self.epochs}, Recon Loss: {self.epoch_losses_recon[-1]:.7f}, Geo Loss: {self.epoch_losses_emb[-1]:.7f}")
            
            # Check for early stopping
            if self.early_stopping:
                os.makedirs(f"{self.random_state}/", exist_ok=True)
                subfolder_path = f"{self.random_state}/best_{self.random_state}.pth"
                early_stopping = EarlyStopping(patience = self.patience,
                                        delta_factor = self.delta_factor, 
                                        save_model = self.save_model, 
                                        save_path = subfolder_path)
                should_stop, best_loss, counter = early_stopping(self.epoch_losses_balanced[-1], best_loss, counter, model)
                if should_stop:
                    logging.info(f"Stopping training early at epoch {epoch}")
                    return  

    def transform(self, x):
        self.torch_module.eval()

        x = TensorDataset(torch.tensor(x, dtype = torch.float))

        loader = DataLoader(x, batch_size=self.batch_size,
                                             shuffle=False)
 
        z = [self.torch_module.encoder(batch[0].to(self.device)).cpu().detach().numpy() for batch in loader]
        return np.concatenate(z)


    def inverse_transform(self, x):
        self.torch_module.eval()
        x = FromNumpyDataset(x)
        loader = DataLoader(x, batch_size=self.batch_size, shuffle=False)
        x_hat = [self.torch_module.decoder(batch.to(self.device)).cpu().detach().numpy()
                 for batch in loader]

        return np.concatenate(x_hat)