from __future__ import annotations

"""Abstract application class and various concrete implementations.

* ``BaseApplication`` enforces that every concrete application must
  lazily create three tensors:

  - **Z** : execution path input sequence
  - **A** : affine transformation matrix
  - **X0**: initial design / archive

The attributes are *derived* from the user‑supplied constructor arguments,
so they are declared with ``init=False`` and filled inside
``__post_init__``.
"""

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional

import gpytorch
import numpy as np
import scipy
import torch

import shapiq
from shapiq.approximator import KernelSHAP, SVARM, PermutationSamplingSV
from .regressionMSR import RegressionMSR
from .polyshap import PolySHAP, ExplanationFrontierGenerator

# from xac.acquisition_functions import KernelSHAPSampler, SVARMSampler

from xac.blackbox_functions import (BaseBlackboxFunction, BotorchTestFunction,
                                    TabRepoBenchmark)

import numpy as np

log = logging.getLogger(__name__)

# -----------------------------------------------------------------------------
# Abstract parent
# -----------------------------------------------------------------------------


@dataclass(frozen=True)
class BaseApplication(ABC):
    """Abstract base class for all applications."""

    eval_bb_only_on_Z: bool = field(init=False)
    lazy_setup: bool = field(init=False, default=False)

    # ---------------------------------------------------------------------
    # properties every subclass must implement
    # ---------------------------------------------------------------------
    @property
    @abstractmethod
    def Z(self) -> torch.Tensor:
        pass

    @property
    @abstractmethod
    def A(self) -> torch.Tensor:
        pass

    @property
    @abstractmethod
    def X0(self) -> torch.Tensor:
        pass

    @property
    @abstractmethod
    def candidate_set(self) -> torch.Tensor:
        pass

    @property
    @abstractmethod
    def candidate_idx_Z(self) -> torch.Tensor:
        # Indices of the elements in the candidate set in Z
        pass

    @abstractmethod
    def termination_criterion(self, property_posterior) -> bool:
        """Return True if the termination condition for this application is met."""

    @abstractmethod
    def run_lazy_setup(self, blackbox_function):
        """Setup class lazily. Must be called from outside."""
        pass

    # ------------------------------------------------------------------
    # post‑init & lazy setup
    # ------------------------------------------------------------------
    def __post_init__(self):
        # Maybe we need this in a later application
        if not self.lazy_setup:
            object.__setattr__(self, "Z", self._generate_Z())
            object.__setattr__(self, "A", self._generate_A())
            object.__setattr__(self, "X0", self._generate_X0())

        else:
            object.__setattr__(self, "lazy_setup_conducted", False)

    # ------------------------------------------------------------------
    # Compute posterior of the property for a given surrogate
    # ------------------------------------------------------------------
    def property_posterior(self, surrogate, noisy_variant: bool = False):
        """Return the posterior over the property using a fitted surrogate.

        Notes
        -----
        Should not be stored as part of the dataclass, since it may change over time.
        """
        if self.lazy_setup:
            assert self.lazy_setup_conducted, f"run_lazy_setup() must be called first."

        def _property_posterior(_surrogate, _noisy_variant):
            mvn_mean = _surrogate.forward(self.Z, observation_noise=_noisy_variant).mean
            mvn_lazy_covar = _surrogate.forward_lazy_covar(
                self.Z, observation_noise=_noisy_variant
            )

            if mvn_lazy_covar.ndim == 3:
                mean = (self.A @ mvn_mean.T).T
            else:
                mean = self.A @ mvn_mean

            cov = self.A @ mvn_lazy_covar.matmul(self.A.T)
            #DefaultCPUAllocator: can't allocate memory: you tried to allocate 68719476736 bytes. Error code 12 (Cannot allocate memory)

            return gpytorch.distributions.MultivariateNormal(mean, cov)
            # Might return a Gaussian mixture (if ndim=3 and leading dimension > 3)

        try:
            return _property_posterior(surrogate, noisy_variant)

        except:
            # Force jitter on diagonal, as this can avoid non positive-definiteness issues
            log.info(
                "Added jitter on diagonal (leading to PPD for y-Z) to avoid positive-definiteness issues."
            )
            return _property_posterior(surrogate, True)

    # ------------------------------------------------------------------
    # utility: dtype / device transfer
    # ------------------------------------------------------------------
    def to(
        self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
    ):
        """Recursively move stored tensors to *device* / *dtype*."""
        if self.lazy_setup:
            assert self.lazy_setup_conducted, f"run_lazy_setup() must be called first."

        if device is not None or dtype is not None:
            self.Z = self.Z.to(device=device, dtype=dtype)
            self.A = self.A.to(device=device, dtype=dtype)
            self.X0 = self.X0.to(device=device, dtype=dtype)
        return self


# -----------------------------------------------------------------------------
# Application – Partial dependence plots
# -----------------------------------------------------------------------------


