import random
import warnings

import botorch
import gpytorch

import numpy as np
import torch
from gpytorch.means import ZeroMean
from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior, UniformPrior
from torch import nn
from botorch.models.transforms.input import *
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from gpytorch.constraints import GreaterThan

from ..utils import default_device, to_tensor

from . import utils
from .prior import Batch
from .utils import get_batch_to_dataloader


class Warp(gpytorch.Module):
    r"""A transform that uses learned input warping functions.

    Each specified input dimension is warped using the CDF of a
    Kumaraswamy distribution. Typically, MAP estimates of the
    parameters of the Kumaraswamy distribution, for each input
    dimension, are learned jointly with the GP hyperparameters.

    for each output in batched multi-output and multi-task models.

    For now, ModelListGPs should be used to learn independent warping
    functions for each output.
    """

    _min_concentration_level = 1e-4

    def __init__(
        self,
        indices: List[int],
        transform_on_train: bool = True,
        transform_on_eval: bool = True,
        transform_on_fantasize: bool = True,
        reverse: bool = False,
        eps: float = 1e-7,
        concentration1_prior: Optional[Prior] = None,
        concentration0_prior: Optional[Prior] = None,
        batch_shape: Optional[torch.Size] = None,
    ) -> None:
        r"""Initialize transform.

        Args:
            indices: The indices of the inputs to warp.
            transform_on_train: A boolean indicating whether to apply the
                transforms in train() mode. Default: True.
            transform_on_eval: A boolean indicating whether to apply the
                transform in eval() mode. Default: True.
            transform_on_fantasize: A boolean indicating whether to apply the
                transform when called from within a `fantasize` call. Default: True.
            reverse: A boolean indicating whether the forward pass should untransform
                the inputs.
            eps: A small value used to clip values to be in the interval (0, 1).
            concentration1_prior: A prior distribution on the concentration1 parameter
                of the Kumaraswamy distribution.
            concentration0_prior: A prior distribution on the concentration0 parameter
                of the Kumaraswamy distribution.
            batch_shape: The batch shape.
        """
        super().__init__()
        self.register_buffer("indices", torch.tensor(indices, dtype=torch.long))
        self.transform_on_train = transform_on_train
        self.transform_on_eval = transform_on_eval
        self.transform_on_fantasize = transform_on_fantasize
        self.reverse = reverse
        self.batch_shape = batch_shape or torch.Size([])
        self._X_min = eps
        self._X_range = 1 - 2 * eps
        if len(self.batch_shape) > 0:
            # Note: this follows the gpytorch shape convention for lengthscales
            # There is ongoing discussion about the extra `1`.
            # https://github.com/cornellius-gp/gpytorch/issues/1317
            batch_shape = self.batch_shape + torch.Size([1])
        else:
            batch_shape = self.batch_shape
        for i in (0, 1):
            p_name = f"concentration{i}"
            self.register_parameter(
                p_name,
                nn.Parameter(torch.full(batch_shape + self.indices.shape, 1.0)),
            )
        if concentration0_prior is not None:

            def closure(m):
                # print(m.concentration0)
                return m.concentration0

            self.register_prior(
                "concentration0_prior",
                concentration0_prior,
                closure,
                lambda m, v: m._set_concentration(i=0, value=v),
            )
        if concentration1_prior is not None:

            def closure(m):
                # print(m.concentration1)
                return m.concentration1

            self.register_prior(
                "concentration1_prior",
                concentration1_prior,
                closure,
                lambda m, v: m._set_concentration(i=1, value=v),
            )
        for i in (0, 1):
            p_name = f"concentration{i}"
            constraint = GreaterThan(
                self._min_concentration_level,
                transform=None,
                # set the initial value to be the identity transformation
                initial_value=1.0,
            )
            self.register_constraint(param_name=p_name, constraint=constraint)

    def _set_concentration(self, i: int, value: Union[float, Tensor]) -> None:
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.concentration0)
        self.initialize(**{f"concentration{i}": value})

    def _transform(self, X: Tensor) -> Tensor:
        r"""Warp the inputs through the Kumaraswamy CDF.

        Args:
            X: A `input_batch_shape x (batch_shape) x n x d`-dim tensor of inputs.
                batch_shape here can either be self.batch_shape or 1's such that
                it is broadcastable with self.batch_shape if self.batch_shape is set.

        Returns:
            A `input_batch_shape x (batch_shape) x n x d`-dim tensor of transformed
                inputs.
        """
        X_tf = expand_and_copy_tensor(X=X, batch_shape=self.batch_shape)
        k = Kumaraswamy(
            concentration1=self.concentration1, concentration0=self.concentration0
        )
        # normalize to [eps, 1-eps]
        X_tf[..., self.indices] = k.cdf(
            torch.clamp(
                X_tf[..., self.indices] * self._X_range + self._X_min,
                self._X_min,
                1.0 - self._X_min,
            )
        )
        return X_tf

    def _untransform(self, X: Tensor) -> Tensor:
        r"""Warp the inputs through the Kumaraswamy inverse CDF.

        Args:
            X: A `input_batch_shape x batch_shape x n x d`-dim tensor of inputs.

        Returns:
            A `input_batch_shape x batch_shape x n x d`-dim tensor of transformed
                inputs.
        """
        if len(self.batch_shape) > 0:
            if self.batch_shape != X.shape[-2 - len(self.batch_shape) : -2]:
                raise BotorchTensorDimensionError(
                    "The right most batch dims of X must match self.batch_shape: "
                    f"({self.batch_shape})."
                )
        X_tf = X.clone()
        k = Kumaraswamy(
            concentration1=self.concentration1, concentration0=self.concentration0
        )
        # unnormalize from [eps, 1-eps] to [0,1]
        X_tf[..., self.indices] = (
            (k.icdf(X_tf[..., self.indices]) - self._X_min) / self._X_range
        ).clamp(0.0, 1.0)
        return X_tf

    def transform(self, X: Tensor) -> Tensor:
        r"""Transform the inputs.

        Args:
            X: A `batch_shape x n x d`-dim tensor of inputs.

        Returns:
            A `batch_shape x n x d`-dim tensor of transformed inputs.
        """
        return self._untransform(X) if self.reverse else self._transform(X)

    def untransform(self, X: Tensor) -> Tensor:
        r"""Un-transform the inputs.

        Args:
            X: A `batch_shape x n x d`-dim tensor of inputs.

        Returns:
            A `batch_shape x n x d`-dim tensor of un-transformed inputs.
        """
        return self._transform(X) if self.reverse else self._untransform(X)

    def equals(self, other: InputTransform) -> bool:
        r"""Check if another input transform is equivalent.

        Note: The reason that a custom equals method is defined rather than
        defining an __eq__ method is because defining an __eq__ method sets
        the __hash__ method to None. Hashing modules is currently used in
        pytorch. See https://github.com/pytorch/pytorch/issues/7733.

        Args:
            other: Another input transform.

        Returns:
            A boolean indicating if the other transform is equivalent.
        """
        other_state_dict = other.state_dict()
        return (
            type(self) == type(other)
            and (self.transform_on_train == other.transform_on_train)
            and (self.transform_on_eval == other.transform_on_eval)
            and (self.transform_on_fantasize == other.transform_on_fantasize)
            and all(
                torch.allclose(v, other_state_dict[k].to(v))
                for k, v in self.state_dict().items()
            )
        )

    def preprocess_transform(self, X: Tensor) -> Tensor:
        r"""Apply transforms for preprocessing inputs.

        The main use cases for this method are 1) to preprocess training data
        before calling `set_train_data` and 2) preprocess `X_baseline` for noisy
        acquisition functions so that `X_baseline` is "preprocessed" with the
        same transformations as the cached training inputs.

        Args:
            X: A `batch_shape x n x d`-dim tensor of inputs.

        Returns:
            A `batch_shape x n x d`-dim tensor of (transformed) inputs.
        """
        if self.transform_on_train:
            # We need to disable learning of bounds here.
            # See why: https://github.com/pytorch/botorch/issues/1078.
            if hasattr(self, "learn_bounds"):
                learn_bounds = self.learn_bounds
                self.learn_bounds = False
                result = self.transform(X)
                self.learn_bounds = learn_bounds
                return result
            else:
                return self.transform(X)
        return X

    def forward(self, X: Tensor) -> Tensor:
        r"""Transform the inputs to a model.

        Args:
            X: A `batch_shape x n x d`-dim tensor of inputs.

        Returns:
            A `batch_shape x n' x d`-dim tensor of transformed inputs.
        """
        if self.training:
            if self.transform_on_train:
                return self.transform(X)
        elif self.transform_on_eval:
            if fantasize.off() or self.transform_on_fantasize:
                return self.transform(X)
        return X


