import contextlib

from .. import transformer
from .. import bar_distribution
import torch
import scipy
import math
from sklearn.preprocessing import power_transform, PowerTransformer

def log01(x, eps=.0000001, input_between_zero_and_one=False):
    logx = torch.log(x + eps)
    if input_between_zero_and_one:
        return (logx - math.log(eps)) / (math.log(1 + eps) - math.log(eps))
    return (logx - logx.min(0)[0]) / (logx.max(0)[0] - logx.min(0)[0])

def log01_batch(x, eps=.0000001, input_between_zero_and_one=False):
    x = x.repeat(1, x.shape[-1] + 1, 1)
    for b in range(x.shape[-1]):
        x[:, b, b] = log01(x[:, b, b], eps=eps, input_between_zero_and_one=input_between_zero_and_one)
    return x

def lognormed_batch(x, eval_pos, eps=.0000001):
    x = x.repeat(1, x.shape[-1] + 1, 1)
    for b in range(x.shape[-1]):
        logx = torch.log(x[:, b, b]+eps)
        x[:, b, b] = (logx - logx[:eval_pos].mean(0))/logx[:eval_pos].std(0)
    return x

def _rank_transform(x_train, x):
    assert len(x_train.shape) == len(x.shape) == 1
    relative_to = torch.cat((torch.zeros_like(x_train[:1]),x_train.unique(sorted=True,), torch.ones_like(x_train[-1:])),-1)
    higher_comparison = (relative_to < x[...,None]).sum(-1).clamp(min=1)
    pos_inside_interval = (x - relative_to[higher_comparison-1])/(relative_to[higher_comparison] - relative_to[higher_comparison-1])
    x_transformed = higher_comparison - 1 + pos_inside_interval
    return x_transformed/(len(relative_to)-1.)

def rank_transform(x_train, x):
    assert x.shape[1] == x_train.shape[1], f"{x.shape=} and {x_train.shape=}"
    # make sure everything is between 0 and 1
    assert (x_train >= 0.).all() and (x_train <= 1.).all(), f"{x_train=}"
    assert (x >= 0.).all() and (x <= 1.).all(), f"{x=}"
    return_x = x.clone()
    for feature_dim in range(x.shape[1]):
        return_x[:, feature_dim] = _rank_transform(x_train[:, feature_dim], x[:, feature_dim])
    return return_x



def general_power_transform(x_train, x_apply, eps, less_safe=False):
    if eps > 0:
        try:
            pt = PowerTransformer(method='box-cox')
            pt.fit(x_train.cpu()+eps)
            x_out = torch.tensor(pt.transform(x_apply.cpu()+eps), dtype=x_apply.dtype, device=x_apply.device)
        except ValueError as e:
            print(e)
            x_out = x_apply - x_train.mean(0)
    else:
        pt = PowerTransformer(method='yeo-johnson')
        if not less_safe and (x_train.std() > 1_000 or x_train.mean().abs() > 1_000):
            x_apply = (x_apply - x_train.mean(0)) / x_train.std(0)
            x_train = (x_train - x_train.mean(0)) / x_train.std(0)
            print('inputs are LAARGEe, normalizing them')
        try:
            pt.fit(x_train.cpu().double())
        except ValueError as e:
            print('caught this errrr', e)
            if less_safe:
                x_train = (x_train - x_train.mean(0)) / x_train.std(0)
                x_apply = (x_apply - x_train.mean(0)) / x_train.std(0)
            else:
                x_train = x_train - x_train.mean(0)
                x_apply = x_apply - x_train.mean(0)
            pt.fit(x_train.cpu().double())
        x_out = torch.tensor(pt.transform(x_apply.cpu()), dtype=x_apply.dtype, device=x_apply.device)
    if torch.isnan(x_out).any() or torch.isinf(x_out).any():
        print('WARNING: power transform failed')
        print(f"{x_train=} and {x_apply=}")
        x_out = x_apply - x_train.mean(0)
    return x_out