@dataclass(frozen=True)
class PDPApplication(BaseApplication):
    # PDP for 1d marginal effects

    dim_of_interest: int = 0
    dim_of_interest_samples: int = (
        20  # How many samples to draw for dimension of interest
    )

    marginalized_dims_samples: int = (
        50  # How many samples to draw in total (per grid point) from all marginalized dimensions
    )

    init_design_dim_factor: int = 5

    eval_bb_only_on_Z: bool = field(init=False, default=True)
    lazy_setup: bool = field(init=False, default=True)

    def run_lazy_setup(self, blackbox_function: BotorchTestFunction, seed: int, amount_iterations: int = None):
        object.__setattr__(self, "blackbox_function", blackbox_function)

        # Generate an equidistant set of grid  points for the dimension of interest
        dim_of_interest_bounds = self.blackbox_function.get_bounds_for_dim(
            dim=self.dim_of_interest
        )
        dim_of_interest_grid = torch.unsqueeze(
            torch.linspace(
                dim_of_interest_bounds[0],
                dim_of_interest_bounds[1],
                self.dim_of_interest_samples + 2,
            )[1:-1],
            dim=1,
        )

        # Draw marginalized_dims_samples * dims Sobol samples for the other dimensions
        amount_marginalized_dims = self.blackbox_function.dim - 1

        sobol = torch.quasirandom.SobolEngine(
            dimension=amount_marginalized_dims, scramble=True, seed=seed
        )  # To ensure different grids across seeds
        marginalized_dims_unscaled_grid = sobol.draw(
            self.marginalized_dims_samples  # * amount_marginalized_dims
        )

        marginalized_dims_idx = torch.arange(self.blackbox_function.dim)
        marginalized_dims_idx = marginalized_dims_idx[
            marginalized_dims_idx != self.dim_of_interest
        ]

        marginalized_dims_lower_bounds = torch.unsqueeze(
            self.blackbox_function.bounds[0, marginalized_dims_idx], dim=0
        )
        marginalized_dims_upper_bounds = torch.unsqueeze(
            self.blackbox_function.bounds[1, marginalized_dims_idx], dim=0
        )
        marginalized_dims_grid = (
            marginalized_dims_lower_bounds
            + (marginalized_dims_upper_bounds - marginalized_dims_lower_bounds)
            * marginalized_dims_unscaled_grid
        )

        # List with inputs of execution path per dimension in M
        execution_path_separated = []
        for dim_of_interest_grid_idx in range(dim_of_interest_grid.shape[0]):
            # For each of M dims span the execution path
            execution_path_dim_subset = torch.zeros(
                (marginalized_dims_grid.shape[0], self.blackbox_function.dim),
                dtype=float,
            )

            execution_path_dim_subset[:, self.dim_of_interest] = dim_of_interest_grid[
                dim_of_interest_grid_idx, :
            ]
            execution_path_dim_subset[:, marginalized_dims_idx] = marginalized_dims_grid

            execution_path_separated.append(execution_path_dim_subset)

            assert torch.all(
                execution_path_dim_subset[:, self.dim_of_interest]
                == execution_path_dim_subset[:, self.dim_of_interest].flatten()[0]
            )  # Assert all values are identical for dim_of_interest
            assert torch.any(
                execution_path_dim_subset[:, marginalized_dims_idx]
                != execution_path_dim_subset[:, marginalized_dims_idx].flatten()[0]
            )  # Assert other values vary

            for temp_dim_idx in torch.arange(self.blackbox_function.dim):
                temp_slice = execution_path_dim_subset[:, temp_dim_idx]
                lower, upper = self.blackbox_function.get_bounds_for_dim(
                    dim=temp_dim_idx
                )
                assert torch.all((temp_slice >= lower) & (temp_slice <= upper))

        Z = torch.stack(execution_path_separated).reshape(
            -1, self.blackbox_function.dim
        )

        # Generate affine transformation
        A = torch.zeros((self.dim_of_interest_samples, Z.shape[0]), dtype=float)
        for dim_of_interest_grid_idx in range(A.shape[0]):
            A[
                dim_of_interest_grid_idx,
                dim_of_interest_grid_idx
                * marginalized_dims_grid.shape[0] : (dim_of_interest_grid_idx + 1)
                * marginalized_dims_grid.shape[0],
            ] = (
                1 / marginalized_dims_grid.shape[0]
            )

        # Select size of initial design based on dimensionality
        init_design_size = self.init_design_dim_factor * self.blackbox_function.dim
        rand_perm = torch.randperm(Z.shape[0])
        X0 = Z[rand_perm[:init_design_size], :]
        candidate_set = Z[rand_perm[init_design_size:], :]
        candidate_idx_Z = rand_perm[init_design_size:]

        # Required overrides
        object.__setattr__(self, "_Z", Z)
        object.__setattr__(self, "_A", A)
        object.__setattr__(self, "_X0", X0)
        object.__setattr__(self, "_candidate_set", candidate_set)
        object.__setattr__(self, "_candidate_idx_Z", candidate_idx_Z)

        if self.blackbox_function.is_pseudo_expensive:
            f_Z_gt = self.blackbox_function.evaluate(X=self.Z)
            object.__setattr__(self, "f_Z_gt", f_Z_gt)

            prop_gt = self.A @ self.f_Z_gt[0]
            object.__setattr__(self, "prop_gt", prop_gt)

        object.__setattr__(self, "lazy_setup_conducted", True)

    # ---------------- required overrides -----------------------------
    @property
    def Z(self) -> str:
        return self._Z

    @property
    def A(self) -> str:
        return self._A

    @property
    def X0(self) -> str:
        return self._X0

    @property
    def candidate_set(self) -> str:
        return self._candidate_set

    @property
    def candidate_idx_Z(self) -> str:
        return self._candidate_idx_Z

    def termination_criterion(self, property_posterior) -> torch.Tensor:
        if self.lazy_setup:
            assert self.lazy_setup_conducted, f"run_lazy_setup() must be called first."

        return False  # Not implemented