def constraint_based_on_distribution_support(
    prior: torch.distributions.Distribution, device, sample_from_path
):
    if sample_from_path:
        return None

    if hasattr(prior.support, "upper_bound"):
        return gpytorch.constraints.Interval(
            to_tensor(prior.support.lower_bound, device=device),
            to_tensor(prior.support.upper_bound, device=device),
        )
    else:
        return gpytorch.constraints.GreaterThan(
            to_tensor(prior.support.lower_bound, device=device)
        )


loaded_things = {}


def torch_load(path):
    """
    Cached torch load. Caution: This does not copy the output but keeps pointers.
    That means, if you modify the output, you modify the output of later calls to this function with the same args.
    :param path:
    :return:
    """
    if path not in loaded_things:
        print(f"loading {path}")
        with open(path, "rb") as f:
            loaded_things[path] = torch.load(f)
    return loaded_things[path]


def get_model(x, y, hyperparameters: dict, sample=True):
    sample_from_path = hyperparameters.get("sample_from_extra_prior", None)
    device = x.device
    num_features = x.shape[-1]
    likelihood = gpytorch.likelihoods.GaussianLikelihood(
        noise_constraint=gpytorch.constraints.Positive()
    )
    likelihood.register_prior(
        "noise_prior",
        LogNormalPrior(
            torch.tensor(
                hyperparameters.get("hebo_noise_logmean", -4.63), device=device
            ),
            torch.tensor(hyperparameters.get("hebo_noise_std", 0.5), device=device),
        ),
        "noise",
    )
    lengthscale_prior = (
        GammaPrior(
            torch.tensor(hyperparameters["lengthscale_concentration"], device=device),
            torch.tensor(hyperparameters["lengthscale_rate"], device=device),
        )
        if hyperparameters.get("lengthscale_concentration", None)
        else UniformPrior(
            torch.tensor(0.0, device=device), torch.tensor(1.0, device=device)
        )
    )
    covar_module = gpytorch.kernels.MaternKernel(
        nu=3 / 2,
        ard_num_dims=num_features,
        lengthscale_prior=lengthscale_prior,
        lengthscale_constraint=constraint_based_on_distribution_support(
            lengthscale_prior, device, sample_from_path
        ),
    )
    # ORIG DIFF: orig lengthscale has no prior
    # covar_module.register_prior("lengthscale_prior",
    # UniformPrior(.000000001, 1.),
    # GammaPrior(concentration=hyperparameters.get('lengthscale_concentration', 1.),
    #           rate=hyperparameters.get('lengthscale_rate', .1)),
    # skewness is controllled by concentration only, want somthing like concetration in [0.1,1.], rate around [.05,1] seems reasonable
    # "lengthscale")
    outputscale_prior = GammaPrior(
        concentration=hyperparameters.get("outputscale_concentration", 0.5),
        rate=hyperparameters.get("outputscale_rate", 1.0),
    )
    covar_module = gpytorch.kernels.ScaleKernel(
        covar_module,
        outputscale_prior=outputscale_prior,
        outputscale_constraint=constraint_based_on_distribution_support(
            outputscale_prior, device, sample_from_path
        ),
    )

    if random.random() < float(hyperparameters.get("add_linear_kernel", True)):
        # ORIG DIFF: added priors for variance and outputscale of linear kernel
        var_prior = UniformPrior(
            torch.tensor(0.0, device=device), torch.tensor(1.0, device=device)
        )
        out_prior = UniformPrior(
            torch.tensor(0.0, device=device), torch.tensor(1.0, device=device)
        )
        lincovar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.LinearKernel(
                variance_prior=var_prior,
                variance_constraint=constraint_based_on_distribution_support(
                    var_prior, device, sample_from_path
                ),
            ),
            outputscale_prior=out_prior,
            outputscale_constraint=constraint_based_on_distribution_support(
                out_prior, device, sample_from_path
            ),
        )
        covar_module = covar_module + lincovar_module

    if hyperparameters.get("hebo_warping", False):
        # initialize input_warping transformation
        warp_tf = Warp(
            indices=list(range(num_features)),
            # use a prior with median at 1.
            # when a=1 and b=1, the Kumaraswamy CDF is the identity function
            concentration1_prior=LogNormalPrior(
                torch.tensor(0.0, device=device),
                torch.tensor(
                    hyperparameters.get("hebo_input_warping_c1_std", 0.75),
                    device=device,
                ),
            ),
            concentration0_prior=LogNormalPrior(
                torch.tensor(0.0, device=device),
                torch.tensor(
                    hyperparameters.get("hebo_input_warping_c0_std", 0.75),
                    device=device,
                ),
            ),
        )
    else:
        warp_tf = None
    # assume mean 0 always!
    if len(y.shape) < len(x.shape):
        y = y.unsqueeze(-1)
    model = botorch.models.SingleTaskGP(
        x, y, likelihood=likelihood, covar_module=covar_module, input_transform=warp_tf
    )
    model.mean_module = ZeroMean(x.shape[:-2])
    model.to(device)
    likelihood.to(device)

    if sample:
        model = model.pyro_sample_from_prior()
        if sample_from_path:
            parameter_sample_distribution = torch_load(
                sample_from_path
            )  # dict with entries for each parameter
            idx_for_len = {}
            for (
                parameter_name,
                parameter_values,
            ) in parameter_sample_distribution.items():
                assert len(parameter_values.shape) == 1
                try:
                    p = eval(parameter_name)
                    if len(parameter_values) in idx_for_len:
                        idx = idx_for_len[len(parameter_values)].view(p.shape)
                    else:
                        idx = torch.randint(len(parameter_values), p.shape)
                        idx_for_len[len(parameter_values)] = idx
                    new_sample = parameter_values[idx].to(device).view(p.shape)  # noqa
                    assert new_sample.shape == p.shape
                    with torch.no_grad():
                        p.data = new_sample
                except AttributeError:
                    utils.print_once(
                        f"could not find parameter {parameter_name} in model for `sample_from_extra_prior`"
                    )
            model.requires_grad_(False)
            likelihood.requires_grad_(False)
        return model, model.likelihood
    else:
        assert not (hyperparameters.get("sigmoid", False)) and not (
            hyperparameters.get("y_minmax_norm", False)
        ), "Sigmoid and y_minmax_norm can only be used to sample models..."
        return model, likelihood


