from typing import Tuple

import numpy as np
import pandas as pd
import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models import MultiTaskGP, SingleTaskGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import InputTransform
from botorch.posteriors import GPyTorchPosterior
from gpytorch.distributions import MultivariateNormal
from gpytorch.mlls import ExactMarginalLogLikelihood


class SimpleGPProxyModel:
    def __init__(
        self,
        input_names: list,
        target_col: str,
        input_transform: InputTransform,
    ):
        """
        Initializes a Gaussian Process model.
        Args:
            env: Environment object.
            target_col: Name of the target column to optimize.
        """
        self.model = None
        self.input_transform = input_transform
        self.target_col = target_col
        self.input_names = input_names
        self.multi_task = False

    def fit(self, train_df: pd.DataFrame) -> SingleTaskGP:
        """
        Fits a Gaussian Process model to the training data.

        Args:
            train_df: DataFrame containing features and target column.

        Returns:
            SingleTaskGP model and fitted MinMaxScaler.
        """
        train_X = torch.tensor(train_df[self.input_names].values, dtype=torch.float64)
        train_Y = torch.tensor(
            train_df[self.target_col].values.reshape(-1, 1), dtype=torch.float64
        )

        self.model = SingleTaskGP(
            train_X=train_X,
            train_Y=train_Y,
            input_transform=self.input_transform,
        )
        mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model)
        fit_gpytorch_mll(mll=mll)
        return self.model

    def fit_transform(
        self, train_df: pd.DataFrame, test_df: pd.DataFrame
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Fits a Gaussian Process model to the training data.
        Returns posterior mean and variance on the test data.

        Args:
            train_df: DataFrame containing features and target column.
            test_df: DataFrame containing features for predictions.

        Returns:
            Tuple of (posterior_mean, posterior_variance).
        """
        self.fit(train_df)
        test_X = torch.tensor(test_df[self.input_names].values)
        posterior = self.model.posterior(test_X)
        mean, var = posterior.mean.detach().numpy(), posterior.variance.detach().numpy()
        return mean, var


class MultiTaskGPProxyModel:
    def __init__(
        self,
        input_names: list,
        target_col: str,
        input_transform: InputTransform,
    ):
        """
        Initializes a MultiTask Gaussian Process model.
        Args:
            env: Environment object.
            target_col: Name of the target column to optimize.
        """
        self.model = None
        self.input_transform = input_transform
        self.target_col = target_col
        self.input_names = input_names
        self.multi_task = True

    def fit(self, train_df: pd.DataFrame, train_prior_df: pd.DataFrame) -> MultiTaskGP:
        """
        Fits a Gaussian Process model to the training data.

        Args:
            train_df: DataFrame containing features and target column.

        Returns:
            SingleTaskGP model and fitted MinMaxScaler.
        """
        train_X0 = torch.tensor(
            train_prior_df[self.input_names].values, dtype=torch.float64
        )
        train_X1 = torch.tensor(train_df[self.input_names].values, dtype=torch.float64)
        train_Y0 = torch.tensor(
            train_prior_df[self.target_col].values.reshape(-1, 1), dtype=torch.float64
        )
        train_X0 = self.input_transform(train_X0)
        train_X1 = self.input_transform(train_X1)
        train_Y1 = torch.tensor(
            train_df[self.target_col].values.reshape(-1, 1), dtype=torch.float64
        )
        N0, N1 = train_X0.shape[0], train_X1.shape[0]
        i0, i1 = torch.zeros(N0, 1), torch.ones(N1, 1)
        train_X = torch.cat(
            [
                torch.cat([train_X0, i0], -1),
                torch.cat([train_X1, i1], -1),
            ]
        )
        train_Y = torch.cat([train_Y0, train_Y1])
        self.model = MultiTaskGP(
            train_X=train_X,
            train_Y=train_Y,
            task_feature=-1,
        )
        mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model)
        fit_gpytorch_mll(mll=mll)
        return self.model

    def fit_transform(
        self,
        train_df: pd.DataFrame,
        train_prior_df: pd.DataFrame,
        test_df: pd.DataFrame,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Fits a Gaussian Process model to the training data.
        Returns posterior mean and variance on the test data.

        Args:
            train_df: DataFrame containing features and target column.
            test_df: DataFrame containing features for predictions.

        Returns:
            Tuple of (posterior_mean, posterior_variance).
        """
        self.fit(train_df, train_prior_df)
        test_X = torch.tensor(test_df[self.input_names].values)
        test_X = self.input_transform(test_X)
        test_X = torch.cat([test_X, torch.ones(test_X.shape[0], 1)], -1)
        posterior = self.model.posterior(test_X)
        mean, var = posterior.mean.detach().numpy(), posterior.variance.detach().numpy()
        return mean, var


class GroundedPairwiseGP(PairwiseGP):
    """PairwiseGP that conditions the posterior on f(x_ref) = 0."""

    def __init__(self, datapoints, comparisons, ref_idx=0, **kwargs):
        super().__init__(datapoints, comparisons, **kwargs)
        self.ref_idx = ref_idx
        self.register_buffer("x_ref", datapoints[..., ref_idx : ref_idx + 1, :].clone())

    def posterior(
        self, X, output_indices=None, observation_noise=False, posterior_transform=None
    ):
        batch_shape = X.shape[:-2]
        d = X.shape[-1]
        n = X.shape[-2]

        x_ref_expanded = self.x_ref.expand(*batch_shape, 1, d)
        X_aug = torch.cat([X, x_ref_expanded], dim=-2)

        joint = super().posterior(
            X_aug, output_indices, observation_noise, posterior_transform=None
        )
        mu = joint.mean
        Sigma = joint.covariance_matrix

        if mu.shape[-1] == 1:
            mu = mu.squeeze(-1)

        mu_X = mu[..., :n]
        mu_ref = mu[..., n : n + 1]
        Sigma_XX = Sigma[..., :n, :n]
        Sigma_Xr = Sigma[..., :n, n : n + 1]
        Sigma_rr = Sigma[..., n : n + 1, n : n + 1]

        mu_cond = mu_X - (Sigma_Xr / Sigma_rr).squeeze(-1) * mu_ref
        Sigma_cond = Sigma_XX - Sigma_Xr @ Sigma_Xr.transpose(-1, -2) / Sigma_rr

        eye = torch.eye(n, dtype=Sigma.dtype, device=Sigma.device)
        eye = eye.expand_as(Sigma_cond)
        Sigma_cond = Sigma_cond + self._jitter * eye

        cond_posterior = GPyTorchPosterior(MultivariateNormal(mu_cond, Sigma_cond))

        if posterior_transform is not None:
            return posterior_transform(cond_posterior)
        return cond_posterior