# -----------------------------------------------------------------------------
# Application – Efficient Benchmarking
# -----------------------------------------------------------------------------


@dataclass(frozen=True)
class TabRepoBenchmarkApplication(BaseApplication):
    """Concrete implementation for the TabRepo Benchmarking setting."""

    amount_challenger_configs: int = 1
    amount_init_design_configs: int = 5
    amount_random_init_design_samples: int = (
        0  # As opposed to fully evaluated configs in init design, simply add random samples from dataset
    )

    # The evaluation scope of the blackbox function is restricted to Z.
    eval_bb_only_on_Z: bool = field(init=False, default=True)
    lazy_setup: bool = field(init=False, default=True)

    # ---------------- required overrides -----------------------------
    def run_lazy_setup(self, blackbox_function: TabRepoBenchmark, seed: int, amount_iterations: int = None):
        object.__setattr__(self, "blackbox_function", blackbox_function)

        # ---------------- asserts -----------------------------
        assert (
            self.amount_challenger_configs == 1
        ), f"Currently only supports a single challenger config."

        assert (
            self.amount_init_design_configs > 0
        ), f"Currently an initial design is required (at least 1, as this contains default)."

        # ------------------------------------------------------------------
        # 1. Randomly select configs
        # ------------------------------------------------------------------
        candidate_config_ids = self.blackbox_function.dataset[
            :, self.blackbox_function.config_id_idx
        ].unique()

        config_ids = candidate_config_ids[
            torch.randperm(len(candidate_config_ids))[
                : self.amount_init_design_configs + self.amount_challenger_configs
            ]
        ]
        config_data = [
            blackbox_function.dataset[
                blackbox_function.dataset[:, blackbox_function.config_id_idx]
                == temp_config_id
            ]
            for temp_config_id in config_ids
        ]

        config_dataset_ids = [
            config_data[i][:, blackbox_function.dataset_id_idx]
            for i in range(len(config_data))
        ]
        # Ensure that datasets are always identical
        assert all(
            [
                (
                    config_dataset_ids[0].sort()[0] == config_dataset_ids[i].sort()[0]
                ).all()
                for i in range(len(config_dataset_ids))
            ]
        )

        init_design_config_idx = config_ids[:-1]
        execution_path_config_idx = config_ids[-2:]
        candidate_set_config_idx = config_ids[-1]

        init_design_data = torch.concat(config_data[:-1])

        # ------------------------------------------------------------------
        # 2. If specified, add random samples from remaining dataset to initial design
        # ------------------------------------------------------------------

        if self.amount_random_init_design_samples > 0:
            config_data_complement = blackbox_function.dataset[
                ~torch.isin(
                    blackbox_function.dataset[:, blackbox_function.config_id_idx],
                    config_ids,
                )
            ]

            rows = config_data_complement.size(0)
            random_init_design_samples = config_data_complement[
                torch.randperm(rows)[: self.amount_random_init_design_samples]
            ]

            init_design_data = torch.concat(
                [init_design_data, random_init_design_samples]
            )

        init_design_x = init_design_data[:, blackbox_function.indep_attr_idx]
        init_design_y_perf = init_design_data[:, blackbox_function.perf_metric_idx]
        init_design_y_cost = init_design_data[:, blackbox_function.cost_metric_idx]

        execution_path_data = torch.concat(config_data[-2:])
        execution_path_x = execution_path_data[:, blackbox_function.indep_attr_idx]
        execution_path_y_perf = execution_path_data[
            :, blackbox_function.perf_metric_idx
        ]
        execution_path_y_cost = execution_path_data[
            :, blackbox_function.cost_metric_idx
        ]

        candidate_set_data = config_data[-1]
        candidate_set_x = candidate_set_data[:, blackbox_function.indep_attr_idx]
        candidate_set_y_perf = candidate_set_data[:, blackbox_function.perf_metric_idx]
        candidate_set_y_cost = candidate_set_data[:, blackbox_function.cost_metric_idx]

        object.__setattr__(self, "_X0", init_design_x)
        object.__setattr__(self, "Y0", (init_design_y_perf, init_design_y_cost))

        object.__setattr__(self, "_Z", execution_path_x)
        object.__setattr__(
            self, "f_Z_gt", (execution_path_y_perf, execution_path_y_cost)
        )

        object.__setattr__(self, "_candidate_set", candidate_set_x)
        object.__setattr__(
            self,
            "_candidate_idx_Z",
            torch.arange(
                config_data[-2].shape[0],
                config_data[-2].shape[0] + config_data[-1].shape[0],
            ),
        )  # torch.arange()

        object.__setattr__(
            self, "_A", self._generate_A(int(execution_path_data.shape[0] / 2))
        )

        if blackbox_function.is_pseudo_expensive:
            prop_gt = self.A @ self.f_Z_gt[0]
            object.__setattr__(self, "prop_gt", prop_gt)

        object.__setattr__(self, "lazy_setup_conducted", True)

    @property
    def Z(self) -> str:
        return self._Z

    @property
    def A(self) -> str:
        return self._A

    @property
    def X0(self) -> str:
        return self._X0

    @property
    def candidate_set(self) -> str:
        return self._candidate_set

    @property
    def candidate_idx_Z(self) -> str:
        return self._candidate_idx_Z

    def _generate_A(self, amount_instances) -> torch.Tensor:
        assert (
            self.amount_challenger_configs == 1
        ), f"Currently only supports comparing two configs."

        return torch.cat(
            [
                torch.full((1, amount_instances), 1.0 / amount_instances),
                torch.full((1, amount_instances), -1.0 / amount_instances),
            ],
            dim=-1,
        ).to(torch.float64)
        # Mean performance challenger minus mean performance incumbent => >0 corresponds to challenger is better

    def termination_criterion(self, property_posterior) -> bool:
        if self.lazy_setup:
            assert self.lazy_setup_conducted, f"run_lazy_setup() must be called first."

        return False  # Not implemented


