import os
from src.utils.numpy_dataset import FromNumpyDataset
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, random_split
from src.models.base_model import BaseModel
from src.models.torch_models import AETorchModule, EarlyStopping
import numpy as np
import logging
import wandb

class AE(BaseModel):
    """
    Vanilla Autoencoder.
    """

    def __init__(self,
                 n_components,
                 lr,
                 batch_size,
                 weight_decay,
                 random_state,
                 device,
                 dropout_prob,
                 epochs,
                 hidden_dims, 
                 early_stopping,
                 patience,
                 delta_factor,
                 save_model
                 ):

        self.n_components = n_components
        self.lr = lr
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        self.random_state = random_state
        self.device = device
        self.dropout_prob = dropout_prob
        self.epochs = epochs
        self.data_shape = None
        self.hidden_dims = hidden_dims
        self.early_stopping = early_stopping
        self.patience = patience
        self.delta_factor = delta_factor
        self.save_model = save_model 

    def init_torch_module(self, data_shape):

        input_size = data_shape

        self.torch_module = AETorchModule(input_dim   = input_size,
                                          hidden_dims = self.hidden_dims,
                                          z_dim       = self.n_components,
                                          dropout_prob=self.dropout_prob)

    def fit(self, x, y):

        self.data_shape = x.shape[1]
        tensor_dataset = TensorDataset(torch.tensor(x, dtype = torch.float))

        if self.random_state is not None:

            torch.manual_seed(self.random_state)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        self.init_torch_module(self.data_shape)

        self.optimizer = torch.optim.AdamW(self.torch_module.parameters(),
                                            lr=self.lr,
                                            weight_decay=self.weight_decay)
        self.criterion = nn.MSELoss()

        train_loader = DataLoader(tensor_dataset, batch_size=self.batch_size, shuffle=True)

        self.train_loop(train_loader, self.optimizer, self.torch_module, self.epochs, self.device) 


    def compute_loss(self, x, x_hat):

        loss_recon = self.criterion(x, x_hat)

        self.recon_loss_temp = loss_recon.item()

        loss_recon.backward()


    def train_loop(self, train_loader, optimizer, model, epochs, device = 'cpu'):

        self.epoch_losses_recon = []
        best_loss = float("inf")
        counter=0
        
        for _, epoch in enumerate(range(epochs)):
            model.to(device)
            model.train()

            running_recon_loss = 0

            for _, (x) in enumerate(train_loader, 0):
                x = x[0].to(device)
                optimizer.zero_grad()
                recon, _ = model(x)
                self.compute_loss(x, recon)
                running_recon_loss += self.recon_loss_temp
                optimizer.step()

            self.epoch_losses_recon.append(running_recon_loss / len(train_loader))
            wandb.log({f"{self.random_state}: train_recon_loss": self.epoch_losses_recon[-1], "Epoch": epoch})

            if epoch%50 == 0:
                logging.info(f"Epoch {epoch}/{self.epochs}, Recon Loss: {self.epoch_losses_recon[-1]:.7f}")

            # Check for early stopping
            if self.early_stopping:
                os.makedirs(f"{self.random_state}/", exist_ok=True)
                subfolder_path = f"{self.random_state}/best_{self.random_state}.pth"
                early_stopping = EarlyStopping(patience = self.patience,
                                        delta_factor = self.delta_factor, 
                                        save_model = self.save_model, 
                                        save_path = subfolder_path)
                should_stop, best_loss, counter = early_stopping(self.epoch_losses_recon[-1], best_loss, counter, model)
                if should_stop:
                    logging.info(f"Stopping training early at epoch {epoch}")
                    return  


    def transform(self, x):
        self.torch_module.eval()

        x = TensorDataset(torch.tensor(x, dtype = torch.float))

        loader = DataLoader(x, batch_size=self.batch_size,
                                             shuffle=False)
 
        z = [self.torch_module.encoder(batch[0].to(self.device)).cpu().detach().numpy() for batch in loader]
        return np.concatenate(z)


    def inverse_transform(self, x):
        self.torch_module.eval()
        x = FromNumpyDataset(x)
        loader = DataLoader(x, batch_size=self.batch_size, shuffle=False)
        x_hat = [self.torch_module.decoder(batch.to(self.device)).cpu().detach().numpy()
                 for batch in loader]

        return np.concatenate(x_hat)