from __future__ import annotations

"""Acquisition‑function base class with EIG variants and baseline."""

import logging
import math
from abc import ABC, abstractmethod
from dataclasses import dataclass

import scipy
import torch
from gpytorch import inv_quad

from xac.applications.applications import BaseApplication
from xac.surrogates.gp_surrogate import GPSurrogate

from linear_operator import to_linear_operator

log = logging.getLogger(__name__)

# -----------------------------------------------------------------------------
# Abstract base
# -----------------------------------------------------------------------------


@dataclass(frozen=True)
class BaseAcquisitionFunction(ABC):
    """Base class for acquisition functions."""

    @abstractmethod
    def __call__(
        self,
        X: torch.Tensor,  # (S × D)
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:  # (S,)
        """Return a utility value for every candidate in *X*.

        Args:
            X (torch.Tensor): A tensor of shape (s, d) representing s candidate configurations in d dimensions.
            surrogate (GPSurrogate): A GP surrogate model that has been trained on the current data archive.

        Returns:
            torch.Tensor: A tensor of shape (s,) containing the utility value for each candidate configuration.
        """
        pass

    @property
    @abstractmethod
    def plot_name(self) -> str:
        """A readable name used for plotting."""
        pass


# -----------------------------------------------------------------------------
# EIG variants
# -----------------------------------------------------------------------------


@dataclass(frozen=True)
class Random(BaseAcquisitionFunction):
    """Random baseline."""

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:

        utils = torch.zeros(X.shape[0], dtype=X.dtype)
        random_idx = torch.randint(0, X.shape[0], (1,))
        utils[random_idx] = 1.0

        return utils

    @property
    def plot_name(self) -> str:
        return "Random"


# @dataclass(frozen=True)
# class EIGFunctionProperty(BaseAcquisitionFunction):
#     """Expected Information Gain (EIG) for the function property."""

#     def __call__(
#         self,
#         X: torch.Tensor,
#         candidate_idx_Z: torch.Tensor,
#         surrogate: GPSurrogate,
#         application: BaseApplication,
#     ) -> torch.Tensor:
#         # -----------------------------------------------------------------------------
#         # Vectorized EIG-FP computation
#         # -----------------------------------------------------------------------------
#         def _compute_EIG_FP(cov_fz, cov_yz):
#             try:
#                 # 2) Transform cov_fz to posterior covariance of property
#                 cov_prop = A @ cov_fz @ A.T  # (M × M)

#                 # 3) Decompose via Cholesky
#                 L = torch.linalg.cholesky(cov_prop, upper=False)  # (M x M)

#             except:
#                 # Backup variant with enhanced numerical stability (adds noise to diagonal for Cholesky)
#                 log.info(
#                     f"""Computing EIGFunctionProperty using covariance of F_Z did not work. Trying covariance of Y_Z (with noise on diagonal) for Cholesky instead."""
#                 )

#                 cov_prop = A @ cov_yz @ A.T  # (M × M)
#                 L = torch.linalg.cholesky(cov_prop, upper=False)  # (M x M)

#             # 4) Compute V
#             B = A @ cov_fz  # (M x S)
#             V = torch.linalg.solve_triangular(L, B, upper=False, left=True)  # (M X S)

#             # 5) Compute column-wise, squared 2-norm
#             V_norms = torch.norm(V, p=2, dim=0) ** 2

#             cov_yz_diag = cov_yz.detach().diag()

#             # 6) Compute EIG (for all in Z due to ease of implementation (and comparable cost) and filter to elements in X)
#             EIG_Z = torch.log(cov_yz_diag) - torch.log(cov_yz_diag - V_norms)

#             return EIG_Z

#         A = application.A
#         Z = application.Z

#         # Assert that all candidates (X) and elements in Z are unique
#         assert X.unique(dim=0).size(0) == X.size(0)
#         assert Z.unique(dim=0).size(0) == Z.size(0)

#         try:
#             # 1) Compute posterior covariance of F_Z and Y_Z
#             cov_fz = surrogate.forward(Z).covariance_matrix
#             cov_yz = surrogate.forward(Z, observation_noise=True).covariance_matrix

#             if cov_fz.ndim == 3:
#                 # Fully Bayesian setting
#                 EIG_Z_cols = []

#                 for i in range(cov_fz.shape[0]):
#                     EIG_Z_cols.append(_compute_EIG_FP(cov_fz[i, ...], cov_yz[i, ...]))

#                 # Compute mean
#                 EIG_Z = torch.stack(EIG_Z_cols).mean(dim=0)

#             else:
#                 # MLM/ MAP setting
#                 EIG_Z = _compute_EIG_FP(cov_fz, cov_yz)

#             EIG_X = EIG_Z[candidate_idx_Z]
#             return EIG_X

#         except Exception as e:
#             log.info(
#                 f"CAUTION! Exception occurred in EIGFunctionProperty: {e}. Returning uniform vector to avoid crashes."
#             )

#             return torch.zeros((X.shape[0]), dtype=X.dtype)

#     @property
#     def plot_name(self) -> str:
#         return "EIG-FP"


@dataclass(frozen=True)
class EIGFunctionProperty(BaseAcquisitionFunction):
    """Efficient implementation of the Epected Information Gain (EIG) for the function property."""

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:
        # -----------------------------------------------------------------------------
        # Efficient, vectorized EIG-FP computation
        # -----------------------------------------------------------------------------
        A = application.A
        Z = application.Z

        # # Assert that all candidates (X) and elements in Z are unique
        # assert X.unique(dim=0).size(0) == X.size(0)
        # assert Z.unique(dim=0).size(0) == Z.size(0)

        # import time
        # start = time.perf_counter()

        #Test efficient approach
        # #A
        # X_train= Z[:4,:]

        # Z_binary= surrogate._model.input_transform.transform(Z)
        # X_train_binary= surrogate._model.input_transform(X_train)

        # K_XT= surrogate._model.covar_module(Z_binary, X_train_binary).to_dense()
        # B= A @ K_XT

        lazy_covar_fz = surrogate.forward_lazy_covar(Z, observation_noise=False)
        lazy_covar_yz = surrogate.forward_lazy_covar(Z, observation_noise=True)

        # end = time.perf_counter()
        # log.info("Time " + str(end - start))

        def _compute_eig_fp(covar_fz, covar_yz):
            # import time
            # start = time.perf_counter()

            covar_yz_diag = covar_yz.diagonal(dim1=-2, dim2=-1)

            # end = time.perf_counter()
            # log.info("Time " + str(end - start))

            # import time
            # start = time.perf_counter()

            transformed_covar_fz = covar_fz.matmul(A.T)
            quad_form_covar_fz = A @ transformed_covar_fz

            # transformed_covar_fz = covar_fz.matmul(to_linear_operator(A.T))
            # quad_form_covar_fz = transformed_covar_fz.rmatmul(to_linear_operator(A)) #A @ transformed_covar_fz
            #Leads to OOM, materialize after projection better than lazy

            #should be constant in cost, as it is always 16x16 and 65k (same size)

            correction_term = inv_quad(
                input=quad_form_covar_fz,
                inv_quad_rhs=transformed_covar_fz.transpose(-2, -1),
                reduce_inv_quad=False, #Returns diagonal
            )

            # end = time.perf_counter()
            # log.info("Time " + str(end - start))

            return torch.log(covar_yz_diag) - torch.log(covar_yz_diag - correction_term)

        # try:
        try:
            EIG_Z = _compute_eig_fp(lazy_covar_fz, lazy_covar_yz)

        except:
            # Backup variant with enhanced numerical stability (adds noise to diagonal for Cholesky)
            log.info(
                f"""Computing EIGFunctionProperty using covariance of F_Z did not work. Trying covariance of Y_Z (with noise on diagonal) for Cholesky instead."""
            )
            EIG_Z = _compute_eig_fp(lazy_covar_yz, lazy_covar_yz)

        if EIG_Z.ndim == 2:
            EIG_Z = EIG_Z.mean(dim=0)

        EIG_X = EIG_Z[candidate_idx_Z]

        # check if it is still same as old for 1 sample

        # #Check: TODO remove: is this same as old variant?
        # EIG_X_inefficient = EIGFunctionProperty()(X, candidate_idx_Z, surrogate, application)
        # assert torch.allclose(EIG_X, EIG_X_inefficient, atol=1e-4), "Efficient EIG-FP does not match inefficient variant!"

        return EIG_X

        # except Exception as e:
        #     log.info(
        #         f"CAUTION! Exception occurred in EIGFunctionProperty: {e}. Returning uniform vector to avoid crashes."
        #     )

        #     return torch.zeros((X.shape[0]), dtype=X.dtype)

    @property
    def plot_name(self) -> str:
        return "EIG-FP"


@dataclass(frozen=True)
class EPIG(BaseAcquisitionFunction):
    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:
        # -----------------------------------------------------------------------------
        # Implementation of the EPIG (expected predictive information gain) acquisition function
        # -----------------------------------------------------------------------------
        # Caution: Not numerically stable yet

        Z = application.Z

        # # Assert that all candidates (X) and elements in Z are unique
        # assert X.unique(dim=0).size(0) == X.size(0)
        # assert Z.unique(dim=0).size(0) == Z.size(0)

        lazy_covar_yz = surrogate.forward_lazy_covar(Z, observation_noise=True)

        try:
            covar_yz_diag = lazy_covar_yz.diagonal(dim1=-2, dim2=-1)
            covar_yz_outer = covar_yz_diag.outer(covar_yz_diag)
            covar_yz_squared = lazy_covar_yz.mul(lazy_covar_yz)

            corr_sq = (
                covar_yz_squared.div(covar_yz_outer)
                .to_dense()
                .clamp(min=1e-12, max=1.0 - 1e-12)
            )

            EPIG_Z = -0.5 * (torch.log1p(-corr_sq).mean(dim=0))

            # EPIG_Z= 0.5 * torch.log(
            #                         covar_yz_outer / (covar_yz_outer - covar_yz_squared.to_dense())
            #                          ).mean(dim= 0)

            EPIG_X = EPIG_Z[candidate_idx_Z]

            return EPIG_X

        except Exception as e:
            log.info(
                f"CAUTION! Exception occurred in EPIG: {e}. Returning uniform vector to avoid crashes."
            )

            return torch.zeros((X.shape[0]), dtype=X.dtype)

    @property
    def plot_name(self) -> str:
        return "EPIG"


@dataclass(frozen=True)
class EIGExecutionPath(BaseAcquisitionFunction):
    """Expected Information Gain (EIG) for the execution path."""

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,  # unused but keeps uniform signature
        iteration_idx: int
    ) -> torch.Tensor:

        try:
            Z = application.Z
            lazy_covar_fz = surrogate.forward_lazy_covar(Z, observation_noise=False)

            # cand_ppv = surrogate.forward_lazy_covar(X, observation_noise=False)
            #cand_ppv = surrogate.forward_marg_vars(X).covariance_matrix
            # Old: cand_ppv = surrogate.forward(X).covariance_matrix

            if lazy_covar_fz.ndim == 3:
                EIG_Z= lazy_covar_fz.mean(dim=0).diag()

            else:
                EIG_Z= lazy_covar_fz.diagonal(
                    dim1=-2, dim2=-1
                )  # Extracts the diagonal from the last two dimensions

            if EIG_Z.ndim == 2:
                EIG_Z = EIG_Z.mean(dim=0)

            return EIG_Z[candidate_idx_Z]

        except Exception as e:
            log.info(
                f"CAUTION! Exception occurred in EIGExecutionPath: {e}. Returning uniform vector to avoid crashes."
            )

            return torch.zeros((X.shape[0]), dtype=X.dtype)

    @property
    def plot_name(self) -> str:
        return "EIG-EP"