#Define game
#Calls blackbox function internally on sampled coalitions
class ShapIqGame(shapiq.Game):
    def __init__(self, 
                    m,
                    surrogate,
                    archive_size,
                    blackbox_fn,
                    exact= False) -> None:
        super().__init__(
            n_players= m,
            player_names=[str(i) for i in range(m)],
        ) #normalization_value=self.characteristic_function[()],  # 0
        #TODO: Do we need normalization value here?

        self.surrogate= surrogate
        self.archive_size= archive_size
        self.blackbox_fn= blackbox_fn
        self.exact= exact

    def value_function(self, coalitions: np.ndarray) -> np.ndarray:
        """Defines the worth of a coalition as a lookup in the characteristic function.

        Args:
            coalitions: A 2D array where each row represents a coalition as a binary
                vector (1 for present, 0 for absent).

        Returns:
            A 1D array containing the value of each coalition based on the
                characteristic function.
        """
        #torch.tensor(coalitions, dtype= torch.int64)
        coalitions_int= torch.tensor(coalitions, dtype= torch.int64)
        coalitions_numeric= self.surrogate._model.input_transform.untransform(coalitions_int)

        output= self.blackbox_fn(coalitions_numeric)[0].squeeze()

        assert (self.surrogate._model.input_transform.transform(coalitions_numeric) == coalitions_int).all(), "Siq coalition transformation does not match input."
        #assert torch.allclose(self.archive_X, coalitions_numeric), "Siq coalition numeric values do not match archive X."
        
        # if not self.exact:
        #     assert self.archive_size == coalitions_numeric.shape[0], "." #has to be disabled as PermutationSamplingSV evaluates coalitions one by one 
        # #i think we dont really need archive xevaluates

        return np.array([output] if output.ndim ==0 else output)

