import torch
from tqdm import trange
from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO
from gpytorch.settings import cholesky_jitter
from torch.optim import Adam

from erl_lib.model_based.modules.gaussian_process import enable_cholesky, GPsModel


class GPsTrainer:
    def __init__(self, model: GPsModel, lr, silent=True, logger=None):
        self.model = model
        self.jitter = model.jitter

        self.lr = lr
        self.silent = silent
        self.logger = logger

    @staticmethod
    def smart_param_init(train_x_tensor, train_y_tensor, base_model, likelihood):
        """Smart initialization for lengthscale, outputscale and
        noise hyperparameters.

        """
        dim = train_y_tensor.shape[1]
        dim_inp = train_x_tensor.shape[1]
        # Lengthscales are initialized at roughly 1/10 of the std of the input data spread.
        base_model.covar_module.base_kernel.raw_lengthscale.data.copy_(
            base_model.covar_module.base_kernel.raw_lengthscale_constraint._inv_transform(
                train_x_tensor.std(axis=0) / 2
            )
            .repeat([dim, 1])
            .reshape([dim, 1, dim_inp])
        )
        # # Outputscale variance is initialized at output variance of the target data
        base_model.covar_module.raw_outputscale.data.copy_(
            base_model.covar_module.raw_outputscale_constraint._inv_transform(
                train_y_tensor.var(axis=0)
            )
        )
        # The noise has two components, the top one is a shared parameter that
        # is added to all outputs, the bottom one is a specific noise for each
        # separate model. The shared noise is initialized to 1/1000 of the output variance,
        # the separate noises are initialized at the output variance. Ideally,
        # I would not want the shared parameter at all, but it seems to be added in
        # GPyTorch by default.
        likelihood.raw_noise.data.copy_(
            likelihood.raw_noise_constraint._inv_transform(
                train_y_tensor.var(axis=0).mean() / 10000
            )
        )
        likelihood.raw_task_noises.data.copy_(
            likelihood.raw_task_noises_constraint._inv_transform(
                train_y_tensor.var(axis=0) / 100
            )
        )

    def train(
        self,
        dataset_train,
        dataset_eval,
        env_step,
        num_max_epochs: int,
        keep_epochs=None,
    ):
        batch = next(dataset_train)
        batch_eval = next(dataset_eval)
        x_train, y_train = self.model.process_batch(batch)

        if self.model.model is None:
            self.model.init_model(x_train, y_train)
            assert self.model.model is not None
            # Smart initialization of lengthscale, outputscale
            # and noise hyperparameters. This step is crucial.
        self.smart_param_init(x_train, y_train, self.model.model, self.model.likelihood)
        base_model = self.model.model
        likelihood = self.model.likelihood

        # build optimizer every time
        params = [{"params": base_model.parameters()}]
        if self.model.is_sparse:
            params.append({"params": likelihood.parameters()})
        optimizer = Adam(params, lr=self.lr)

        iterator = trange(
            num_max_epochs, mininterval=5, disable=self.silent, desc="[Model]"
        )

        losses, eval_scores = [], []

        with enable_cholesky(), cholesky_jitter(self.jitter):
            self.model.eval()
            eval_score = self.model.eval_score(batch_eval)
            early_stopper = EarlyStopper(self.model, eval_score, keep_epochs)
            best_state_dict = early_stopper.best_state_dict

            self.model.train()
            # loss for GP
            if self.model.is_sparse:
                mll = VariationalELBO(likelihood, base_model, num_data=x_train.shape[0])
            else:
                mll = ExactMarginalLogLikelihood(likelihood, base_model)
                base_model.set_train_data(x_train, y_train, strict=False)

            with iterator as pbar:
                for epoch in pbar:
                    self.model.train()
                    pred = base_model(x_train)
                    loss = -mll(pred, y_train)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    self.model.eval()
                    with torch.no_grad():
                        eval_score = self.model.eval_score(batch_eval)
                    stop, best_state_dict, best_epoch = early_stopper.step(eval_score)

                    pbar.set_postfix(
                        {
                            "Train": float(loss),
                            "Val": float(eval_score),
                            "Best": best_epoch,
                        },
                        refresh=False,
                    )
                    losses.append(loss.detach())
                    eval_scores.append(eval_score)

                    if stop:
                        break

        self.model.load_state_dict(best_state_dict)
        self.model.eval()
        if self.logger is not None:
            index = {"env_step": env_step}

            info = self.model.model_state()
            info["train_loss"] = losses[-1]
            info["eval_score"] = eval_scores[-1]
            info["best_epoch"] = best_epoch
            info["last_epoch"] = epoch
            self.logger.append("model_train", index, info)


class EarlyStopper:
    def __init__(self, model, eval_score, keep_epochs=1):
        self.model = model
        self.best_eval_score = eval_score
        self.keep_epochs = keep_epochs

        self.best_state_dict = {}
        self.clone_state_dict()
        self.epoch_kept = 0
        self.best_epoch = 0
        self.epoch = 0

    def step(self, eval_score):
        self.epoch += 1

        if eval_score < self.best_eval_score:
            self.clone_state_dict()
            self.best_eval_score = eval_score
            self.best_epoch = self.epoch
            self.epoch_kept = 0
        else:
            self.epoch_kept += 1

        if self.keep_epochs < self.epoch_kept:
            stop = True
        else:
            stop = False
        return stop, self.best_state_dict, self.best_epoch

    def clone_state_dict(self):
        for key, value in self.model.state_dict().items():
            self.best_state_dict[key] = value.clone()
