import copy
import math
from typing import Any, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn

TensorLike = Union[torch.Tensor, Sequence[Sequence[float]], Sequence[float]]


def gamma_schedule(
    t: Union[int, torch.Tensor], gamma0: Union[float, torch.Tensor] = 1.0
) -> torch.Tensor:
    if isinstance(t, int):
        t_tensor = torch.tensor(float(t))
    else:
        t_tensor = torch.as_tensor(t).float()
    t_tensor = torch.clamp(t_tensor, min=1.0)
    return torch.as_tensor(gamma0).float() / torch.sqrt(t_tensor)


def _as_2d(x: torch.Tensor) -> Tuple[torch.Tensor, bool]:
    if x.dim() == 1:
        return x.unsqueeze(0), True
    return x, False


def _standard_normal_cdf(z: torch.Tensor) -> torch.Tensor:
    if hasattr(torch.special, "ndtr"):
        return torch.special.ndtr(z)
    return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))


def _binary_entropy(p: torch.Tensor) -> torch.Tensor:
    return p * (1 - p)


def acquisition_value(
    x: torch.Tensor,
    mixture_model: Any,
    anchors: Optional[torch.Tensor],
    beta_t: Union[float, torch.Tensor],
    gamma_t: Union[float, torch.Tensor],
) -> torch.Tensor:
    if not isinstance(x, torch.Tensor):
        raise TypeError("x must be a torch.Tensor")
    if not hasattr(mixture_model, "predict"):
        raise TypeError("mixture_model must implement predict(X) -> (mean, std, cov)")

    x2d, _ = _as_2d(x)
    device = x2d.device
    dtype = x2d.dtype

    beta_t = torch.as_tensor(beta_t, device=device, dtype=dtype)
    gamma_t = torch.as_tensor(gamma_t, device=device, dtype=dtype)

    if anchors is None:
        anchors2d = x2d.new_zeros((0, x2d.size(-1)))
    else:
        if not isinstance(anchors, torch.Tensor):
            anchors = torch.as_tensor(anchors, dtype=dtype)
        anchors2d, _ = _as_2d(anchors.to(device=device, dtype=dtype))

    n_points = x2d.size(0)
    n_anchors = anchors2d.size(0)

    if n_anchors == 0:
        mean_x, std_x, _ = mixture_model.predict(x2d)
        mean_x = mean_x.view(-1)
        std_x = std_x.view(-1)
        return mean_x + torch.sqrt(beta_t) * std_x

    X_all = torch.cat([x2d, anchors2d], dim=0)
    mean_all, std_all, cov_all = mixture_model.predict(X_all)
    mean_all = mean_all.view(-1)
    std_all = std_all.view(-1)

    mean_x = mean_all[:n_points]
    mean_a = mean_all[n_points:]
    var_x = std_all[:n_points] ** 2
    var_a = std_all[n_points:] ** 2

    if cov_all is None:
        cov_xa = x2d.new_zeros((n_points, n_anchors))
    else:
        cov_all = cov_all.view(n_points + n_anchors, n_points + n_anchors)
        cov_xa = cov_all[:n_points, n_points:]

    sigma_diff2 = var_x.unsqueeze(1) + var_a.unsqueeze(0) - 2.0 * cov_xa
    sigma_diff = torch.sqrt(torch.clamp(sigma_diff2, min=1e-12))
    z = (mean_x.unsqueeze(1) - mean_a.unsqueeze(0)) / sigma_diff
    p = _standard_normal_cdf(z)
    ren = _binary_entropy(p).mean(dim=1)

    return mean_x + beta_t * torch.sqrt(torch.clamp(var_x, min=0.0)) + gamma_t * ren