@dataclass(frozen=True)
class ShapleyApplication(BaseApplication):
    """Concrete implementation for Shapley value estimation."""

    # The evaluation scope of the blackbox function is restricted to Z.
    eval_bb_only_on_Z: bool = field(init=False, default=True)
    lazy_setup: bool = field(init=False, default=True)

    init_design_factor: int = 2 #Initial design size is init_design_factor * m

    # ---------------- required overrides -----------------------------
    def run_lazy_setup(self, blackbox_function: BaseBlackboxFunction, seed: int, amount_iterations: int = None):
        object.__setattr__(self, "blackbox_function", blackbox_function)

        baseline_config, candidate_config = self.sample_configs()
        m = self.get_blackbox_dim()

        #Initialize ShapIQ approximator
        #changes here should be mirrored in get_siq_approximation()

        #Modified to LeverageSHAP
        frontier_generator= ExplanationFrontierGenerator(N= [i for i in range(m)])
        explanation_frontier = frontier_generator.generate_kadd(max_order=1)
        sampling_weights_1 = np.ones(m + 1)

        siq_approximator = PolySHAP(n=m,
                                        explanation_frontier= explanation_frontier,
                                        sampling_weights= sampling_weights_1,
                                        pairing_trick= True, #replacement= False,
                                        random_state= seed)
        # siq_approximator = KernelSHAP(n=m, 
        #                               index= 'SV', 
        #                               max_order= 1, 
        #                               random_state= seed)

        object.__setattr__(self, "siq_approximator", siq_approximator)
        object.__setattr__(self, "seed", seed)
        object.__setattr__(self, "m", m)

        # assert amount_iterations is not None, "amount_iterations must be specified for ShapleyApplication lazy setup."
        # amount_samples= amount_iterations + self.init_design_factor * m

        init_design_size= max(self.init_design_factor * m, m + 1) #At least m+1 samples
        object.__setattr__(self, "init_design_size", init_design_size)

        self.siq_approximator._sampler.sample(init_design_size) #amount_samples
        object.__setattr__(self, "siq_init_design", self.siq_approximator._sampler.coalitions_matrix)
        object.__setattr__(self, "siq_init_design_binary", torch.tensor(self.siq_init_design, dtype= torch.float64))

        # siq_approximator._sampler.sample(3)
        # siq_approximator._sampler.coalitions_matrix

        # Generate Z as all 2^m coalitions (grid of all combinations of 0/1 for m dimensions)
        Z_binary = torch.tensor(
            np.array([list(map(int, np.binary_repr(i, width=m))) for i in range(2**m)]),
            dtype=torch.float64,
        )

        # Generate A as the Shapley value transformation matrix
        A = torch.zeros((m, Z_binary.shape[0]), dtype=torch.float64)

        Z_sum = Z_binary.sum(axis=1)

        def get_shapley_weight(
            amount_players_in_coalition: int,  # count(T)
            feature_in_coalition: bool,  # \i
        ) -> float:
            # avoid warnings here
            with np.errstate(divide="ignore", invalid="ignore"):
                if feature_in_coalition:
                    # Case: Feature in coalition => w(T-1)
                    weight = 1 / (
                        m * scipy.special.comb(m - 1, amount_players_in_coalition - 1)
                    )
                else:
                    # Case: Feature not in coalition => -w(T)
                    weight = -1 / (
                        m * scipy.special.comb(m - 1, amount_players_in_coalition)
                    )
                return torch.tensor(weight, dtype=torch.float64)

        weights_w = get_shapley_weight(
            amount_players_in_coalition=Z_sum, feature_in_coalition=True
        )
        weights_wo = get_shapley_weight(
            amount_players_in_coalition=Z_sum, feature_in_coalition=False
        )

        for feature_idx in range(m):
            # First set all values to -w(T)
            A[feature_idx, :] = weights_wo

            # Then override entries where feature is in coalition to w(T-1)
            A[feature_idx, Z_binary[:, feature_idx] == 1] = weights_w[
                Z_binary[:, feature_idx] == 1
            ]

        assert not torch.any(torch.isnan(A)) and not torch.any(
            torch.isinf(A)
        ), "NaN or Inf values in Shapley A matrix."
        assert torch.allclose(
            torch.sum(A, dim=1), torch.zeros(m, dtype=torch.float64)
        ), "Row sums of Shapley A matrix are not zero."

        # Map binary Z to actual Z (choose baseline config whereever 0, candidate config wherever 1)
        # but only for each feature
        Z = torch.zeros_like(Z_binary)
        for feature_idx in range(Z_binary.shape[-1]):
            Z[:, feature_idx] = torch.where(
                Z_binary[:, feature_idx] == 1,
                candidate_config[feature_idx],
                baseline_config[feature_idx],
            )

        #Use shapley kernel for initial design (even in our GP approach)
        #rand_perm = torch.randperm(Z_binary.shape[0])

        #Mapping from each entry in self.siq_samples_binary to equivalent index in Z_binary
        def map_siq_sample_to_Z_binary_idx(siq_sample: torch.Tensor) -> int:
            #Brute-force search
            for z_idx in range(Z_binary.shape[0]):
                if torch.all(siq_sample == Z_binary[z_idx, :]):
                    return z_idx
            raise ValueError("Could not map siq sample to Z_binary index.")

        siq_init_design_indices_in_Z= []
        for siq_sample_idx in range(self.siq_init_design_binary.shape[0]):
            siq_sample= self.siq_init_design_binary[siq_sample_idx, :]
            siq_idx_in_Z= map_siq_sample_to_Z_binary_idx(siq_sample)
            siq_init_design_indices_in_Z.append(siq_idx_in_Z)

        #Todo: Use shapley kernel for initial design (even for our GP approach - just assume that its better)
        
        #siq_candidate_indices_in_Z as complement of siq_init_design_indices_in_Z in range
        siq_candidate_indices_in_Z= torch.tensor(
            [i for i in range(Z.shape[0]) if i not in siq_init_design_indices_in_Z],
            dtype= torch.int64
        )
        # siq_candidate_indices_in_Z= torch.ones(Z.shape[0], dtype=torch.bool)
        # siq_candidate_indices_in_Z[siq_init_design_indices_in_Z] = False

        X0 = Z[siq_init_design_indices_in_Z, :]
        candidate_set = Z[siq_candidate_indices_in_Z, :]
        candidate_idx_Z = siq_candidate_indices_in_Z

        X0_binary = Z_binary[siq_init_design_indices_in_Z, :]
        candidate_set_binary = Z_binary[siq_candidate_indices_in_Z, :]
        candidate_idx_Z_binary = siq_candidate_indices_in_Z

        assert not ((X0_binary[:, None, :] == candidate_set_binary[None, :, :]).all(dim=2)).any(), "X0_binary and candidate_set_binary have overlapping rows."

        #candidate set is already "ordered" => we just iterate through it in AF maximimization for SIQ baseline

        # Required overrides
        object.__setattr__(self, "_A", A)

        object.__setattr__(self, "baseline_config", baseline_config)
        object.__setattr__(self, "candidate_config", candidate_config)

        object.__setattr__(self, "_Z", Z)
        object.__setattr__(self, "_X0", X0)
        object.__setattr__(self, "_candidate_set", candidate_set)
        object.__setattr__(self, "_candidate_idx_Z", candidate_idx_Z)

        object.__setattr__(self, "_Z_binary", Z_binary)
        object.__setattr__(self, "_X0_binary", X0_binary)
        object.__setattr__(self, "_candidate_set_binary", candidate_set_binary)
        object.__setattr__(self, "_candidate_idx_Z_binary", candidate_idx_Z_binary)

        if self.blackbox_function.is_pseudo_expensive:
            f_Z_gt = self.blackbox_function.evaluate(X=self._Z)
            object.__setattr__(self, "f_Z_gt", f_Z_gt)

            prop_gt = self.A @ self.f_Z_gt[0]
            object.__setattr__(self, "prop_gt", prop_gt)

        # object.__setattr__(self, "Y0", (None, None))

        object.__setattr__(self, "lazy_setup_conducted", True)





    def get_siq_values(self,
                       amount_coalitions,
                       blackbox_fn,
                       surrogate,
                       acquisition_fn_name) -> torch.Tensor:
        #Returns the current ShapIQ SV approximations. Given sampled coalitions of the same size.
        #Note however that the sampled coalitions (of initial design; and between iterations) are independent.
        #Hence CRN is not applied and variance reduction effects are not achieved.

        #See https://shapiq.readthedocs.io/en/latest/_modules/shapiq/approximator/regression/base.html#Regression.kernel_shap_iq_routine
    
        game= ShapIqGame(m= self.m, surrogate= surrogate, archive_size= amount_coalitions, blackbox_fn= blackbox_fn, exact= False)
        sampling_weights_1 = np.ones(self.m + 1)

        #Define KernelSHAP approximator
        if acquisition_fn_name == "KernelSHAPSampler":
            temp_siq_approximator = KernelSHAP(n=self.m, 
                                        index= 'SV', 
                                        max_order= 1,
                                        pairing_trick= True, #Modified
                                        sampling_weights= sampling_weights_1, #Modified
                                        random_state= self.seed)
            
        elif acquisition_fn_name == "LeverageSHAPSampler":
            frontier_generator= ExplanationFrontierGenerator(N= [i for i in range(self.m)])
            explanation_frontier = frontier_generator.generate_kadd(max_order=1)

            temp_siq_approximator = PolySHAP(n=self.m,
                                             explanation_frontier= explanation_frontier,
                                             sampling_weights= sampling_weights_1,
                                             pairing_trick= True, #replacement= False,
                                             random_state= self.seed)
            
        elif acquisition_fn_name == "SVARMSampler":
            temp_siq_approximator = SVARM(n=self.m, 
                                        index= 'SV',
                                        pairing_trick= True, #Modified
                                        max_order= 1, 
                                        sampling_weights= sampling_weights_1, #Modified
                                        random_state= self.seed)
            
        elif acquisition_fn_name == "PermutationSampler":
            temp_siq_approximator = PermutationSamplingSV(n=self.m, 
                                                          random_state= self.seed)

        elif acquisition_fn_name == "RegressionMSRSampler":
            temp_siq_approximator = RegressionMSR(n=self.m,
                                                  pairing_trick= True,
                                                  replacement= False, 
                                                  sampling_weights= sampling_weights_1, #Modified
                                                  random_state= self.seed)

            #even for permuation sampler the amount of evaluated coalitions is specified via budget argument in approximate() call
            
        else:
            raise ValueError("Acquisition function type not supported for ShapIQ value approximation.")
        
        #Even though seed is fixed, sampling n coals vs. sampling n+1 coals does not lead to first n coals being identical
        
        if not acquisition_fn_name == "PermutationSampler":
            temp_ks= temp_siq_approximator.approximate(budget=amount_coalitions, game=game)
            temp_ks_values= temp_ks.values
        
        else:
            temp_ks= temp_siq_approximator.approximate(budget=amount_coalitions, game=game, batch_size= 1)
            temp_ks_values= temp_ks.values

        empty_coalition_index= temp_siq_approximator._sampler.empty_coalition_index

        if empty_coalition_index is None:
            empty_coalition_index= 0

        assert temp_ks.baseline_value == temp_ks_values[empty_coalition_index]

        sv_approximations= np.concatenate([temp_ks_values[:empty_coalition_index],
                                        temp_ks_values[empty_coalition_index +1:]], axis=0)

        return sv_approximations
        
        # import copy
        # temp_siq_approximator= copy.deepcopy(self.siq_approximator)

        #Set coalitions to subset of current archive
        #temp_siq_approximator._sampler.sample(amount_coalitions)
        #print(temp_siq_approximator._sampler.coalitions_matrix)



        # assert torch.tensor(temp_siq_approximator._sampler.coalitions_matrix, dtype= torch.float64).equal(transformed_archive_X), "Siq samples do not match transformed archive X."

        # kernel_weights_dict = {}
        # for interaction_size in range(1, temp_siq_approximator.max_order + 1):
        #     kernel_weights_dict[interaction_size] = temp_siq_approximator._init_kernel_weights(interaction_size)

        # game_values= np.array(archive_Y.squeeze())

        # sv_approximations_wint= temp_siq_approximator.regression_routine(
        #     kernel_weights=kernel_weights_dict[1],
        #     game_values= game_values,
        #     index_approximation= temp_siq_approximator.approximation_index
        # )
        
        # #Select all entries from sv_approximations_wint except temp_siq_approximator._sampler.empty_coalition_index
        # sv_approximations= np.concatenate([sv_approximations_wint[:temp_siq_approximator._sampler.empty_coalition_index],
        #                                 sv_approximations_wint[temp_siq_approximator._sampler.empty_coalition_index +1:]], axis=0)

        # #assert that this is identical to archive

    def get_exact_siq_values(self,
                    amount_coalitions,
                    blackbox_fn,
                    surrogate) -> torch.Tensor:
        
        game= ShapIqGame(m= self.m, surrogate= surrogate, archive_size= amount_coalitions, blackbox_fn= blackbox_fn, exact= True)

        exact_computer = shapiq.ExactComputer(n_players=game.n_players, game=game)
        sv_exact = exact_computer(index="SV", order=1)
        #print(sv_exact)
        return sv_exact.values[1:] #extract dynamically #as long as assert works this should be fine

    def get_levgp_siq_value(self,
                       amount_coalitions,
                       blackbox_fn,
                       partial_gp,
                       acquisition_fn_name) -> torch.Tensor:
        #Hybrid case (Fit GP surrogate on samples from LeverageSHAP)

        #1. Sample coalitions according to LeverageSHAP (SHAPIQ)
        sampling_weights_1 = np.ones(self.m + 1)
        frontier_generator= ExplanationFrontierGenerator(N= [i for i in range(self.m)])
        explanation_frontier = frontier_generator.generate_kadd(max_order=1)

        siq_approximator = PolySHAP(n=self.m,
                                            explanation_frontier= explanation_frontier,
                                            sampling_weights= sampling_weights_1,
                                            pairing_trick= True, #replacement= False,
                                            random_state= self.seed)
        
        siq_approximator._sampler.sample(amount_coalitions)
        
        train_coals_bool= siq_approximator._sampler.coalitions_matrix
        Z_coals_bool= self._Z_binary.bool()

        # Compare every row of a with every row of b
        matches = (train_coals_bool[:, None, :] == Z_coals_bool[None, :, :]).all(dim=-1)
        train_idx = matches.int().argmax(dim=1)

        train_X= self.Z[train_idx]
        train_Y= self.f_Z_gt[0][train_idx]

        #2. Fit GP surrogate on these samples
        gp= partial_gp(train_X, train_Y)
        #gp._model.train_inputs[0]
        gp.fit()

        #3. Get SVs according to GP surrogate
        return np.array(self.property_posterior(gp).mean)

    @property
    def Z(self) -> str:
        return self._Z

    @property
    def A(self) -> str:
        return self._A

    @property
    def X0(self) -> str:
        return self._X0

    @property
    def candidate_set(self) -> str:
        return self._candidate_set

    @property
    def candidate_idx_Z(self) -> str:
        return self._candidate_idx_Z

    def termination_criterion(self, property_posterior) -> torch.Tensor:
        if self.lazy_setup:
            assert self.lazy_setup_conducted, f"run_lazy_setup() must be called first."

        return False  # Not implemented