#@torch.inference_mode()
def general_acq_function(model: transformer.TransformerModel, x_given, y_given, x_eval, apply_power_transform=True,
                    rand_sample=False, znormalize=False, pre_normalize=False, pre_znormalize=False, predicted_mean_fbest=False,
                    input_znormalize=False, max_dataset_size=10_000, remove_features_with_one_value_only=False,
                    return_actual_ei=False, acq_function='ei', ucb_rest_prob=.05, ensemble_log_dims=False,
                    ensemble_type='mean_probs', # in ('mean_probs', 'max_acq')
                    input_power_transform=False, power_transform_eps=.0, input_power_transform_eps=.0,
                    input_rank_transform=False, ensemble_input_rank_transform=False,
                    ensemble_power_transform=False, ensemble_feature_rotation=False,
                    style=None, outlier_stretching_interval=0.0, verbose=False, unsafe_power_transform=False,
                         ):
    """
    Differences to HEBO:
        - The noise can't be set in the same way, as it depends on the tuning of HPs via VI.
        - Log EI and PI are always used directly instead of using the approximation.

    This is a stochastic function, relying on torch.randn

    :param model:
    :param x_given: torch.Tensor of shape (N, D)
    :param y_given: torch.Tensor of shape (N, 1) or (N,)
    :param x_eval: torch.Tensor of shape (M, D)
    :param kappa:
    :param eps:
    :return:
    """
    assert ensemble_type in ('mean_probs', 'max_acq')
    if rand_sample is not False \
        and (len(x_given) == 0 or
             ((1 + x_given.shape[1] if rand_sample is None else max(2, rand_sample)) > x_given.shape[0])):
        print('rando')
        return torch.zeros_like(x_eval[:,0]) #torch.randperm(x_eval.shape[0])[0]
    y_given = y_given.reshape(-1)
    assert len(y_given) == len(x_given)
    if apply_power_transform:
        if pre_normalize:
            y_normed = y_given / y_given.std()
            if not torch.isinf(y_normed).any() and not torch.isnan(y_normed).any():
                y_given = y_normed
        elif pre_znormalize:
            y_znormed = (y_given - y_given.mean()) / y_given.std()
            if not torch.isinf(y_znormed).any() and not torch.isnan(y_znormed).any():
                y_given = y_znormed
        y_given = general_power_transform(y_given.unsqueeze(1), y_given.unsqueeze(1), power_transform_eps, less_safe=unsafe_power_transform).squeeze(1)
        if verbose:
            print(f"{y_given=}")
        #y_given = torch.tensor(power_transform(y_given.cpu().unsqueeze(1), method='yeo-johnson', standardize=znormalize), device=y_given.device, dtype=y_given.dtype,).squeeze(1)
    y_given_std = torch.tensor(1., device=y_given.device, dtype=y_given.dtype)
    if znormalize and not apply_power_transform:
        if len(y_given) > 1:
            y_given_std = y_given.std()
        y_given_mean = y_given.mean()
        y_given = (y_given - y_given_mean) / y_given_std

    if remove_features_with_one_value_only:
        x_all = torch.cat([x_given, x_eval], dim=0)
        only_one_value_feature = torch.tensor([len(torch.unique(x_all[:,i])) for i in range(x_all.shape[1])]) == 1
        x_given = x_given[:,~only_one_value_feature]
        x_eval = x_eval[:,~only_one_value_feature]

    if outlier_stretching_interval > 0.:
        tx = torch.cat([x_given, x_eval], dim=0)
        m = outlier_stretching_interval
        eps = 1e-10
        small_values = (tx < m) & (tx > 0.)
        tx[small_values] = m * (torch.log(tx[small_values] + eps) - math.log(eps)) / (math.log(m + eps) - math.log(eps))

        large_values = (tx > 1. - m) & (tx < 1.)
        tx[large_values] = 1. - m * (torch.log(1 - tx[large_values] + eps) - math.log(eps)) / (
                    math.log(m + eps) - math.log(eps))
        x_given = tx[:len(x_given)]
        x_eval = tx[len(x_given):]

    if input_znormalize: # implementation that relies on the test set, too...
        std = x_given.std(dim=0)
        std[std == 0.] = 1.
        mean = x_given.mean(dim=0)
        x_given = (x_given - mean) / std
        x_eval = (x_eval - mean) / std

    if input_power_transform:
        x_given = general_power_transform(x_given, x_given, input_power_transform_eps)
        x_eval = general_power_transform(x_given, x_eval, input_power_transform_eps)

    if input_rank_transform is True or input_rank_transform == 'full': # uses test set x statistics...
        x_all = torch.cat((x_given,x_eval), dim=0)
        for feature_dim in range(x_all.shape[-1]):
            uniques = torch.sort(torch.unique(x_all[..., feature_dim])).values
            x_eval[...,feature_dim] = torch.searchsorted(uniques,x_eval[..., feature_dim]).float() / (len(uniques)-1)
            x_given[...,feature_dim] = torch.searchsorted(uniques,x_given[..., feature_dim]).float() / (len(uniques)-1)
    elif input_rank_transform is False:
        pass
    elif input_rank_transform == 'train':
        x_given = rank_transform(x_given, x_given)
        x_eval = rank_transform(x_given, x_eval)
    elif input_rank_transform.startswith('train'):
        likelihood = float(input_rank_transform.split('_')[-1])
        if torch.rand(1).item() < likelihood:
            print('rank transform')
            x_given = rank_transform(x_given, x_given)
            x_eval = rank_transform(x_given, x_eval)
    else:
        raise NotImplementedError


    # compute logits
    criterion: bar_distribution.BarDistribution = model.criterion
    x_predict = torch.cat([x_given, x_eval], dim=0)


    logits_list = []
    for x_feed in torch.split(x_predict, max_dataset_size, dim=0):
        x_full_feed = torch.cat([x_given, x_feed], dim=0).unsqueeze(1)
        y_full_feed = y_given.unsqueeze(1)
        if ensemble_log_dims == '01':
            x_full_feed = log01_batch(x_full_feed)
        elif ensemble_log_dims == 'global01' or ensemble_log_dims is True:
            x_full_feed = log01_batch(x_full_feed, input_between_zero_and_one=True)
        elif ensemble_log_dims == '01-10':
            x_full_feed = torch.cat((log01_batch(x_full_feed)[:, :-1], log01_batch(1. - x_full_feed)), 1)
        elif ensemble_log_dims == 'norm':
            x_full_feed = lognormed_batch(x_full_feed, len(x_given))
        elif ensemble_log_dims is not False:
            raise NotImplementedError

        if ensemble_feature_rotation:
            x_full_feed = torch.cat([x_full_feed[:, :, (i+torch.arange(x_full_feed.shape[2])) % x_full_feed.shape[2]] for i in range(x_full_feed.shape[2])], dim=1)

        if ensemble_input_rank_transform == 'train' or ensemble_input_rank_transform is True:
            x_full_feed = torch.cat([rank_transform(x_given, x_full_feed[:,i,:])[:,None] for i in range(x_full_feed.shape[1])] + [x_full_feed], dim=1)

        if ensemble_power_transform:
            assert apply_power_transform is False
            y_full_feed = torch.cat((general_power_transform(y_full_feed, y_full_feed, power_transform_eps), y_full_feed), dim=1)


        if style is not None:
            if callable(style):
                style = style()

            if isinstance(style, torch.Tensor):
                style = style.to(x_full_feed.device)
            else:
                style = torch.tensor(style, device=x_full_feed.device).view(1, 1).repeat(x_full_feed.shape[1], 1)


        logits = model(
            (style,
             x_full_feed.repeat_interleave(dim=1, repeats=y_full_feed.shape[1]),
             y_full_feed.repeat(1,x_full_feed.shape[1])),
            single_eval_pos=len(x_given)
        )
        if ensemble_type == 'mean_probs':
            logits = logits.softmax(-1).mean(1, keepdim=True).log_()  # (num given + num eval, 1, num buckets)

        logits_list.append(logits)  # (< max_dataset_size, 1 , num_buckets)
    logits = torch.cat(logits_list, dim=0) # (num given + num eval, 1 or (num_features+1), num buckets)
    del logits_list, x_full_feed
    if torch.isnan(logits).any():
        print('nan logits')
        print(f"y_given: {y_given}, x_given: {x_given}, x_eval: {x_eval}")
        print(f"logits: {logits}")
        return torch.zeros_like(x_eval[:,0])

    #logits = model((torch.cat([x_given, x_given, x_eval], dim=0).unsqueeze(1),
    #               torch.cat([y_given, torch.zeros(len(x_eval)+len(x_given), device=y_given.device)], dim=0).unsqueeze(1)),
    #               single_eval_pos=len(x_given))[:,0] # (N + M, num_buckets)
    logits_given = logits[:len(x_given)]
    logits_eval = logits[len(x_given):]

    #tau = criterion.mean(logits_given)[torch.argmax(y_given)] # predicted mean at the best y
    if predicted_mean_fbest:
        tau = criterion.mean(logits_given)[torch.argmax(y_given)].squeeze(0)
    else:
        tau = torch.max(y_given)
    #log_ei = torch.stack([criterion.ei(logits_eval[:,i], noisy_best_f[i]).log() for i in range(len(logits_eval))],0)

    def acq_ensembling(acq_values): # (points, ensemble dim)
        return acq_values.max(1).values

    if isinstance(acq_function, (dict,list)):
        acq_function = acq_function[style]

    if acq_function == 'ei':
        acq_value = acq_ensembling(criterion.ei(logits_eval, tau))
    elif acq_function == 'ei_or_rand':
        if torch.rand(1).item() < 0.5:
            acq_value = torch.rand(len(x_eval))
        else:
            acq_value = acq_ensembling(criterion.ei(logits_eval, tau))
    elif acq_function == 'pi':
        acq_value = acq_ensembling(criterion.pi(logits_eval, tau))
    elif acq_function == 'ucb':
        acq_function = criterion.ucb
        if ucb_rest_prob is not None:
            acq_function = lambda *args: criterion.ucb(*args, rest_prob=ucb_rest_prob)
        acq_value = acq_ensembling(acq_function(logits_eval, tau))
    elif acq_function == 'mean':
        acq_value = acq_ensembling(criterion.mean(logits_eval))
    elif acq_function.startswith('hebo'):
        noise, upsi, delta, eps = (float(v) for v in acq_function.split('_')[1:])
        noise = y_given_std * math.sqrt(2 * noise)
        kappa = math.sqrt(upsi * 2 * ((2.0 + x_given.shape[1] / 2.0) * math.log(max(1, len(x_given))) + math.log(
            3 * math.pi ** 2 / (3 * delta))))
        rest_prob = 1. - .5 * (1 + torch.erf(torch.tensor(kappa / math.sqrt(2), device=logits.device)))
        ucb = acq_ensembling(criterion.ucb(logits_eval, None, rest_prob=rest_prob)) \
              + torch.randn(len(logits_eval), device=logits_eval.device) * noise
        noisy_best_f = tau + eps + \
                       noise * torch.randn(len(logits_eval), device=logits_eval.device)[:, None].repeat(1, logits_eval.shape[1])

        log_pi = acq_ensembling(criterion.pi(logits_eval, noisy_best_f).log())
        # log_ei = torch.stack([criterion.ei(logits_eval[:,i], noisy_best_f[i]).log() for i in range(len(logits_eval))],0)
        log_ei = acq_ensembling(criterion.ei(logits_eval, noisy_best_f).log())

        acq_values = torch.stack([ucb, log_ei, log_pi], dim=1)

        def is_pareto_efficient(costs):
            """
            Find the pareto-efficient points
            :param costs: An (n_points, n_costs) array
            :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient
            """
            is_efficient = torch.ones(costs.shape[0], dtype=bool, device=costs.device)
            for i, c in enumerate(costs):
                if is_efficient[i]:
                    is_efficient[is_efficient.clone()] = (costs[is_efficient] < c).any(
                        1)  # Keep any point with a lower cost
                    is_efficient[i] = True  # And keep self
            return is_efficient

        acq_value = is_pareto_efficient(-acq_values)
    else:
        raise ValueError(f'Unknown acquisition function: {acq_function}')

    max_acq = acq_value.max()

    return acq_value if return_actual_ei else (acq_value == max_acq)