def _acquisition_value_decoupled_restarts(
    x: torch.Tensor,
    mixture_model: Any,
    anchors: Optional[torch.Tensor],
    beta_t: Union[float, torch.Tensor],
    gamma_t: Union[float, torch.Tensor],
) -> torch.Tensor:
    x2d, _ = _as_2d(x)
    device = x2d.device
    dtype = x2d.dtype

    beta_t = torch.as_tensor(beta_t, device=device, dtype=dtype)
    gamma_t = torch.as_tensor(gamma_t, device=device, dtype=dtype)

    if anchors is None:
        anchors2d = x2d.new_zeros((0, x2d.size(-1)))
    else:
        if not isinstance(anchors, torch.Tensor):
            anchors = torch.as_tensor(anchors, dtype=dtype)
        anchors2d, _ = _as_2d(anchors.to(device=device, dtype=dtype))

    n_restarts = x2d.size(0)
    n_anchors = anchors2d.size(0)

    if n_anchors == 0:
        mean_x, std_x, _ = mixture_model.predict(x2d)
        mean_x = mean_x.view(-1)
        std_x = std_x.view(-1)
        return mean_x + torch.sqrt(beta_t) * std_x

    try:
        mean_x_ind, std_x_ind, _ = mixture_model.predict(x2d, block_size=1)
    except TypeError:
        mean_x_ind, std_x_ind, _ = mixture_model.predict(x2d)
    mean_x_ind = mean_x_ind.view(-1)
    std_x_ind = std_x_ind.view(-1)

    block = 1 + n_anchors
    X_blocks = torch.cat(
        [
            x2d.unsqueeze(1),
            anchors2d.unsqueeze(0).expand(n_restarts, n_anchors, x2d.size(-1)),
        ],
        dim=1,
    )
    X_flat = X_blocks.reshape(n_restarts * block, x2d.size(-1))

    try:
        mean_flat, std_flat, cov_flat = mixture_model.predict(X_flat, block_size=block)
    except TypeError:
        mean_flat, std_flat, cov_flat = mixture_model.predict(X_flat)

    mean_flat = mean_flat.view(-1)
    std_flat = std_flat.view(-1)
    mean = mean_flat.view(n_restarts, block)
    std = std_flat.view(n_restarts, block)

    mean_joint_x = mean[:, 0]
    std_joint_x = std[:, 0]
    mean_a = mean[:, 1:]
    var_x = std_joint_x.pow(2)
    var_a = std[:, 1:].pow(2)

    if cov_flat is None:
        cov_xa = x2d.new_zeros((n_restarts, n_anchors))
    else:
        if cov_flat.dim() == 3:
            cov_blk = cov_flat.view(n_restarts, block, block)
        else:
            cov4 = cov_flat.view(n_restarts, block, n_restarts, block)
            idx = torch.arange(n_restarts, device=device)
            cov_blk = cov4[idx, :, idx, :]
        cov_xa = cov_blk[:, 0, 1:]

    sigma_diff2 = var_x.unsqueeze(1) + var_a - 2.0 * cov_xa
    sigma_diff = torch.sqrt(torch.clamp(sigma_diff2, min=1e-12))
    z = (mean_joint_x.unsqueeze(1) - mean_a) / sigma_diff
    p = _standard_normal_cdf(z)
    ren = _binary_entropy(p).mean(dim=1)

    return mean_x_ind + beta_t * std_x_ind + gamma_t * ren


