import torch
import gpytorch
import math
import copy
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils.Kendall import calculate_kendall_tau


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, use_robust_init=True):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()

        d = train_x.shape[-1]
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.MaternKernel(ard_num_dims=d)
        )

        if use_robust_init:
            initial_lengthscale = math.sqrt(d)
            self.covar_module.base_kernel.lengthscale = (
                initial_lengthscale * torch.ones(1, d).to(train_x.device)
            )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class GPModel:
    def __init__(self, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        self.model = None
        self.likelihood = None
        self.train_x = None
        self.train_y = None

        self._calib_mean = None
        self._calib_std = None
        self._use_calib = False

    def fit(self, X, y, n_iter=100, lr=0.1, n_restarts=5, use_robust_init=False):
        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, dtype=torch.float32)
        if not isinstance(y, torch.Tensor):
            y = torch.tensor(y, dtype=torch.float32)

        self.train_x = X.to(self.device)
        self.train_y = y.to(self.device)

        best_loss = float("inf")
        best_model_state = None
        best_likelihood_state = None

        print(f"Starting training with {n_restarts} restarts...")

        for restart in range(n_restarts):
            likelihood = gpytorch.likelihoods.GaussianLikelihood().to(self.device)
            model = ExactGPModel(
                self.train_x, self.train_y, likelihood, use_robust_init=use_robust_init
            ).to(self.device)

            if not use_robust_init:
                # robust init
                d = self.train_x.shape[-1]
                rand_val = torch.rand(1, d).to(self.device)  # [0, 1]
                new_lengthscale = (0.5 + 1.5 * rand_val) * math.sqrt(d)
                model.covar_module.base_kernel.lengthscale = new_lengthscale
            else:
                pass

            model.train()
            likelihood.train()

            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

            final_loss = 0.0

            for i in range(n_iter):
                optimizer.zero_grad()
                output = model(self.train_x)
                # Add extra jitter to handle numerical instability with very small noise
                with gpytorch.settings.cholesky_jitter(1e-3):
                    loss = -mll(output, self.train_y)
                loss.backward()
                optimizer.step()

                # Constrain noise to be at least 1e-5 to prevent singular matrices
                # This is critical when data is noiseless (or user sets noise_std=0)
                if hasattr(model.likelihood, "noise_covar"):
                    model.likelihood.noise_covar.raw_noise.data.clamp_(
                        min=-10.0
                    )  # roughly corresponds to noise > 1e-5

                final_loss = loss.item()

                print_interval = max(1, n_iter // 5)
                if (i + 1) % print_interval == 0:
                    print(
                        f"[Restart {restart + 1}/{n_restarts}] Iter {i + 1}/{n_iter} - Loss: {final_loss:.3f}"
                    )

            print(f"Restart {restart + 1} finished with Loss: {final_loss:.3f}")

            # Check if this is the best model
            if final_loss < best_loss:
                best_loss = final_loss
                best_model_state = copy.deepcopy(model.state_dict())
                best_likelihood_state = copy.deepcopy(likelihood.state_dict())

        print(f"Best Loss after {n_restarts} restarts: {best_loss:.3f}")

        # Load best model
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood().to(self.device)
        self.model = ExactGPModel(
            self.train_x, self.train_y, self.likelihood, use_robust_init=use_robust_init
        ).to(self.device)
        self.model.load_state_dict(best_model_state)
        self.likelihood.load_state_dict(best_likelihood_state)

    # @torch.no_grad()
    def set_prediction_calibration(self, X_calib):
        if self.model is None:
            raise RuntimeError(
                "Model has not been trained yet. Please call fit() first."
            )

        if not isinstance(X_calib, torch.Tensor):
            X_calib = torch.tensor(X_calib, dtype=torch.float32)
        X_calib = X_calib.to(self.device)

        mean_c, _, _ = self.predict(X_calib, normalize=False)

        mu_calib = mean_c.mean()
        var_calib = mean_c.var(unbiased=True)

        eps = torch.tensor(1e-12, device=self.device, dtype=mean_c.dtype)
        var_calib = torch.clamp(var_calib, min=eps)
        sigma_calib = torch.sqrt(var_calib)

        self._calib_mean = mu_calib.detach()
        self._calib_std = sigma_calib.detach()
        self._use_calib = True

    def predict(self, X_test, normalize: bool = True, enable_grad: bool = True):
        if self.model is None:
            raise RuntimeError(
                "Model has not been trained yet. Please call fit() first."
            )

        if not isinstance(X_test, torch.Tensor):
            X_test = torch.tensor(X_test, dtype=torch.float32)

        X_test = X_test.to(self.device)

        self.model.eval()
        self.likelihood.eval()

        with (
            torch.set_grad_enabled(bool(enable_grad)),
            gpytorch.settings.fast_pred_var(),
        ):
            observed_pred = self.likelihood(self.model(X_test))
            mean = observed_pred.mean
            std = observed_pred.stddev
            cov = observed_pred.covariance_matrix

        if normalize and self._use_calib:
            mu_calib = self._calib_mean.to(self.device, dtype=mean.dtype)
            sigma_calib = self._calib_std.to(self.device, dtype=mean.dtype)

            var_calib = sigma_calib**2
            mean = (mean - mu_calib) / sigma_calib
            std = std / sigma_calib
            cov = cov / var_calib

        return mean, std, cov

    def calculate_mre(self, X_test, y_test):
        mean, _, _ = self.predict(X_test)

        if not isinstance(y_test, torch.Tensor):
            y_test = torch.tensor(y_test, dtype=torch.float32)
        y_test = y_test.to(self.device)

        epsilon = 1e-8
        relative_errors = torch.abs((y_test - mean) / (y_test + epsilon))
        mre = torch.mean(relative_errors).item()
        return mre