def optimize_acq(model, known_x, known_y, num_grad_steps=10, num_random_samples=100, lr=.01, **kwargs):
    """
    intervals are assumed to be between 0 and 1
    only works with ei
    recommended extra kwarg: ensemble_input_rank_transform=='train'

    :param model: model to optimize, should already handle different num_features with its encoder
    You can add this simply with `model.encoder = encoders.VariableNumFeaturesEncoder(model.encoder, model.encoder.num_features)`
    :param known_x: (N, num_features)
    :param known_y: (N,)
    :param num_grad_steps: int
    :param num_random_samples: int
    :param lr: float
    :param kwargs: will be given to `general_acq_function`
    :return:
    """
    x_eval = torch.rand(num_random_samples, known_x.shape[1]).requires_grad_(True)
    opt = torch.optim.Adam(params=[x_eval], lr=lr)
    best_acq, best_x = -float('inf'), x_eval[0].detach()
    for grad_step in range(num_grad_steps):
        acq = general_acq_function(model, known_x, known_y, x_eval, return_actual_ei=True, **kwargs)
        max_acq = acq.detach().max().item()
        if max_acq > best_acq:
            best_x = x_eval[acq.argmax()].detach()
            best_acq = max_acq

        (-acq.mean()).backward()
        assert (x_eval.grad != 0.).any()
        if torch.isfinite(x_eval.grad).all():
            opt.step()
        opt.zero_grad()
        with torch.no_grad():
            x_eval.clamp_(min=0., max=1.)

    return best_x