def maximize_acquisition_x(
    mixture_model: Any,
    beta_t: Union[float, torch.Tensor],
    gamma_t: Union[float, torch.Tensor],
    step: int,
    lr: float,
    print_every: int,
    n_restarts: int,
    anchors: Optional[torch.Tensor] = None,
    dim: Optional[int] = None,
    bounds: Optional[Tuple[float, float]] = (0.0, 1.0),
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
    seed: Optional[int] = None,
) -> torch.Tensor:
    if not hasattr(mixture_model, "predict"):
        raise TypeError("mixture_model must implement predict(X) -> (mean, std, cov)")
    if step <= 0:
        raise ValueError("step must be positive")
    if n_restarts <= 0:
        raise ValueError("n_restarts must be positive")
    if print_every <= 0:
        raise ValueError("print_every must be positive")

    if device is None:
        if anchors is not None and isinstance(anchors, torch.Tensor):
            device = anchors.device
        elif hasattr(mixture_model, "device"):
            device = mixture_model.device
        else:
            device = torch.device("cpu")

    if dim is None:
        if anchors is not None:
            if not isinstance(anchors, torch.Tensor):
                anchors = torch.as_tensor(anchors, device=device, dtype=dtype)
            dim = int(anchors.shape[-1])
        elif hasattr(mixture_model, "w") and isinstance(
            getattr(mixture_model, "w"), torch.Tensor
        ):
            dim = int(mixture_model.w.numel())
        elif hasattr(mixture_model, "d"):
            dim = int(getattr(mixture_model, "d"))
        else:
            raise ValueError("dim is required when it cannot be inferred")

    if seed is not None:
        torch.manual_seed(seed)

    if bounds is None:
        x0 = torch.randn(n_restarts, dim, device=device, dtype=dtype)
    else:
        low, high = float(bounds[0]), float(bounds[1])
        x0 = low + (high - low) * torch.rand(
            n_restarts, dim, device=device, dtype=dtype
        )

    x_param = nn.Parameter(x0)
    optimizer = torch.optim.Adam([x_param], lr=lr)

    for i in range(step):
        optimizer.zero_grad(set_to_none=True)
        acq = _acquisition_value_decoupled_restarts(
            x_param, mixture_model, anchors, beta_t, gamma_t
        )
        loss = -acq.sum()
        loss.backward()
        optimizer.step()

        if bounds is not None:
            low, high = float(bounds[0]), float(bounds[1])
            with torch.no_grad():
                x_param.clamp_(min=low, max=high)

        if i == 0 or (i + 1) % print_every == 0 or (i + 1) == step:
            with torch.no_grad():
                acq_now = _acquisition_value_decoupled_restarts(
                    x_param, mixture_model, anchors, beta_t, gamma_t
                )
                print(
                    f"[maximize_acquisition_x] step={i + 1}/{step} "
                    f"loss={loss.item():.6f} acq_mean={acq_now.mean().item():.6f} acq_best={acq_now.max().item():.6f}"
                )

    with torch.no_grad():
        acq_final = _acquisition_value_decoupled_restarts(
            x_param, mixture_model, anchors, beta_t, gamma_t
        )
        best_idx = int(torch.argmax(acq_final).item())
        return x_param.detach()[best_idx]


def maximize_acquisition_adam_multistart(
    mixture_model: Any,
    beta_t: Union[float, torch.Tensor],
    gamma_t: Union[float, torch.Tensor],
    steps: int,
    lr: float,
    n_restarts: int,
    print_every: int = 10,
    anchors: Optional[torch.Tensor] = None,
    dim: Optional[int] = None,
    bounds: Optional[Tuple[float, float]] = (0.0, 1.0),
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
    seed: Optional[int] = None,
) -> torch.Tensor:
    if not hasattr(mixture_model, "predict"):
        raise TypeError("mixture_model must implement predict(X) -> (mean, std, cov)")
    if steps <= 0:
        raise ValueError("steps must be positive")
    if n_restarts <= 0:
        raise ValueError("n_restarts must be positive")
    if print_every <= 0:
        raise ValueError("print_every must be positive")

    if device is None:
        if anchors is not None and isinstance(anchors, torch.Tensor):
            device = anchors.device
        elif hasattr(mixture_model, "device"):
            device = mixture_model.device
        else:
            device = torch.device("cpu")

    if dim is None:
        if anchors is not None:
            if not isinstance(anchors, torch.Tensor):
                anchors = torch.as_tensor(anchors, device=device, dtype=dtype)
            dim = int(anchors.shape[-1])
        elif hasattr(mixture_model, "w") and isinstance(
            getattr(mixture_model, "w"), torch.Tensor
        ):
            dim = int(mixture_model.w.numel())
        elif hasattr(mixture_model, "d"):
            dim = int(getattr(mixture_model, "d"))
        else:
            raise ValueError("dim is required when it cannot be inferred")

    if seed is not None:
        torch.manual_seed(seed)

    if bounds is None:
        x0 = torch.randn(n_restarts, dim, device=device, dtype=dtype)
    else:
        low, high = float(bounds[0]), float(bounds[1])
        x0 = low + (high - low) * torch.rand(
            n_restarts, dim, device=device, dtype=dtype
        )

    x_param = nn.Parameter(x0)
    optimizer = torch.optim.Adam([x_param], lr=lr)

    for i in range(steps):
        optimizer.zero_grad(set_to_none=True)
        acq = _acquisition_value_decoupled_restarts(
            x_param, mixture_model, anchors, beta_t, gamma_t
        )
        loss = -acq.sum()
        loss.backward()
        optimizer.step()

        if bounds is not None:
            low, high = float(bounds[0]), float(bounds[1])
            with torch.no_grad():
                x_param.clamp_(min=low, max=high)

        if i == 0 or (i + 1) % print_every == 0 or (i + 1) == steps:
            with torch.no_grad():
                acq_now = _acquisition_value_decoupled_restarts(
                    x_param, mixture_model, anchors, beta_t, gamma_t
                )
                print(
                    f"[maximize_acquisition_adam_multistart] step={i + 1}/{steps} "
                    f"loss={loss.item():.6f} acq_mean={acq_now.mean().item():.6f} acq_best={acq_now.max().item():.6f}"
                )

    with torch.no_grad():
        acq_final = _acquisition_value_decoupled_restarts(
            x_param, mixture_model, anchors, beta_t, gamma_t
        )
        best_idx = int(torch.argmax(acq_final).item())
        return x_param.detach()[best_idx]


