import torch
from typing import Dict, Any, Optional, Tuple, List
from mechanics import euler_maruyama_generalized
from functions import get_potential_grad_as_torch
import numpy as np
import argparse

# ============================================================
# Mode 1: Orbit generator
# ============================================================
def gaussian_blob(
    N: int,
    d: int,
    device: torch.device,
    mean: torch.Tensor,
    var: float,
) -> torch.Tensor:
    """X ~ N(mean, var * I_d). RNG controlled by set_seed(args.seed)."""
    mean = mean.to(device=device, dtype=torch.float32).view(1, d)
    return mean + (var ** 0.5) * torch.randn((N, d), device=device)


def get_rotational_drift(omega: float):
    def drift(x: torch.Tensor, t: float):
        return torch.stack([-omega * x[:, 1], omega * x[:, 0]], dim=1)
    return drift


def generate_orbit(
        *,
        N: int,
        steps: int,
        dt: float,
        d: int,
        num_p0: int,
        radius: float,
        omega: float,
        init_var: float,
        sigma: float,
        device: torch.device,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Dict[str, Any]]:
    """
    Rotational drift SDE:
      dX_t = b(X_t) dt + sigma dW_t,
    with b the rigid rotation field.

    Generates num_p0 distinct blobs placed on a circle of radius 'radius'.

    Returns:
        X_em_torch: (num_p0, N, steps+1, d)
        V_em_torch: (num_p0, N, steps+1, d) or None
        time_grid: (steps+1,)
        meta: dict
    """
    if d != 2:
        raise ValueError("orbit mode currently assumes d=2.")

    drift_gt = get_rotational_drift(float(omega))

    all_pops: List[torch.Tensor] = []
    for i in range(int(num_p0)):
        angle = (i / max(1, int(num_p0))) * 2.0 * np.pi
        center = torch.tensor(
            [float(radius) * np.cos(angle), float(radius) * np.sin(angle)],
            device=device,
            dtype=torch.float32,
        )
        x0 = gaussian_blob(N=int(N), d=2, device=device, mean=center, var=float(init_var))
        X_pop = euler_maruyama_generalized(
            x0,
            drift_gt,
            sigma=float(sigma),
            dt=float(dt),
            steps=int(steps),
        )  # (N, steps+1, 2)
        all_pops.append(X_pop)

    X_em_torch = torch.stack(all_pops, dim=0)  # (num_p0, N, steps+1, 2)
    time_grid = torch.arange(int(steps) + 1, device=device, dtype=torch.float32) * float(dt)

    # Compute velocities
    all_vels = []
    for i in range(num_p0):
        V_pop = compute_orbit_velocity(X_em_torch[i], omega=float(omega))
        all_vels.append(V_pop)
    V_em_torch = torch.stack(all_vels, dim=0)  # (num_p0, N, steps+1, 2)


    meta = {
        "mode": "orbit",
        "N": int(N),
        "steps": int(steps),
        "dt": float(dt),
        "d": int(d),
        "num_p0": int(num_p0),
        "radius": float(radius),
        "omega": float(omega),
        "init_var": float(init_var),
        "sigma": float(sigma),
        "has_vel": V_em_torch,
    }

    return X_em_torch, V_em_torch, time_grid, meta


def compute_orbit_velocity(
        X_traj: torch.Tensor,  # (N, T+1, 2)
        omega: float,
) -> torch.Tensor:
    """
    Compute rotational velocity v(x) = ω(-y, x).

    Args:
        X_traj: Particle trajectories (N, T+1, 2)
        omega: Angular velocity

    Returns:
        V_traj: Velocity trajectories (N, T+1, 2)
    """
    N, T_plus_1, d = X_traj.shape
    if d != 2:
        raise ValueError("Orbit velocity requires d=2")

    V_traj = torch.zeros_like(X_traj)

    for t_idx in range(T_plus_1):
        x = X_traj[:, t_idx, :]  # (N, 2)
        V_traj[:, t_idx, 0] = -omega * x[:, 1]
        V_traj[:, t_idx, 1] = omega * x[:, 0]

    return V_traj

def add_orbit_parser(subparsers) -> argparse.ArgumentParser:
    """Add orbit (rigid rotation) subparser."""
    po = subparsers.add_parser("orbit", help="Rigid rotation (orbit) SDE dataset")
    po.add_argument("--N", type=int, required=False, default=None)
    po.add_argument("--steps", type=int, required=False, default=None)
    po.add_argument("--dt", type=float, required=False, default=None)
    po.add_argument("--d", type=int, default=2)
    po.add_argument("--num-p0", type=int, default=1)
    po.add_argument("--radius", type=float, default=3.0)
    po.add_argument("--omega", type=float, default=2.0)
    po.add_argument("--init-var", type=float, default=0.1)
    po.add_argument("--sigma", type=float, default=0.0)

    # Velocity mode
    po.add_argument("--vel-mode", type=str, default="analytic",
                    choices=["analytic", "zero", "gradient_flow"],
                    help="Velocity computation mode")

    # Gradient flow specific (if vel_mode=gradient_flow)
    po.add_argument("--score-method", type=str, default="kernel",
                    choices=["kernel", "neural"],
                    help="Score estimation method for gradient flow")
    po.add_argument("--score-hidden", type=int, default=64,
                    help="Hidden dim for neural score estimator")
    po.add_argument("--score-train-steps", type=int, default=5,
                    help="Training steps per timestep for neural score")

    return po