import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import logging

from benchmarking.utils.experiment import Experiment
from .model import PCA
from ...train_utils.utils import save_checkpoint


class KernelPCATrainer:
    """
    Trainer for the KernelPCA.
    """

    def __init__(self, model: PCA) -> None:
        self.model = model

    def train(
        self,
        train_loader: torch.Tensor,
        test_loader: torch.Tensor,
        experiment: Experiment = None,
        device_id: int = None,
    ) -> None:
        model = self.model

        # Get full train data.
        train_data = dataloader_to_numpy(train_loader)

        # Perform transformation
        transformed_reduced = model.fit_transform(train_data)

        # Reconstruct data.
        recon = model.inverse_transform(transformed_reduced)

        # Measure reconstruction quality.
        train_loss = nn.MSELoss()(torch.Tensor(recon), torch.Tensor(train_data)).item()
        logging.info(f"Kernel PCA fitting completed! The train loss is {train_loss}.")

        # Get full test data.
        test_data = dataloader_to_numpy(test_loader)

        # Perform transformation on test data
        test_transformed = model.transform(test_data)

        # Reconstruct test data
        test_recon = model.inverse_transform(test_transformed)

        test_loss = nn.MSELoss()(
            torch.Tensor(test_recon), torch.Tensor(test_data)
        ).item()
        logging.info(
            f"Kernel PCA test transformation and reconstruction completed! The test loss is {test_loss}."
        )

        # Report loss
        if experiment:
            experiment.train_loss_mse.append(train_loss)
            experiment.test_loss_mse.append(test_loss)

        # Save model
        save_checkpoint(
            model=model, experiment=experiment, device_id=device_id, best=True
        )


def dataloader_to_numpy(dataloader: DataLoader) -> np.ndarray:
    # List to hold the data
    data_list = []

    # Iterate over the dataloader
    for batch in dataloader:
        # Convert the batch to a NumPy array and append to the list
        data_list.append(batch.cpu().numpy())

    # Concatenate all the batches to form the full dataset
    full_dataset = np.concatenate(data_list, axis=0)

    return full_dataset
