import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from syne_tune import Reporter
import logging
import numpy as np
import math

from benchmarking.utils.experiment import Experiment
from .model import StackedCAE
from ...train_utils.utils import EarlyStopper, save_checkpoint
from models.data_utils.util import get_device


class CAETrainer:
    "Trainer class for the single layer CAE as originally proposed."
    def __init__(self) -> None:
        raise NotImplementedError


class StackedCAETrainer:
    """
    Trainer for the StackedCAE using the original CAE-wise loss calculation.
    """

    def __init__(
        self,
        model: StackedCAE,
        print_loss_every: int = 50,
        lambda_c: float = 0.005,
    ) -> None:
        self.use_cuda = torch.cuda.is_available()
        self.model = model
        self.print_loss_every = print_loss_every
        self.num_steps = 0
        self.gamma = lambda_c
        self.input_dim = self.model.encoder_spec[0]
        self.writer = SummaryWriter()

        if self.use_cuda:
            self.model.cuda()

    def train(
        self,
        train_loader: DataLoader,
        test_loader: DataLoader,
        epochs: int = 10,
        lr: float = 1e-4,
        hyperparametertuning: bool = False,
        experiment: Experiment = None,
        early_stopping: bool = True,
        patience: int = 30,
        min_delta: float = 1.02,
        checkpointing: bool = True,
        device_id: int = None,
    ) -> None:
        self.model.train()
        device = get_device(device_id)
        self.model.to(device)
        optimizer = optim.Adam(self.model.parameters(), lr=lr)

        # Performance reporting to syne-tune
        if hyperparametertuning:
            report = Reporter()

        if early_stopping:
            early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)
        for epoch in range(epochs):
            epoch_loss = 0.0
            epoch_loss_mse = 0.0

            # Run one epoch of training
            num_train_batches = math.ceil(
                len(train_loader.dataset) / train_loader.batch_size
            )
            for batch in train_loader:
                if self.use_cuda:
                    batch = batch.to(device)

                optimizer.zero_grad()

                recon, latents = self.model(batch)
                mse = nn.MSELoss()(recon, batch)
                jacobian_loss = self.get_stacked_jacobian_loss(latents=latents)
                contractive_loss = self.gamma * jacobian_loss

                loss = mse + contractive_loss
                epoch_loss += loss
                epoch_loss_mse += mse
                self.num_steps += 1

                # Print loss if step is eligible
                if self.num_steps % self.print_loss_every == 0:
                    logging.info(
                        f"Step {self.num_steps} loss: {loss.item():.8f}\n >> MSE: {mse.item():.8f}, Contractive loss: {contractive_loss.item():.8f}"
                    )

                # Submit loss to tensorboard
                self.writer.add_scalar("[Train] Total loss", loss, self.num_steps)

                loss.backward()
                optimizer.step()

            self.writer.add_scalar(
                "[Train] Avg. epoch loss",
                epoch_loss / num_train_batches,
                epoch + 1,
            )

            # Now get test loss
            epoch_test_loss = 0.0
            epoch_test_loss_mse = 0.0

            num_test_batches = math.ceil(
                len(test_loader.dataset) / test_loader.batch_size
            )
            for batch in test_loader:
                if self.use_cuda:
                    batch = batch.to(device)

                with torch.no_grad():
                    recon, latents = self.model(batch)
                    mse = nn.MSELoss()(recon, batch)
                    jacobian_loss = self.get_stacked_jacobian_loss(latents=latents)
                    contractive_loss = self.gamma * jacobian_loss

                    loss = mse + contractive_loss
                    epoch_test_loss += loss
                    epoch_test_loss_mse += mse

            # Report test loss to syne-tune
            if hyperparametertuning:
                report_loss = epoch_test_loss.item() / num_test_batches
                assert not np.isnan(report_loss), "Loss cannot be NaN"
                report(step=epoch, mean_loss=report_loss, epoch=epoch + 1)

            # Report to the received experiment instance for model comparison
            if experiment:
                experiment.train_loss_mse.append(
                    epoch_loss_mse.item() / num_train_batches
                )
                experiment.test_loss_mse.append(
                    epoch_test_loss_mse.item() / num_test_batches
                )
            logging.info(
                f"Epoch [{epoch + 1}/{epochs}]: Avg. Train Loss: {epoch_loss.item()/num_train_batches:.8f}, Avg. Test Loss: {epoch_test_loss.item()/num_train_batches:.8f}\n"
            )

            self.writer.add_scalar("[Test] Total loss", loss, epoch + 1)
            self.writer.add_scalar(
                "[Test] Avg. epoch loss",
                epoch_test_loss / num_test_batches,
                epoch + 1,
            )

            if checkpointing:
                save_checkpoint(
                    model=self.model,
                    experiment=experiment,
                    device_id=device_id,
                    epoch=epoch,
                    best=(
                        early_stopper.is_best(
                            epoch_test_loss_mse.item() / num_test_batches
                        )
                        if early_stopping
                        else None
                    ),
                )

            if early_stopping and early_stopper.early_stop(
                epoch_test_loss_mse.item() / num_test_batches
            ):
                logging.info("Training was stopped early.")
                break

        # Finally close the tensorboard writer.
        self.writer.flush()
        self.writer.close()

    
    def get_stacked_jacobian_loss(self, latents: list[torch.Tensor]) -> torch.Tensor:
        """
        Calculates the jacobian loss of the encoder as if the CAE was stacked by summing the jacobian loss of 
        each sub-encoder, i.e. each layer. This is what all known existing multi-layer CAE architectures do so far, 
        although it contradicts the original contractive idea.
        """

        jacobian_loss = torch.tensor([0.0], device=latents[0].device)
        for idx, latent in enumerate(latents):
            dh = 1 - latent**2
            w_sum = torch.sum(self.model.caes[idx].encoder[0].weight**2, dim=1)
            w_sum = w_sum.unsqueeze(1)
            current_jacobian_loss = torch.sum(torch.mm(dh**2, w_sum), 0)
            jacobian_loss += current_jacobian_loss
        return jacobian_loss
