import math
import os
import warnings
from dataclasses import dataclass

import gpytorch
import torch
from gpytorch.constraints import Interval
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from torch import Tensor
from torch.quasirandom import SobolEngine

from botorch.fit import fit_gpytorch_mll
# Constrained Max Posterior Sampling s a new sampling class, similar to MaxPosteriorSampling,
# which implements the constrained version of Thompson Sampling described in [1].
from botorch.generation.sampling import ConstrainedMaxPosteriorSampling, MaxPosteriorSampling
from botorch.models import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.transforms.outcome import Standardize
from botorch.utils.transforms import unnormalize


@dataclass
class ScboState:
    dim: int
    batch_size: int
    length: float = 1.
    length_min: float = 0.5**4
    length_max: float = 2.
    failure_counter: int = 0
    failure_tolerance: int = 2  # float("nan")  # Note: Post-initialized
    success_counter: int = 0
    success_tolerance: int = 2
    best_value: float = -float("inf")
    best_constraint_values: Tensor = torch.ones(1) * torch.inf
    restart_triggered: bool = False

    # def __post_init__(self):
    #     self.failure_tolerance = math.ceil(max([4.0 / self.batch_size, float(self.dim) / self.batch_size]))


def update_tr_length(state: ScboState):
    # Update the length of the trust region according to
    # success and failure counters
    # (Just as in original TuRBO paper)
    if state.success_counter == state.success_tolerance:  # Expand trust region
        state.length = min(2.0 * state.length, state.length_max)
        state.success_counter = 0
    elif state.failure_counter == state.failure_tolerance:  # Shrink trust region
        state.length /= 2.0
        state.failure_counter = 0

    if state.length < state.length_min:  # Restart when trust region becomes too small
        state.restart_triggered = True

    return state


def get_best_index_for_batch(Y: Tensor, C: Tensor):
    """Return the index for the best point."""
    is_feas = (C <= 0).all(dim=-1)
    if is_feas.any():  # Choose best feasible candidate
        score = Y.clone()
        score[~is_feas] = -float("inf")
        return score.argmax()
    return C.clamp(min=0).sum(dim=-1).argmin()


def update_state(state, Y_next, C_next):
    """Method used to update the TuRBO state after each step of optimization.

    Success and failure counters are updated according to the objective values
    (Y_next) and constraint values (C_next) of the batch of candidate points
    evaluated on the optimization step.

    As in the original TuRBO paper, a success is counted whenver any one of the
    new candidate points improves upon the incumbent best point. The key difference
    for SCBO is that we only compare points by their objective values when both points
    are valid (meet all constraints). If exactly one of the two points being compared
    violates a constraint, the other valid point is automatically considered to be better.
    If both points violate some constraints, we compare them inated by their constraint values.
    The better point in this case is the one with minimum total constraint violation
    (the minimum sum of constraint values)"""

    # Pick the best point from the batch
    best_ind = get_best_index_for_batch(Y=Y_next, C=C_next)
    y_next, c_next = Y_next[best_ind], C_next[best_ind]

    if (c_next <= 0).all():
        # At least one new candidate is feasible
        improvement_threshold = state.best_value + 1e-3 * math.fabs(state.best_value)
        if y_next > improvement_threshold or (state.best_constraint_values > 0).any():
            state.success_counter += 1
            state.failure_counter = 0
            state.best_value = y_next.item()
            state.best_constraint_values = c_next
        else:
            state.success_counter = 0
            state.failure_counter += 1
    else:
        # No new candidate is feasible
        total_violation_next = c_next.clamp(min=0).sum(dim=-1)
        total_violation_center = state.best_constraint_values.clamp(min=0).sum(dim=-1)
        if total_violation_next < total_violation_center:
            state.success_counter += 1
            state.failure_counter = 0
            state.best_value = y_next.item()
            state.best_constraint_values = c_next
        else:
            state.success_counter = 0
            state.failure_counter += 1

    # Update the length of the trust region according to the success and failure counters
    state = update_tr_length(state)
    return state


def get_initial_points(dim, n_pts, seed=0):
    sobol = SobolEngine(dimension=dim, scramble=True, seed=seed)
    X_init = sobol.draw(n=n_pts)
    return X_init


def generate_batch(
    state: ScboState,
    model,  # GP model
    x_center,
    # X,  # Evaluated points on the domain [0, 1]^d
    # Y,  # Function values
    # C,  # Constraint values
    batch_size,
    # n_candidates,  # Number of candidates for Thompson sampling
    constraint_model,
    # sobol: SobolEngine,
    X_possible,
):
    # assert X.min() >= 0.0 and X.max() <= 1.0 and torch.all(torch.isfinite(Y))

    # Create the TR bounds
    # best_ind = get_best_index_for_batch(Y=Y, C=C)
    # x_center = X[best_ind, :].clone()
    # tr_lb = torch.clamp(x_center - state.length / 2.0, 0.0, 1.0)
    # tr_ub = torch.clamp(x_center + state.length / 2.0, 0.0, 1.0)

    # # Thompson Sampling w/ Constraints (SCBO)
    # dim = x_center.shape[-1]
    # pert = sobol.draw(n_candidates)
    # pert = tr_lb + (tr_ub - tr_lb) * pert

    # # Create a perturbation mask
    # prob_perturb = min(20.0 / dim, 1.0)
    # mask = torch.rand(n_candidates, dim) <= prob_perturb
    # ind = torch.where(mask.sum(dim=1) == 0)[0]
    # mask[ind, torch.randint(0, dim - 1, size=(len(ind),))] = 1

    # # Create candidate points from the perturbations and the mask
    # X_cand = x_center.expand(n_candidates, dim).clone()
    # X_cand[mask] = pert[mask]
    
    s = 1.
    while True:
        tr_lb = torch.clamp(x_center - s * state.length / 2.0, 0.0, 1.0)
        tr_ub = torch.clamp(x_center + s * state.length / 2.0, 0.0, 1.0)
        f_lb = (X_possible >= tr_lb[..., :]).all(axis=1)
        f_ub = (X_possible <= tr_ub[..., :]).all(axis=1)
        flag = f_lb & f_ub
        if sum(flag) > batch_size:
            break
        else:
            # already cannot find anything else in the zoomed in region
            state.restart_triggered = True
            s *= 2.
        
    X_cand = X_possible[flag]
    if X_cand.shape[0] == batch_size:
        X_next = X_cand
        print(X_cand, x_center)
        
    else:
        if constraint_model is not None:
            # Sample on the candidate points using Constrained Max Posterior Sampling
            ts = ConstrainedMaxPosteriorSampling(
                model=model, constraint_model=constraint_model, replacement=False
            )
        else:
            ts = MaxPosteriorSampling(
                model=model, replacement=False,
            )
        with torch.no_grad():
            X_next = ts(X_cand, num_samples=batch_size)
            
    # matches = torch.all(X_next[0] == X_cand, dim=1)  # Compare element-wise and reduce along columns
    # assert torch.any(matches)
    # print('[In SCBO] Selected batch =', X_next, torch.any(matches))
    
    aux = {
        'scaleup': s,
        'candidate_num': sum(flag).tolist(),
    }

    return X_next, aux



