import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from syne_tune import Reporter
import logging
import numpy as np
import math

from benchmarking.utils.experiment import Experiment
from .model import DeepCAE
from ...train_utils.utils import EarlyStopper, save_checkpoint
from models.data_utils.util import get_device


class DeepCAETrainer:
    """
    Trainer for the DeepCAE using the trick proposed in
    https://arxiv.org/abs/2402.18164
    for joint loss calculation for the entire encoder at once instead of using stacking.
    """

    def __init__(
        self,
        model: DeepCAE,
        print_loss_every: int = 50,
        lambda_c: float = 0.005,
    ) -> None:
        self.use_cuda = torch.cuda.is_available()
        self.model = model
        self.print_loss_every = print_loss_every
        self.num_steps = 0
        self.gamma = lambda_c
        self.input_dim = self.model.encoder_spec[0]
        self.writer = SummaryWriter()

        if self.use_cuda:
            self.model.cuda()

    def train(
        self,
        train_loader: DataLoader,
        test_loader: DataLoader,
        epochs: int = 10,
        lr: float = 1e-4,
        hyperparametertuning: bool = False,
        experiment: Experiment = None,
        early_stopping: bool = True,
        patience: int = 30,
        min_delta: float = 1.02,
        checkpointing: bool = True,
        device_id: int = None,
    ) -> None:
        self.model.train()
        device = get_device(device_id)
        self.model.to(device)
        optimizer = optim.Adam(self.model.parameters(), lr=lr)

        # Performance reporting to syne-tune
        if hyperparametertuning:
            report = Reporter()

        if early_stopping:
            early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)
        for epoch in range(epochs):
            epoch_loss = 0.0
            epoch_loss_mse = 0.0

            # Run one epoch of training
            num_train_batches = math.ceil(
                len(train_loader.dataset) / train_loader.batch_size
            )
            for batch in train_loader:
                if self.use_cuda:
                    batch = batch.to(device)

                optimizer.zero_grad()

                recon, latents = self.model(batch)
                mse = nn.MSELoss()(recon, batch)
                jacobian_loss = self.get_joint_jacobian_loss(latents=latents)
                contractive_loss = self.gamma * jacobian_loss

                loss = mse + contractive_loss
                epoch_loss += loss
                epoch_loss_mse += mse
                self.num_steps += 1

                # Print loss if step is eligible
                if self.num_steps % self.print_loss_every == 0:
                    logging.info(
                        f"Step {self.num_steps} loss: {loss.item():.8f}\n >> MSE: {mse.item():.8f}, Contractive loss: {contractive_loss.item():.8f}"
                    )

                # Submit loss to tensorboard
                self.writer.add_scalar("[Train] Total loss", loss, self.num_steps)

                loss.backward()
                optimizer.step()

            self.writer.add_scalar(
                "[Train] Avg. epoch loss",
                epoch_loss / num_train_batches,
                epoch + 1,
            )

            # Now get test loss
            epoch_test_loss = 0.0
            epoch_test_loss_mse = 0.0

            num_test_batches = math.ceil(
                len(test_loader.dataset) / test_loader.batch_size
            )
            for batch in test_loader:
                if self.use_cuda:
                    batch = batch.to(device)

                with torch.no_grad():
                    recon, latents = self.model(batch)
                    mse = nn.MSELoss()(recon, batch)
                    jacobian_loss = self.get_joint_jacobian_loss(latents=latents)
                    contractive_loss = self.gamma * jacobian_loss

                    loss = mse + contractive_loss
                    epoch_test_loss += loss
                    epoch_test_loss_mse += mse

            # Report test loss to syne-tune
            if hyperparametertuning:
                report_loss = epoch_test_loss.item() / num_test_batches
                assert not np.isnan(report_loss), "Loss cannot be NaN"
                report(step=epoch, mean_loss=report_loss, epoch=epoch + 1)

            # Report to the received experiment instance for model comparison
            if experiment:
                experiment.train_loss_mse.append(
                    epoch_loss_mse.item() / num_train_batches
                )
                experiment.test_loss_mse.append(
                    epoch_test_loss_mse.item() / num_test_batches
                )
            logging.info(
                f"Epoch [{epoch + 1}/{epochs}]: Avg. Train Loss: {epoch_loss/num_train_batches:.8f}, Avg. Test Loss: {epoch_test_loss/num_train_batches:.8f}\n"
            )

            self.writer.add_scalar("[Test] Total loss", loss, epoch + 1)
            self.writer.add_scalar(
                "[Test] Avg. epoch loss",
                epoch_test_loss / num_test_batches,
                epoch + 1,
            )

            if checkpointing:
                save_checkpoint(
                    model=self.model,
                    experiment=experiment,
                    device_id=device_id,
                    epoch=epoch,
                    best=(
                        early_stopper.is_best(
                            epoch_test_loss_mse.item() / num_test_batches
                        )
                        if early_stopping
                        else None
                    ),
                )

            if early_stopping and early_stopper.early_stop(
                epoch_test_loss_mse.item() / num_test_batches
            ):
                logging.info("Training was stopped early.")
                break

        # Finally close the tensorboard writer.
        self.writer.flush()
        self.writer.close()

    def get_joint_jacobian_loss(self, latents: list[torch.Tensor]) -> torch.Tensor:
        """
        Calculate joint Jacobian for the entire encoder at once.
        Also applies the Frobenius norm in the end.
        """

        jacobian_matrices = []
        layer_counter = 0
        for latent in latents:
            weights = self.model.encoder[layer_counter].weight
            layer_counter += (
                2  # Increase in double steps to omit the activation functions.
            )
            jacobian_matrix = self.get_jacobian(latent=latent, weights=weights)
            jacobian_matrices.append(jacobian_matrix)

        jacobian_matrices.reverse()
        final_jacobian = jacobian_matrices[0]
        for jacobian_matrix in jacobian_matrices[1:]:
            final_jacobian = final_jacobian @ jacobian_matrix
        return torch.norm(final_jacobian, p="fro")

    def get_jacobian(self, latent: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
        "Using the trick from the DeepCAE paper, which applies to TanH activations."

        activation_derivative = torch.diag_embed(1 - latent**2)
        return torch.matmul(activation_derivative, weights)
