import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from syne_tune import Reporter
import logging
import math

from benchmarking.utils.experiment import Experiment
from .model import TransformerAE
from ...train_utils.utils import EarlyStopper, save_checkpoint
from models.data_utils.util import get_device


class TransformerAETrainer:
    """
    Trainer for the transformer autoencoder.
    """

    def __init__(self, model: TransformerAE, print_loss_every: int = 10) -> None:
        self.use_cuda = torch.cuda.is_available()
        self.model = model
        self.print_loss_every = print_loss_every
        self.num_steps = 0
        self.writer = SummaryWriter()

        if self.use_cuda:
            self.model.cuda()

    def train(
        self,
        train_loader: DataLoader,
        test_loader: DataLoader,
        epochs: int = 30,
        lr=1e-4,
        hyperparametertuning: bool = False,
        experiment: Experiment = None,
        early_stopping: bool = True,
        patience: int = 30,
        min_delta: float = 1.02,
        checkpointing: bool = True,
        device_id: int = None,
    ) -> None:
        self.model.train()
        device = get_device(device_id)
        self.model.to(device)
        model = self.model
        optimizer = optim.Adam(model.parameters(), lr=lr)

        # Performance reporting to syne-tune
        if hyperparametertuning:
            report = Reporter()

        if early_stopping:
            early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)

        for epoch in range(epochs):
            epoch_loss = 0.0
            num_train_batches = math.ceil(
                len(train_loader.dataset) / train_loader.batch_size
            )
            for batch in train_loader:
                if self.use_cuda:
                    batch = batch.to(device)

                optimizer.zero_grad()
                recon, _ = model(batch.unsqueeze(-1))
                loss = nn.MSELoss()(recon, batch)
                epoch_loss += loss
                self.num_steps += 1

                # Print loss if step is eligible
                if self.num_steps % self.print_loss_every == 0:
                    logging.info(f"Step {self.num_steps} loss: {loss:.5f}")

                # Submit loss to tensor board
                self.writer.add_scalar("[Train] Total loss", loss, self.num_steps)

                loss.backward()
                optimizer.step()

            self.writer.add_scalar(
                "[Train] Total epoch loss",
                epoch_loss / num_train_batches,
                epoch + 1,
            )

            # Now get test loss
            test_epoch_loss = 0.0

            num_test_batches = math.ceil(
                len(test_loader.dataset) / test_loader.batch_size
            )
            for batch in test_loader:
                with torch.no_grad():
                    if self.use_cuda:
                        batch = batch.to(device)
                    recon, _ = model(batch.unsqueeze(-1))
                    loss = nn.MSELoss()(recon, batch)
                    test_epoch_loss += loss

            # Report the test loss to syne tune
            if hyperparametertuning:
                report_loss = test_epoch_loss.item() / num_test_batches
                report(step=epoch, mean_loss=report_loss, epoch=epoch + 1)

            # Report to the received experiment instance for model comparison
            if experiment:
                experiment.train_loss_mse.append(epoch_loss.item() / num_train_batches)
                experiment.test_loss_mse.append(
                    test_epoch_loss.item() / num_test_batches
                )
            else:
                logging.info(
                    f"Epoch [{epoch + 1}/{epochs}], Train Loss: {epoch_loss/num_train_batches:.6f}, Test Loss: {test_epoch_loss/num_test_batches:.6f}\n"
                )

            self.writer.add_scalar("[Test] Total loss", loss, epoch + 1)
            self.writer.add_scalar(
                "[Test] Total epoch loss",
                test_epoch_loss / num_test_batches,
                epoch + 1,
            )

            if checkpointing:
                save_checkpoint(
                    model=model,
                    experiment=experiment,
                    device_id=device_id,
                    epoch=epoch,
                    best=(
                        early_stopper.is_best(test_epoch_loss.item())
                        if early_stopping
                        else None
                    ),
                )

            if early_stopping and early_stopper.early_stop(
                test_epoch_loss.item() / num_test_batches
            ):
                logging.info("Training was stopped early.")
                break

        # Finally close the tensorboard writer.
        self.writer.flush()
        self.writer.close()