class SHAPIQAcquisitionFunction(BaseAcquisitionFunction):
    """Samples candidates according to SHAP-IQ implementations."""

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:
        #First entry in candidate_idx_Z is always the next to sample according to SHAP-IQ sampling
        utils_X= torch.zeros(X.shape[0], dtype=X.dtype)
        utils_X[0]= 1

        return None

    @property
    def plot_name(self) -> str:
        return "SHAP-IQ-Sampler"
    
class KernelSHAPSampler(SHAPIQAcquisitionFunction):
    """Samples candidates according to SHAP-IQ implementation of KernelSHAP."""

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:

        return None

    @property
    def plot_name(self) -> str:
        return "KernelSHAP"
    
class LeverageSHAPSampler(SHAPIQAcquisitionFunction):
    """Samples candidates according to SHAP-IQ implementation of LeverageSHAP."""

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:

        return None

    @property
    def plot_name(self) -> str:
        return "LeverageSHAP"
    
class SVARMSampler(SHAPIQAcquisitionFunction):
    """Samples candidates according to SHAP-IQ implementation of SVARM."""

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:

        return None

    @property
    def plot_name(self) -> str:
        return "SVARM"
    
class PermutationSampler(SHAPIQAcquisitionFunction):
    """Samples candidates according to SHAP-IQ implementation of Permutation Sampling."""

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:
        return None

    @property
    def plot_name(self) -> str:
        return "Permutation Sampling"
    
