import torch
from math import pi
from types import SimpleNamespace

def gaussian_mixture_from_anchors(xy: torch.Tensor,
                                  anchors: torch.Tensor,
                                  sigma: float,
                                  weights: torch.Tensor | None = None) -> torch.Tensor:

    B, M = xy.shape[0], anchors.shape[0]
    diff = xy[:, None, :] - anchors[None, :, :]             # (B,M,2)
    d2 = (diff * diff).sum(dim=-1)                          # (B,M)
    comp = torch.exp(-0.5 * d2 / (sigma ** 2))              # (B,M)
    if weights is not None:
        comp = comp * weights.view(1, M)
    dens = comp.sum(dim=1) + 1e-30                          # (B,)
    return dens

# ---------- 8 Gaussians ----------
def density_eight_gaussians(xy: torch.Tensor, R: float, sigma: float) -> torch.Tensor:
    thetas = torch.linspace(0, 2 * pi, steps=8 + 1, dtype=xy.dtype, device=xy.device)[:-1]
    anchors = torch.stack([R*torch.cos(thetas), R*torch.sin(thetas)], dim=-1)  # (8,2)
    return gaussian_mixture_from_anchors(xy, anchors, sigma)

# ---------- ring(s) radial ----------
def density_rings(xy: torch.Tensor, radii: list[float], sigma_r: float, weights: list[float] | None = None) -> torch.Tensor:
    r = torch.linalg.norm(xy, dim=1)                   # (B,)
    radii_t = torch.tensor(radii, dtype=xy.dtype, device=xy.device).view(1, -1)  # (1,M)
    diff = r.view(-1, 1) - radii_t                     # (B,M)
    comp = torch.exp(-0.5 * (diff * diff) / (sigma_r ** 2))  # (B,M)
    if weights is not None:
        w = torch.tensor(weights, dtype=xy.dtype, device=xy.device).view(1, -1)
        comp = comp * w
    dens = comp.sum(dim=1) + 1e-30
    return dens * 10

# ---------- 2-moons via âncoras em duas meias-luas ----------
def density_two_moons(xy: torch.Tensor, R: float, delta: float, gap: float, sigma: float, n_anchors: int = 256) -> torch.Tensor:
    theta = torch.linspace(0, pi, steps=n_anchors, dtype=xy.dtype, device=xy.device)
    A = torch.stack([ -delta + R*torch.cos(theta), 0.0 + R*torch.sin(theta) ], dim=-1)  # (n,2)
    theta = torch.linspace(pi, 2*pi, steps=n_anchors, dtype=xy.dtype, device=xy.device)
    B = torch.stack([ +delta + R*torch.cos(theta), -gap + R*torch.sin(theta) ], dim=-1)
    anchors = torch.cat([A, B], dim=0)
    return gaussian_mixture_from_anchors(xy, anchors, sigma)


# ---------- “embrulho”: gera log-reward a partir de uma densidade ----------
def build_log_reward_fn(kind: str, **params):
    def log_reward(env) -> torch.Tensor:
        xy = env.pos  # (B,2), usa as coordenadas do batch atual
        if kind == '8g':
            dens = density_eight_gaussians(xy, R=params.get('R', 0.8*min(env.width, env.height)),
                                              sigma=params.get('sigma', 1.0))
        elif kind == 'rings':
            dens = density_rings(xy, radii=[0.2*env.width, 0.8*env.width],
                                      sigma_r=params.get('sigma_r', 1.0),
                                      weights=params.get('weights', None))
        elif kind == 'moons':
            dens = density_two_moons(xy, R=params.get('R', 0.6*min(env.width, env.height)),
                                         delta=params.get('delta', 0.5*params.get('R', 0.6*min(env.width, env.height))),
                                         gap=params.get('gap', 0.3*params.get('R', 0.6*min(env.width, env.height))),
                                         sigma=params.get('sigma', 1.0),
                                         n_anchors=params.get('n_anchors', 256))
        else:
            raise ValueError(f"kind desconhecido: {kind}")

        # chão uniforme pequeno para evitar zeros + estabilidade numérica
        lam = params.get('lam', 1e-6)
        dens = (1 - lam)*dens + lam
        return torch.log(dens) + params.get("multiplier", 1)
    return log_reward


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import numpy as np

    def eval_on_grid(kind: str, N: int):
        W = H = N
        xs = torch.arange(-W, W+1, dtype=torch.get_default_dtype())
        ys = torch.arange(-H, H+1, dtype=torch.get_default_dtype())
        X, Y = torch.meshgrid(xs, ys, indexing='ij')       # (2N+1, 2N+1)
        pos = torch.stack([X.flatten(), Y.flatten()], dim=-1)  # (B,2)

        if kind == '8g':
            logR_fn = build_log_reward_fn('8g', R=0.8*min(W, H), sigma=1., lam=1e-6)
        elif kind == 'rings':
            Rm = min(W, H)
            logR_fn = build_log_reward_fn('rings',
                                          radii=[0.2*Rm, 0.8*Rm],
                                          sigma_r=1.0,
                                          weights=None)
        elif kind == 'moons':
            Rm = 0.6*min(W, H)
            logR_fn = build_log_reward_fn('moons',
                                          R=Rm,
                                          delta=0.5*Rm,
                                          gap=0.3*Rm,
                                          sigma=1.0,
                                          n_anchors=256,
                                          lam=1e-6)

        else:
            raise ValueError(f"kind desconhecido: {kind}")

        # Avalia
        dummy_env = SimpleNamespace(pos=pos, width=W, height=H)
        R = logR_fn(dummy_env).exp()  # (B,)
        Rmat = R.view(2*W + 1, 2*H + 1)  # (x,y) na convenção 'ij'
        return Rmat

    sizes = [18]
    kinds = ['rings']

    for N in sizes:
        for kind in kinds:
            Rmat = eval_on_grid(kind, N)
            plt.figure(figsize=(5, 4))

            # bordas das células: precisam ter +1 que o número de células
            x_edges = np.arange(-N - 0.5, N + 1.5, 1.0)
            y_edges = np.arange(-N - 0.5, N + 1.5, 1.0)
            Xe, Ye = np.meshgrid(x_edges, y_edges, indexing='ij')

            # pcolormesh desenha cada célula e permite contorno
            plt.pcolormesh(Xe, Ye, Rmat.numpy(), shading='flat',
                           edgecolors='k', linewidth=0.4)

            plt.gca().set_aspect('equal')
            plt.xlim([-N - 0.5, N + 0.5])
            plt.ylim([-N - 0.5, N + 0.5])
            plt.title(f"{kind} — N={N}")
            plt.colorbar()
            plt.tight_layout()

    plt.show()