@dataclass(frozen=True)
class BotorchShapleyApplication(ShapleyApplication):
    def sample_configs(self) -> torch.Tensor:
        # Sample baseline and candidate configurations
        lower_bound = self.blackbox_function._bounds[0, :]
        upper_bound = self.blackbox_function._bounds[1, :]

        baseline_config = lower_bound + torch.rand(lower_bound.shape) * (
            upper_bound - lower_bound
        )
        candidate_config = lower_bound + torch.rand(lower_bound.shape) * (
            upper_bound - lower_bound
        )

        return baseline_config, candidate_config

    def get_blackbox_dim(self) -> int:
        return self.blackbox_function.dim
    
@dataclass(frozen=True)
class ShapiqShapleyApplication(ShapleyApplication):
    def sample_configs(self) -> torch.Tensor:
        baseline_config = torch.zeros(self.blackbox_function.dim, dtype= torch.float64)
        candidate_config = torch.ones(self.blackbox_function.dim, dtype= torch.float64)

        return baseline_config, candidate_config

    def get_blackbox_dim(self) -> int:
        return self.blackbox_function.dim


@dataclass(frozen=True)
class YahpoShapleyApplication(ShapleyApplication):
    def sample_configs(self) -> torch.Tensor:
        
        #if self.blackbox_function.yahpo_name

        def get_xgboost_dart_config():
            temp_config = (
                self.blackbox_function.yahpo_opt_space.sample_configuration(1)
            )

            while temp_config['booster'] != 'dart':
                temp_config = (
                    self.blackbox_function.yahpo_opt_space.sample_configuration(1)
                )    

            return temp_config


        if self.blackbox_function.yahpo_name == 'rbv2_xgboost':
            #Simple workaround to ensure that booster is dart 
            baseline_config = get_xgboost_dart_config()

            fign= min(self.blackbox_function.task_id_index, self.blackbox_function.booster_index)
            sign= max(self.blackbox_function.task_id_index, self.blackbox_function.booster_index)

            baseline_config_numeric= np.concatenate([baseline_config.get_array()[:fign], 
                                                    baseline_config.get_array()[fign+1:sign],
                                                    baseline_config.get_array()[sign+1:]])

        else:
            baseline_config = (
                self.blackbox_function.yahpo_opt_space.sample_configuration(1)
            )

            baseline_config_numeric= np.concatenate([baseline_config.get_array()[:self.blackbox_function.task_id_index], 
                                                    baseline_config.get_array()[self.blackbox_function.task_id_index+1:]])

        # baseline_config_numeric= baseline_config.get_array()[1:] #Ignore OpenML ID at idx 0 (as this is always identical)

        temp_candidate_config_numeric= None #torch.zeros(amount_features)

        while not (baseline_config_numeric != (temp_candidate_config_numeric if temp_candidate_config_numeric is not None else baseline_config_numeric)).all():
            if self.blackbox_function.yahpo_name == 'rbv2_xgboost':
                #Simple workaround to ensure that booster is dart 
                temp_candidate_config = get_xgboost_dart_config()

                temp_candidate_config_numeric= np.concatenate([temp_candidate_config.get_array()[:fign], 
                                                        temp_candidate_config.get_array()[fign+1:sign],
                                                        temp_candidate_config.get_array()[sign+1:]])

            else:
                temp_candidate_config = (
                    self.blackbox_function.yahpo_opt_space.sample_configuration(1)
                )

                # temp_candidate_config = (
                # self.blackbox_function.yahpo_opt_space.sample_configuration(1)
                # )

                temp_candidate_config_numeric= np.concatenate([temp_candidate_config.get_array()[:self.blackbox_function.task_id_index], 
                                                            temp_candidate_config.get_array()[self.blackbox_function.task_id_index+1:]])            
                #temp_candidate_config.get_array()[1:] #Ignore OpenML ID at idx 0


        #wie schauen welche anderen features nuisance sind (zb repl)

        candidate_config= temp_candidate_config
        candidate_config_numeric = temp_candidate_config_numeric

        assert (baseline_config_numeric != candidate_config_numeric).all(), "Baseline and candidate config are identical."
        assert (int(baseline_config[self.blackbox_function.task_id_column_name]) ==
                int(candidate_config[self.blackbox_function.task_id_column_name]) ==
                self.blackbox_function.instance), "Baseline and candidate config have wrong dataset IDs."


        # #Filter to features where values differ
        # mask_ineq= baseline_config != candidate_config

        # baseline_config= baseline_config[mask_ineq]
        # candidate_config= candidate_config[mask_ineq]
        # => Does not work as then we cant remap it

        # #Temp workaround: Add small noise to avoid numerical issues with exact equality
        # #Drawback: Makes Shapley value estimation harder
        # candidate_config[~mask_ineq] += 0.1
        # Does also not work as then we cant remap it

        return torch.tensor(baseline_config_numeric), torch.tensor(candidate_config_numeric)
        # Map to numerical values

        # only return features where values are different

    def get_blackbox_dim(self) -> int:
        temp_config = (
            self.blackbox_function.yahpo_opt_space.sample_configuration(1)
        )

        if self.blackbox_function.yahpo_name == 'rbv2_xgboost':
            return temp_config.get_array().shape[0] - 2 #Ignore OpenML ID at idx 0 and booster

        else:
            return temp_config.get_array().shape[0] - 1 ##Ignore OpenML ID at idx 0
        #We dont want openml id as input, however, in evaluate this has to be added