def optimize_acq_w_lbfgs(model, known_x, known_y, num_grad_steps=15_000, num_candidates=100, pre_sample_size=100_000, device='cpu',
                         verbose=False, dims_wo_gradient_opt=[], rand_sample_func=None, **kwargs):
    """
    intervals are assumed to be between 0 and 1
    only works with deterministic acq
    recommended extra kwarg: ensemble_input_rank_transform=='train'

    :param model: model to optimize, should already handle different num_features with its encoder
    You can add this simply with `model.encoder = encoders.VariableNumFeaturesEncoder(model.encoder, model.encoder.num_features)`
    :param known_x: (N, num_features)
    :param known_y: (N,)
    :param num_grad_steps: int: how many steps to take inside of scipy, can be left high, as it stops most of the time automatically early
    :param num_candidates: int: how many candidates to optimize with LBFGS, increases costs when higher
    :param pre_sample_size: int: how many settings to try first with a random search, before optimizing the best with grads
    :param dims_wo_gradient_opt: int: which dimensions to not optimize with gradients, but with random search only
    :param rand_sample_func: function: how to sample random points, should be a function that takes a number of samples and returns a tensor
    For example `lambda n: torch.rand(n, known_x.shape[1])`.
    :param kwargs: will be given to `general_acq_function`
    :return:
    """
    num_features = known_x.shape[1]
    dims_w_gradient_opt = sorted(set(range(num_features)) - set(dims_wo_gradient_opt))
    known_x = known_x.to(device)
    known_y = known_y.to(device)
    pre_sample_size = max(pre_sample_size, num_candidates)
    rand_sample_func = rand_sample_func or (lambda n: torch.rand(n, num_features, device=device))
    if len(known_x) < pre_sample_size:
        x_initial = torch.cat((rand_sample_func(pre_sample_size-len(known_x)).to(device), known_x), 0)
    else:
        x_initial = rand_sample_func(pre_sample_size)
    x_initial = x_initial.clamp(min=0., max=1.)
    x_initial_all = x_initial
    model.to(device)

    with torch.no_grad():
        acq = general_acq_function(model, known_x, known_y, x_initial.to(device), return_actual_ei=True, **kwargs)
        if verbose:
            import matplotlib.pyplot as plt
            if x_initial.shape[1] == 2:
                plt.title('initial acq values, red -> blue')
                plt.scatter(x_initial[:, 0][:100], x_initial[:, 1][:100], c=acq.cpu().numpy()[:100], cmap='RdBu')
        x_initial = x_initial[acq.argsort(descending=True)[:num_candidates].cpu()].detach() # num_candidates x num_features

    x_initial_all_ei = acq.cpu().detach()

    def opt_f(x):
        x_eval = torch.tensor(x).view(-1, len(dims_w_gradient_opt)).float().to(device).requires_grad_(True)
        x_eval_new = x_initial.clone().detach().to(device)
        x_eval_new[:, dims_w_gradient_opt] = x_eval

        assert x_eval_new.requires_grad
        assert not torch.isnan(x_eval_new).any()
        model.requires_grad_(False)
        acq = general_acq_function(model, known_x, known_y, x_eval_new, return_actual_ei=True, **kwargs)
        neg_mean_acq = -acq.mean()
        neg_mean_acq.backward()
        #print(neg_mean_acq.detach().numpy(), x_eval.grad.detach().view(*x.shape).numpy())
        with torch.no_grad():
            x_eval.grad[x_eval.grad != x_eval.grad] = 0.
        return neg_mean_acq.detach().cpu().to(torch.float64).numpy(), \
            x_eval.grad.detach().view(*x.shape).cpu().to(torch.float64).numpy()

    # Optimize best candidates with LBFGS
    if num_grad_steps > 0 and len(dims_w_gradient_opt) > 0:
        # the columns not in dims_wo_gradient_opt will be optimized with gradients
        x_initial_for_gradient_opt = x_initial[:, dims_w_gradient_opt].detach().cpu().flatten().numpy() # x_initial.cpu().flatten().numpy()
        res = scipy.optimize.minimize(opt_f, x_initial_for_gradient_opt, method='L-BFGS-B', jac=True,
                                      bounds=[(0, 1)]*x_initial_for_gradient_opt.size,
                                      options={'maxiter': num_grad_steps})
        results = x_initial.cpu()
        results[:, dims_w_gradient_opt] = torch.tensor(res.x).float().view(-1, len(dims_w_gradient_opt))

    else:
        results = x_initial.cpu()

    results = results.clamp(min=0., max=1.)

    # Recalculate the acq values for the best candidates
    with torch.no_grad():
        acq = general_acq_function(model, known_x, known_y, results.to(device), return_actual_ei=True, verbose=verbose, **kwargs)
        #print(acq)
        if verbose:
            from scipy.stats import rankdata
            import matplotlib.pyplot as plt
            if results.shape[1] == 2:
                plt.scatter(results[:, 0], results[:, 1], c=rankdata(acq.cpu().numpy()), marker='x', cmap='RdBu')
                plt.show()
        best_x = results[acq.argmax().item()].detach()

    acq_order = acq.argsort(descending=True).cpu()
    all_order = x_initial_all_ei.argsort(descending=True).cpu()

    return best_x.detach(), results[acq_order].detach(), acq.cpu()[acq_order].detach(), x_initial_all.cpu()[all_order].detach(), x_initial_all_ei.cpu()[all_order].detach()