class DemoGaussianModel(nn.Module):
    def __init__(
        self,
        d: int,
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.float32,
        lengthscale: float = 0.3,
        sigma2: float = 0.5,
        jitter: float = 1e-6,
    ):
        super().__init__()
        self.device = device or torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self.dtype = dtype

        self.register_buffer(
            "w",
            (
                torch.randn(d, device=self.device, dtype=self.dtype)
                / math.sqrt(max(d, 1))
            ),
        )
        self.register_buffer(
            "lengthscale",
            torch.tensor(float(lengthscale), device=self.device, dtype=self.dtype),
        )
        self.register_buffer(
            "sigma2", torch.tensor(float(sigma2), device=self.device, dtype=self.dtype)
        )
        self.register_buffer(
            "jitter", torch.tensor(float(jitter), device=self.device, dtype=self.dtype)
        )

    def predict(self, X: torch.Tensor):
        X = X.to(device=self.device, dtype=self.dtype)
        mean = torch.tanh(X.matmul(self.w))

        dist2 = torch.cdist(X, X) ** 2
        cov = self.sigma2 * torch.exp(-0.5 * dist2 / (self.lengthscale**2))
        cov = cov + self.jitter * torch.eye(X.size(0), device=X.device, dtype=X.dtype)
        std = torch.sqrt(torch.clamp(torch.diagonal(cov, dim1=-2, dim2=-1), min=0.0))
        return mean, std, cov


