# topoformer_synthetic.py
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Literal, Optional, Tuple
import numpy as np
import torch

@dataclass
class TopologyConfig:
    n_nodes: int = 32
    area_size: float = 1.0
    layout: Literal["uniform","clusters","grid"] = "uniform"
    n_clusters: int = 4
    cluster_std: float = 0.07
    grid_jitter: float = 0.02
    min_dist: float = 0.0

@dataclass
class RadioConfig:
    tx_radius: float = 0.3
    beam_width_deg: float = 360.0
    n_beams: int = 1
    max_degree: int = 8
    symmetric: bool = True

def set_seed(seed: int|None):
    if seed is None: return
    np.random.seed(seed); torch.manual_seed(seed)

def _reject_min_dist(points: torch.Tensor, min_dist: float, area: float, max_tries: int=2000):
    if min_dist <= 0: return points
    out = points.clone()
    N = out.shape[0]
    for i in range(N):
        tries = 0
        while True:
            if i == 0: break
            d = torch.linalg.norm(out[:i]-out[i], dim=-1)
            if (d >= min_dist).all(): break
            out[i] = torch.rand(2) * area
            tries += 1
            if tries > max_tries: break
    return out

def sample_positions(topo: TopologyConfig, seed: int|None=None) -> torch.Tensor:
    set_seed(seed)
    N, A = topo.n_nodes, topo.area_size
    if topo.layout == "uniform":
        pos = torch.rand(N,2) * A
    elif topo.layout == "clusters":
        k = max(1, topo.n_clusters)
        centers = torch.rand(k,2)*A
        assign = torch.from_numpy(np.random.randint(0,k,size=N))
        pos = torch.zeros(N,2)
        std = topo.cluster_std * A
        for i in range(N):
            c = centers[assign[i]]
            pos[i] = c + torch.randn(2)*std
        pos = torch.clamp(pos, 0.0, A)
    elif topo.layout == "grid":
        g = int(math.ceil(math.sqrt(N)))
        xs = torch.linspace(0.5/g, 1-0.5/g, g) * A
        X, Y = torch.meshgrid(xs, xs, indexing="xy")
        grid = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)[:N]
        jitter = (torch.rand_like(grid)-0.5) * (topo.grid_jitter*A)
        pos = torch.clamp(grid + jitter, 0.0, A)
    else:
        raise ValueError("unknown layout")
    if topo.min_dist>0: pos = _reject_min_dist(pos, topo.min_dist, A)
    return pos

def _pairwise(pos: torch.Tensor):
    diff = pos.unsqueeze(1) - pos.unsqueeze(0)
    dist = torch.linalg.norm(diff + 1e-12, dim=-1)
    ang  = torch.atan2(diff[...,1], diff[...,0])  # in [-pi,pi]
    return dist, ang

def _angle_diff(a: torch.Tensor, b: torch.Tensor):
    return (a - b + math.pi) % (2*math.pi) - math.pi

def build_graph(pos: torch.Tensor, radio: RadioConfig, seed: int|None=None) -> torch.Tensor:
    set_seed(seed)
    N = pos.shape[0]
    A = torch.zeros(N,N, dtype=torch.bool)
    dist, ang = _pairwise(pos)
    in_range = (dist <= radio.tx_radius) & (~torch.eye(N, dtype=torch.bool))
    if radio.beam_width_deg >= 359.999:
        for i in range(N):
            cand = torch.where(in_range[i])[0]
            if cand.numel()==0: continue
            d = dist[i, cand]
            keep = cand[torch.argsort(d)][: radio.max_degree]
            A[i, keep] = True
    else:
        bw = math.radians(radio.beam_width_deg)
        for i in range(N):
            cand = torch.where(in_range[i])[0]
            if cand.numel()==0: continue
            d = dist[i, cand]
            order = torch.argsort(d)
            cand = cand[order]
            chosen = []
            beams = []
            for j in cand:
                if len(chosen) >= radio.max_degree: break
                theta = ang[i, j]
                covered = any(abs(float(_angle_diff(theta, c))) <= bw/2 for c in beams)
                if not covered:
                    if len(beams) >= radio.n_beams: 
                        continue
                    beams.append(float(theta))
                chosen.append(int(j))
            if chosen:
                A[i, torch.tensor(chosen)] = True
    if radio.symmetric:
        A = A & A.t()
    A.fill_diagonal_(False)
    return A

def dde_features(pos: torch.Tensor, r_bins: torch.Tensor, a_bins: torch.Tensor) -> torch.Tensor:
    N = pos.shape[0]; R = len(r_bins)-1; A = len(a_bins)-1
    H = torch.zeros(N,R,A, dtype=torch.float32)
    dist, ang = _pairwise(pos)
    for i in range(N):
        for j in range(N):
            if i==j: continue
            d = dist[i,j].item(); a = ang[i,j].item()
            r_idx = int(torch.bucketize(torch.tensor(d), r_bins) - 1)
            a_idx = int(torch.bucketize(torch.tensor(a), a_bins) - 1)
            if 0<=r_idx<R and 0<=a_idx<A:
                H[i,r_idx,a_idx] += 1.0
    return H
