import torch


def weighted_reward(x, base_width):
    """
    x: tensor of shape [n_particles, 1] or [n_particles, 2]
    Returns higher rewards for the global optimum at -1,
    with a smaller local bump at +1.
    """
    x = torch.Tensor(x)
    if len(x.shape) < 2:
        x = x[None, :]
    x1 = x[:, 0]
    if x.shape[1] > 1:
        other_dims_sq = torch.sum(x[:, 1:] ** 2, dim=1)
    else:
        other_dims_sq = 0.0
    # Global peak at -1 (sharp, high)
    global_peak = (
        -torch.relu(
            -torch.exp(-((x1 - (1.5)) ** 2 + other_dims_sq) / base_width) * 100.0 + 1.9
        )
        - 1.9
    )  # amplitude 2, narrow

    # Local peak at +1 (broader, lower)
    local_peak = (
        -torch.relu(-torch.exp(-((x1 - (-1.5)) ** 2 + other_dims_sq) / 1.0) * 2.0 + 1.8)
        - 1.8
    )  # amplitude 1, wider

    return (global_peak + local_peak).squeeze(-1)


def weighted_reward_nd(x, base_width):
    """
    x: tensor of shape [n_particles, 1] or [n_particles, 2]
    Returns higher rewards for the global optimum at -1,
    with a smaller local bump at +1.
    """
    x = torch.Tensor(x)
    if len(x.shape) < 2:
        x = x[None, :]
    x1 = x[:, 0]
    if x.shape[1] > 1:
        other_dims_sq = torch.sum(x[:, 1:] ** 2, dim=1)
    else:
        other_dims_sq = 0.0
    # Global peak at -1 (sharp, high)
    global_peak = (
        torch.exp(-((x1 - (1.5)) ** 2 + other_dims_sq) / base_width) * 2.0
    )  # amplitude 2, narrow

    # Local peak at +1 (broader, lower)
    local_peak = (
        torch.exp(-((x1 - (-1.5)) ** 2 + other_dims_sq) / 3) * 1.5
    )  # amplitude 1, wider

    return (global_peak + local_peak).squeeze(-1)


def weighted_reward_1d(x, base_width):
    """
    x: tensor of shape [n_particles, 1] or [n_particles, 2]
    Returns higher rewards for the global optimum at -1,
    with a smaller local bump at +1.
    """
    x = x[:, 0]
    # Global peak at -1 (sharp, high)
    global_peak = (
        torch.exp(-((x - (1.5)) ** 2) / base_width) * 2.0
    )  # amplitude 2, narrow

    # Local peak at +1 (broader, lower)
    local_peak = torch.exp(-((x - (-1.5)) ** 2) / 3) * 1.5  # amplitude 1, wider

    return (global_peak + local_peak).squeeze(-1)