from ..utils import to_tensor

class TransformerBOMethod:

    def __init__(self, model, acq_f=general_acq_function, device='cpu:0', fit_encoder=None, **kwargs):
        self.model = model
        self.device = device
        self.kwargs = kwargs
        self.acq_function = acq_f
        self.fit_encoder = fit_encoder

    @torch.no_grad()
    def observe_and_suggest(self, X_obs, y_obs, X_pen, return_actual_ei=False):
        # assert X_pen is not None
        # assumptions about X_obs and X_pen:
        # X_obs is a numpy array of shape (n_samples, n_features)
        # y_obs is a numpy array of shape (n_samples,), between 0 and 1
        # X_pen is a numpy array of shape (n_samples_left, n_features)
        X_obs = to_tensor(X_obs, device=self.device).to(torch.float32)
        y_obs = to_tensor(y_obs, device=self.device).to(torch.float32).view(-1)
        X_pen = to_tensor(X_pen, device=self.device).to(torch.float32)

        assert len(X_obs) == len(y_obs), "make sure both X_obs and y_obs have the same length."

        self.model.to(self.device)

        if self.fit_encoder is not None:
            w = self.fit_encoder(self.model, X_obs, y_obs)
            X_obs = w(X_obs)
            X_pen = w(X_pen)

        with (torch.cuda.amp.autocast() if self.device[:3] != 'cpu' else contextlib.nullcontext()):
            acq_values = self.acq_function(self.model, X_obs, y_obs,
                                           X_pen, return_actual_ei=return_actual_ei, **self.kwargs).cpu().clone()  # bool array
            acq_mask = acq_values.max() == acq_values
        possible_next = torch.arange(len(X_pen))[acq_mask]
        if len(possible_next) == 0:
            possible_next = torch.arange(len(X_pen))

        r = possible_next[torch.randperm(len(possible_next))[0]].cpu().item()


        if return_actual_ei:
            return r, acq_values
        else:
            return r