class RegressionMSRSampler(SHAPIQAcquisitionFunction):
    """Samples candidates according to SHAP-IQ implementation of RegressionMSRSampler."""

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:
        return None

    @property
    def plot_name(self) -> str:
        return "Regression MSR"
    
class LeverageGPSampler(SHAPIQAcquisitionFunction):
    """Samples candidates according to SHAP-IQ implementation of LeverageSHAP (and fits GP on this)."""

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:
        return None

    @property
    def plot_name(self) -> str:
        return "LeverageSHAP-GP"

class SHAPKernelSampler(BaseAcquisitionFunction):
    """Samples candidates according to SHAP kernel."""

    # https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf

    def __call__(
        self,
        X: torch.Tensor,
        candidate_idx_Z: torch.Tensor,
        surrogate: GPSurrogate,
        application: BaseApplication,
        iteration_idx: int
    ) -> torch.Tensor:

        Z = application.Z
        Z_binary = surrogate._model.input_transform.transform(Z)

        m = int(math.log2(Z.shape[0]))

        shap_weights = torch.zeros(Z.shape[0], dtype=X.dtype)

        def shap_kernel_weight(m: int, z: int) -> float:
            # According to Theorem 2 in the SHAP paper
            if z == 0 or z == m:
                return 0  # Assign 0 weight to empty set and full set

            else:
                numerator = m - 1
                denominator = scipy.special.comb(m, z) * z * (m - z)

                return numerator / denominator

        # Compute SHAP weights for all rows in X
        for i in range(Z.shape[0]):
            subset_size = torch.sum(Z_binary[i, :]).item()  # Assuming binary features
            shap_weights[i] = shap_kernel_weight(m=m, z=subset_size)

        shap_weights[shap_weights == float("inf")] = 0.0
        shap_weights = shap_weights / shap_weights.sum()  # Normalize to sum to 1

        # Ensure that sampled index is within candidate_idx_Z
        while True:
            sampled_idx = torch.multinomial(shap_weights, num_samples=1)
            if sampled_idx in candidate_idx_Z:
                break

        utils_Z = torch.zeros(Z.shape[0], dtype=X.dtype)
        utils_Z[sampled_idx] = 1.0

        return utils_Z[candidate_idx_Z]  # Return utils for candidates in X only

        # Sanity check: Coalitions with few or many players have higher chance of being sampled, but there are less of them

    @property
    def plot_name(self) -> str:
        return "SHAP-Kernel-Sampler"