@torch.no_grad()
def get_batch(
    batch_size,
    seq_len,
    num_features,
    device=default_device,
    hyperparameters=None,
    batch_size_per_gp_sample=None,
    single_eval_pos=None,
    fix_to_range=None,
    equidistant_x=False,
    verbose=False,
    **kwargs,
):
    """
    This function is very similar to the equivalent in .fast_gp. The only difference is that this function operates over
    a mixture of GP priors.
    :param batch_size:
    :param seq_len:
    :param num_features:
    :param device:
    :param hyperparameters:
    :param for_regression:
    :return:
    """
    hyperparameters = hyperparameters or {}
    with gpytorch.settings.fast_computations(
        *hyperparameters.get("fast_computations", (True, True, True))
    ):
        batch_size_per_gp_sample = batch_size_per_gp_sample or max(batch_size // 4, 1)
        assert batch_size % batch_size_per_gp_sample == 0

        total_num_candidates = batch_size * (2 ** (fix_to_range is not None))
        num_candidates = batch_size_per_gp_sample * (2 ** (fix_to_range is not None))
        unused_feature_likelihood = hyperparameters.get(
            "unused_feature_likelihood", False
        )
        if equidistant_x:
            assert num_features == 1
            assert not unused_feature_likelihood
            x = (
                torch.linspace(0, 1.0, seq_len)
                .unsqueeze(0)
                .repeat(total_num_candidates, 1)
                .unsqueeze(-1)
            )
        else:
            x = torch.rand(total_num_candidates, seq_len, num_features, device=device)
        samples = []
        samples_wo_noise = []
        for i in range(0, total_num_candidates, num_candidates):
            local_x = x[i : i + num_candidates]
            if unused_feature_likelihood:
                r = torch.rand(num_features)
                unused_feature_mask = r < unused_feature_likelihood
                if unused_feature_mask.all():
                    unused_feature_mask[r.argmin()] = False
                used_local_x = local_x[..., ~unused_feature_mask]
            else:
                used_local_x = local_x
            get_model_and_likelihood = lambda: get_model(
                used_local_x,
                torch.zeros(num_candidates, x.shape[1], device=device),
                hyperparameters,
            )
            model, likelihood = get_model_and_likelihood()
            if verbose:
                print(
                    list(model.named_parameters()),
                    (
                        (
                            list(model.input_transform.named_parameters()),
                            model.input_transform.concentration1,
                            model.input_transform.concentration0,
                        )
                        if model.input_transform is not None
                        else None
                    ),
                )

            # trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
            # trained_model.eval()
            successful_sample = 0
            throwaway_share = 0.0
            while successful_sample < 1:
                with gpytorch.settings.prior_mode(True):
                    # print(x.device, device, f'{model.covar_module.base_kernel.lengthscale=}, {model.covar_module.base_kernel.lengthscale.device=}')
                    d = model(used_local_x)
                    try:
                        with warnings.catch_warnings():
                            warnings.simplefilter("ignore")
                            sample_wo_noise = d.sample()
                            d = likelihood(sample_wo_noise)
                    except (RuntimeError, ValueError) as e:
                        successful_sample -= 1
                        model, likelihood = get_model_and_likelihood()
                        if successful_sample < -100:
                            print(
                                f"Could not sample from model {i} after {successful_sample} attempts. {e}"
                            )
                            raise e
                        continue
                    sample = d.sample()  # bs_per_gp_s x T
                    if fix_to_range is None:
                        # for k, v in model.named_parameters(): print(k,v)
                        samples.append(sample.transpose(0, 1))
                        samples_wo_noise.append(sample_wo_noise.transpose(0, 1))
                        break
                    smaller_mask = sample < fix_to_range[0]
                    larger_mask = sample >= fix_to_range[1]
                    in_range_mask = ~(smaller_mask | larger_mask).any(1)
                    throwaway_share += (
                        ~in_range_mask[:batch_size_per_gp_sample]
                    ).sum() / batch_size_per_gp_sample
                    if in_range_mask.sum() < batch_size_per_gp_sample:
                        successful_sample -= 1
                        if successful_sample < 100:
                            print(
                                "Please change hyper-parameters (e.g. decrease outputscale_mean) it"
                                "seems like the range is set to tight for your hyper-parameters."
                            )
                        continue

                    x[i : i + batch_size_per_gp_sample] = local_x[in_range_mask][
                        :batch_size_per_gp_sample
                    ]
                    sample = sample[in_range_mask][:batch_size_per_gp_sample]
                    samples.append(sample.transpose(0, 1))
                    samples_wo_noise.append(sample_wo_noise.transpose(0, 1))
                    successful_sample = True

        if random.random() < 0.01:
            print(
                "throwaway share",
                throwaway_share / (batch_size // batch_size_per_gp_sample),
            )

        # print(f'took {time.time() - start}')
        sample = torch.cat(samples, 1)[..., None]
        sample_wo_noise = torch.cat(samples_wo_noise, 1)[..., None]
        x = x.view(-1, batch_size, seq_len, num_features)[0]
        # TODO think about enabling the line below
        # sample = sample - sample[0, :].unsqueeze(0).expand(*sample.shape)
        x = x.transpose(0, 1)
        assert x.shape[:2] == sample.shape[:2]
    return Batch(
        x=x,
        y=sample,
        target_y=(
            sample
            if hyperparameters.get("observation_noise", True)
            else sample_wo_noise
        ),
    )


DataLoader = get_batch_to_dataloader(get_batch)