if __name__ == "__main__":
    torch.manual_seed(0)
    dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    d = 4

    model = DemoGaussianModel(d=d, device=dev)
    anchors = torch.rand(8, d, device=dev)
    x = torch.rand(16, d, device=dev, requires_grad=True)

    beta_t = 2.0
    gamma_t = gamma_schedule(0, gamma0=1).to(dev)

    mean, std, cov = model.predict(torch.cat([x.detach(), anchors], dim=0))
    acq = acquisition_value(x, model, anchors, beta_t, gamma_t)
    acq.sum().backward()

    print("mean/std/cov:", mean.shape, std.shape, cov.shape)
    print("acq:", acq.shape)
    print("grad_sum:", x.grad.abs().sum().item())

    x_best = maximize_acquisition_x(
        mixture_model=model,
        beta_t=beta_t,
        gamma_t=gamma_t,
        step=50,
        lr=0.05,
        print_every=10,
        n_restarts=16,
        anchors=anchors,
        bounds=(0.0, 1.0),
        device=dev,
    )
    acq_best = acquisition_value(x_best, model, anchors, beta_t, gamma_t).item()
    print("x_best:", x_best.shape)
    print("x_best_range:", x_best.min().item(), x_best.max().item())
    print("acq_best:", acq_best)

    import os
    import sys

    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
    from models.gp_model import GPModel
    from models.mixture import ModelMixture
    from models.deep_ensemble import DeepEnsemble

    total_pool_size = 1000
    d_mix = 10
    list_length = 50
    batch_size = 128
    num_batches_per_epoch = 10
    epochs_rank = 400
    hidden_dim1 = 64
    hidden_dim2 = 64
    num_models = 3
    lr_mix = 1e-3
    use_amp = True

    torch.manual_seed(42)
    true_w = torch.randn(d_mix, 1, device=dev)
    X_pool = torch.randn(total_pool_size, d_mix, device=dev)
    y_pool = X_pool @ true_w + 0.1 * torch.randn(total_pool_size, 1, device=dev)

    X_cpu = X_pool.cpu()
    y_cpu = y_pool.cpu()

    mse_ens = DeepEnsemble(
        input_dim=d_mix,
        hidden_dims=[hidden_dim1, hidden_dim2],
        activation="gelu",
        num_models=num_models,
        seeds=None,
        out_dim=1,
        device=dev,
        dtype=torch.float32,
    )
    mse_ens.fit(
        X_cpu,
        y_cpu,
        steps=400,
        lr=lr_mix,
        batch_size=256,
        loss_type="mse",
        list_size=None,
        lists_per_step=None,
        use_amp=use_amp,
        log_every=200,
    )

    gp = GPModel(device=dev)
    gp.fit(
        X_cpu,
        y_cpu.squeeze(-1),
        n_iter=400,
        lr=0.05,
        n_restarts=2,
        use_robust_init=True,
    )

    list_ens = DeepEnsemble(
        input_dim=d_mix,
        hidden_dims=[hidden_dim1, hidden_dim2],
        activation="gelu",
        num_models=num_models,
        seeds=None,
        out_dim=1,
        device=dev,
        dtype=torch.float32,
    )
    list_ens.fit(
        X_cpu,
        y_cpu,
        steps=epochs_rank,
        lr=lr_mix,
        batch_size=batch_size,
        loss_type="listnet",
        list_size=list_length,
        lists_per_step=num_batches_per_epoch,
        use_amp=use_amp,
        log_every=200,
    )

    gp.set_prediction_calibration(X_cpu)
    mse_ens.set_prediction_calibration(X_cpu)
    list_ens.set_prediction_calibration(X_cpu)

    weights = torch.ones(3, device=dev)
    mixture = ModelMixture(models=[gp, mse_ens, list_ens], weights=weights, device=dev)

    anchors_mix = X_pool[:8].detach()
    x_mix = X_pool[8:24].clone().detach().requires_grad_(True)
    acq_mix = acquisition_value(x_mix, mixture, anchors_mix, beta_t, gamma_t)
    print("mixture_acq_requires_grad:", bool(acq_mix.requires_grad))
    try:
        acq_mix.sum().backward()
        print("mixture_grad_sum:", x_mix.grad.abs().sum().item())
    except RuntimeError as e:
        print("mixture_backward_error:", str(e))
    import copy

    mixture_model_cp = copy.deepcopy(mixture)
    x_best = maximize_acquisition_adam_multistart(
        mixture_model=mixture,
        beta_t=beta_t,
        gamma_t=gamma_t,
        steps=1000,
        lr=0.01,
        n_restarts=16,
        anchors=anchors_mix,
        print_every=200,
        dim=d_mix,
        bounds=None,
        device=dev,
        dtype=torch.float32,
        seed=0,
    )
    acq_best = _acquisition_value_decoupled_restarts(
        x_best, mixture, anchors_mix, beta_t, gamma_t
    )
    print(
        "x_best_shape:",
        tuple(x_best.shape),
        "acq_best:",
        float(acq_best.view(-1)[0].item()),
    )
