"""
Schrödinger equation module for hydrogen atom quantum dynamics.
Based on hmodel.py from the action matching repository.
"""

import torch
import numpy as np
import math
from scipy.special import factorial2
from typing import Tuple, List, Dict, Any, Optional


# Physical constants (atomic units)
hbar = 1.0
m_e = 1.0
e = 1.0
eps0 = 1.0 / (4 * np.pi)
a0 = 1.0
# Numerical epsilons
EPS_COORD = 1e-8   # for coordinate / sqrt / atan2 stability
EPS_DENS  = 1e-12  # for |psi|^2 division stability


def asslaguerre_torch(n: int, alpha: float, x: torch.Tensor) -> torch.Tensor:
    """Associated Laguerre polynomial (recursive implementation)."""
    if n == 0:
        return torch.ones_like(x)
    elif n == 1:
        return 1.0 + alpha - x
    else:
        output = (2*n - 1.0 + alpha - x) * asslaguerre_torch(n-1, alpha, x)
        output = output - (n - 1 + alpha) * asslaguerre_torch(n-2, alpha, x)
        output = output / n
        return output


def asslegendre_torch(m: int, l: int, x: torch.Tensor) -> torch.Tensor:
    """Associated Legendre polynomial (recursive implementation)."""
    if m < 0:
        m = np.abs(m)
        return ((-1)**m * math.factorial(l-m) / math.factorial(l+m) *
                asslegendre_torch(m, l, x))
    if m == l:
        return ((-1)**l * factorial2(2*l - 1) *
                torch.pow(1.0 - x**2, torch.tensor([l]).to(x.device) / 2.0))
    elif m > l:
        return torch.zeros_like(x)
    else:
        output = (2*l - 1) * x * asslegendre_torch(m, l-1, x)
        output = output - (l + m - 1) * asslegendre_torch(m, l-2, x)
        output = output / (l - m)
        return output


class EigenState:
    """Single eigenstate of the hydrogen atom."""

    def __init__(self, n: int, l: int, m: int):
        self.n, self.l, self.m = n, l, m
        self.E = -0.5 / (self.n ** 2)
        # self.E = -hbar**2 / (2.0 * m_e * a0**2) / self.n**2
        self.L2 = hbar**2 * self.l * (self.l + 1)
        self.Lz = hbar**2 * self.m

    def angular(self, theta: torch.Tensor, phi: torch.Tensor) -> torch.Tensor:
        """Angular part of the wavefunction (spherical harmonic)."""
        n, l, m = self.n, self.l, self.m
        output = asslegendre_torch(m, l, torch.cos(theta))
        output = output * ((-1)**m * np.sqrt(
            (2.0*l + 1.0) * math.factorial(l-m) /
            math.factorial(l+m) / 4.0 / np.pi
        ))
        output = output * torch.exp(1.0j * m * phi)
        return output

    def radial(self, r: torch.Tensor) -> torch.Tensor:
        """Radial part of the wavefunction."""
        n, l = int(self.n), int(self.l)

        # all scalar constants must be Python floats, not numpy, not tensors
        na0 = float(n) * float(a0)
        log_na0 = math.log(na0)

        # torch ops only on torch tensors
        output = torch.exp(-r / na0 + l * torch.log(2.0 * r) - l * log_na0)

        # normalization constant is a scalar float
        norm = math.sqrt(
            (2.0 / na0) ** 3
            * math.factorial(n - l - 1)
            / math.factorial(n + l)
            / 2.0
            / n
        )
        output = output * norm

        output = output * asslaguerre_torch(n - l - 1, 2 * l + 1, 2.0 * r / na0)
        return output

    def _radial(self, r: torch.Tensor) -> torch.Tensor:
        # exp(_radial_log) * _radial == radial
        return self.radial(r) * torch.exp(-self._radial_log(r))

    def _radial_log(self, r: torch.Tensor) -> torch.Tensor:
        """Log of radial exponential part."""
        n, l = int(self.n), int(self.l)
        na0 = float(n) * float(a0)
        log_na0 = math.log(na0)
        return -r / na0 + l * torch.log(2.0 * r) - l * log_na0


class WaveFunction:
    """Superposition of hydrogen atom eigenstates."""

    def __init__(self, n: torch.Tensor, l: torch.Tensor, m: torch.Tensor,
                 c0: torch.Tensor, device: torch.device):
        """
        Args:
            n: Principal quantum numbers (tensor of ints)
            l: Angular momentum quantum numbers (tensor of ints)
            m: Magnetic quantum numbers (tensor of ints)
            c0: Complex coefficients for superposition
            device: torch device
        """
        assert (n < 0).sum() == 0
        assert (l < 0).sum() == (l >= n).sum() == 0
        assert (m > l).sum() == (m < -l).sum() == 0

        self.n, self.l, self.m, self.c0 = n, l, m, c0
        self.c0 = self.c0.to(device)
        self.c0 = self.c0 / torch.sqrt(torch.sum(self.c0.abs()**2))
        n_list = n.detach().cpu().tolist()
        l_list = l.detach().cpu().tolist()
        m_list = m.detach().cpu().tolist()
        self.states = [EigenState(int(qn), int(ql), int(qm)) for qn, ql, qm in zip(n_list, l_list, m_list)]

        self.dim = 3
        self.device = device

    def evolve_to(self, t: float) -> 'WaveFunction':
        """Time evolution of the wavefunction."""
        E = torch.tensor([psi.E for psi in self.states]).to(self.device)
        return WaveFunction(self.n, self.l, self.m,
                          torch.exp(-1j * E * t / hbar) * self.c0,
                          self.device)

    def avgH(self) -> torch.Tensor:
        """Average energy."""
        E = torch.tensor([psi.E for psi in self.states])
        return torch.sum(self.c0.abs()**2 * E)

    def avgL2(self) -> torch.Tensor:
        """Average angular momentum squared."""
        L2 = torch.tensor([psi.L2 for psi in self.states])
        return torch.sum(self.c0.abs()**2 * L2)

    def avgLz(self) -> torch.Tensor:
        """Average z-component of angular momentum."""
        Lz = torch.tensor([psi.Lz for psi in self.states])
        return torch.sum(self.c0.abs()**2 * Lz)

    def at(self, x: torch.Tensor) -> torch.Tensor:
        """Evaluate wavefunction at Cartesian coordinates."""
        r = torch.sqrt(x[:,0]**2 + x[:,1]**2 + x[:,2]**2 + EPS_COORD).flatten()
        theta = torch.atan2(torch.sqrt(x[:,0]**2 + x[:,1]**2 + EPS_COORD), x[:,2]).flatten()
        x_coord = torch.sign(x[:,0]) * (torch.abs(x[:,0]) + EPS_COORD)
        phi = torch.atan2(x[:,1], x_coord).flatten()
        return self.at_polar(r, theta, phi)

    def at_polar(self, r: torch.Tensor, theta: torch.Tensor,
                 phi: torch.Tensor) -> torch.Tensor:
        """Evaluate wavefunction at polar coordinates."""
        assert r.shape == theta.shape == phi.shape
        output = 1j * torch.zeros_like(r)
        for i in range(len(self.states)):
            psi_i = self.states[i]
            output += self.c0[i] * psi_i.radial(r) * psi_i.angular(theta, phi)
        return output

    def log_prob(self, x: torch.Tensor) -> torch.Tensor:
        """Compute log probability density |ψ(x,t)|²."""
        r = torch.sqrt(x[:,0]**2 + x[:,1]**2 + x[:,2]**2 + EPS_COORD).flatten()
        z = torch.sign(x[:,2]) * (torch.abs(x[:,2]) + EPS_COORD)
        theta = torch.atan2(torch.sqrt(x[:,0]**2 + x[:,1]**2 + EPS_COORD), z).flatten()
        x_coord = torch.sign(x[:,0]) * (torch.abs(x[:,0]) + EPS_COORD)
        phi = torch.atan2(x[:,1], x_coord).flatten()

        radial_log = torch.stack([psi._radial_log(r) for psi in self.states])
        angular = torch.stack([psi._radial(r) * psi.angular(theta, phi)
                              for psi in self.states])
        coords = self.c0.view([-1, 1])
        max_log, _ = torch.max(radial_log, dim=0)
        psi = (torch.exp(radial_log - max_log) * angular * coords).sum(0)
        output = 2 * torch.log(psi.abs()) + 2 * max_log
        return output


class BohmianDynamics:
    """Bohmian mechanics: particles guided by wavefunction phase."""

    def __init__(self, wave_function: WaveFunction, samples: torch.Tensor):
        self.psi = wave_function
        self.samples = samples

    def propagate(self, dt: float):
        """Propagate particles according to Bohmian velocity field."""
        samples = self.samples
        samples.requires_grad = True
        v = torch.autograd.grad(self.psi.at(samples).angle().sum(), samples)[0]
        samples.data += dt * v
        samples.requires_grad = False
        self.samples = samples
        self.psi = self.psi.evolve_to(dt)


# ============================================================
# Data generator integration
# ============================================================

def generate_schrodinger(
        *,
        N: int,
        steps: int,
        dt: float,
        d: int,
        num_p0: int,
        quantum_numbers: List[Tuple[int, int, int]],
        coefficients: Optional[List[complex]] = None,
        initial_sampling: str = "wavefunction",
        init_sphere_radius: float = 5.0,
        device: torch.device,
        **kwargs,  # Absorb unused args
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Dict[str, Any]]:
    """
    Generate trajectories following Bohmian mechanics for hydrogen atom.

    Args:
        N: Number of particles per initial condition
        steps: Number of time steps
        dt: Time step size
        d: Dimension (must be 3 for hydrogen atom)
        num_p0: Number of different initial conditions
        quantum_numbers: List of (n, l, m) tuples defining the wavefunction
        coefficients: Complex coefficients for superposition (normalized automatically)
        initial_sampling: How to initialize particles ("wavefunction" or "sphere")
        init_sphere_radius: Radius for sphere sampling if initial_sampling="sphere"
        device: torch device

    Returns:
        X_em_torch: (num_p0, N, steps+1, 3) trajectories
        V_em_torch: (num_p0, N, steps+1, 3) velocities
        time_grid: (steps+1,) time values
        meta: metadata dictionary
    """
    if d != 3:
        raise ValueError("Schrödinger mode requires d=3 (hydrogen atom in 3D).")

    # Parse quantum numbers
    ns = torch.tensor([qn[0] for qn in quantum_numbers], dtype=torch.int32)
    ls = torch.tensor([qn[1] for qn in quantum_numbers], dtype=torch.int32)
    ms = torch.tensor([qn[2] for qn in quantum_numbers], dtype=torch.int32)

    # Set up coefficients
    if coefficients is None:
        c0 = torch.ones(len(quantum_numbers), dtype=torch.complex64)
    else:
        c0 = torch.tensor(coefficients, dtype=torch.complex64)

    # Create wavefunction
    psi_t0 = WaveFunction(ns, ls, ms, c0, device)

    all_pops: List[torch.Tensor] = []

    for i in range(num_p0):
        # Initialize particle positions
        if initial_sampling == "wavefunction":
            samples = _sample_from_wavefunction(psi_t0, N, device)
        elif initial_sampling == "sphere":
            samples = torch.randn(N, 3, device=device)
            samples = samples / samples.norm(dim=1, keepdim=True) * init_sphere_radius
        else:
            raise ValueError(f"Unknown initial_sampling: {initial_sampling}")

        # Initialize Bohmian dynamics
        bohmian = BohmianDynamics(psi_t0, samples)

        # Collect trajectory
        X_traj = torch.empty((N, steps + 1, 3), device=device, dtype=torch.float32)
        X_traj[:, 0, :] = bohmian.samples.real.to(torch.float32)

        for t in range(steps):
            bohmian.propagate(dt)
            X_traj[:, t + 1, :] = bohmian.samples.real.to(torch.float32)

        all_pops.append(X_traj)

    X_em_torch = torch.stack(all_pops, dim=0)
    time_grid = torch.arange(steps + 1, device=device, dtype=torch.float32) * dt

    # Compute velocities
    meta_for_vel = {
        "quantum_numbers": quantum_numbers,
        "coefficients": [complex(c) for c in c0.cpu().numpy()],
    }

    all_vels = []
    for i in range(num_p0):
        V_pop = compute_schrodinger_bohmian_velocity(
            X_traj=X_em_torch[i],
            time_grid=time_grid,
            meta=meta_for_vel,
        )
        all_vels.append(V_pop)
    V_em_torch = torch.stack(all_vels, dim=0)

    meta = {
        "mode": "schrodinger",
        "N": N,
        "steps": steps,
        "dt": dt,
        "d": d,
        "num_p0": num_p0,
        "quantum_numbers": quantum_numbers,
        "coefficients": [complex(c) for c in c0.cpu().numpy()],
        "initial_sampling": initial_sampling,
        "init_sphere_radius": init_sphere_radius,
        "a0": a0,
        "avg_energy": float(psi_t0.avgH().real),
        "has_vel": True,
    }

    return X_em_torch, V_em_torch, time_grid, meta


def _sample_from_wavefunction(
    psi: WaveFunction,
    N: int,
    device: torch.device,
    max_attempts: int = 10_000_000,   # increase
    box_size: float = 20.0,
    proposal_chunk: int = 50_000,     # new: cap chunk
) -> torch.Tensor:
    samples = []
    accepted_total = 0
    attempts = 0

    # Estimate max log density
    test_points = torch.randn(20_000, 3, device=device) * 5.0
    log_probs = psi.log_prob(test_points)
    max_log_prob = log_probs.max().item() + 2.0

    while accepted_total < N and attempts < max_attempts:
        n_propose = min(proposal_chunk, max_attempts - attempts)
        x_prop = (torch.rand(n_propose, 3, device=device) - 0.5) * 2 * box_size

        log_prob = psi.log_prob(x_prop)
        accept_prob = torch.exp(log_prob - max_log_prob).clamp(max=1.0)
        u = torch.rand(n_propose, device=device)

        accepted = x_prop[u < accept_prob]
        if accepted.numel() > 0:
            samples.append(accepted)
            accepted_total += accepted.shape[0]

        attempts += n_propose

    if accepted_total < N:
        raise RuntimeError(
            f"Failed to sample N={N} points from |psi|^2. "
            f"Got {accepted_total}. Increase max_attempts/box_size or change proposal."
        )

    all_samples = torch.cat(samples, dim=0)
    return all_samples[:N]


def compute_schrodinger_bohmian_velocity(
        X_traj: torch.Tensor,  # (N, T+1, 3)
        time_grid: torch.Tensor,  # (T+1,)
        meta: dict,
        eps_dens: float = 1e-12,
) -> torch.Tensor:
    """
    Compute Bohmian velocity using the wavefunction.

    v(x,t) = Im(ψ*(x,t) ∇ψ(x,t)) / |ψ(x,t)|²

    Args:
        X_traj: Particle trajectories (N, T+1, 3)
        time_grid: Time values (T+1,)
        meta: Metadata dict with quantum_numbers, coefficients
        eps_dens: Numerical stability for density

    Returns:
        V_traj: Velocity trajectories (N, T+1, 3)
    """
    from dataset_modules.schrodinger import WaveFunction

    device = X_traj.device
    N, T_plus_1, d = X_traj.shape

    if d != 3:
        raise ValueError("Schrödinger velocity requires d=3")

    # Parse quantum numbers
    qnums = meta["quantum_numbers"]
    ns = torch.tensor([q[0] for q in qnums], dtype=torch.int32, device="cpu")
    ls = torch.tensor([q[1] for q in qnums], dtype=torch.int32, device="cpu")
    ms = torch.tensor([q[2] for q in qnums], dtype=torch.int32, device="cpu")

    # Parse coefficients
    coeffs_raw = meta.get("coefficients", None)
    if coeffs_raw is None:
        c0 = torch.ones(len(qnums), dtype=torch.complex64)
    else:
        coeffs = []
        for c in coeffs_raw:
            if isinstance(c, complex):
                coeffs.append(c)
            elif isinstance(c, dict) and ("real" in c) and ("imag" in c):
                coeffs.append(complex(float(c["real"]), float(c["imag"])))
            elif isinstance(c, (list, tuple)) and len(c) == 2:
                coeffs.append(complex(float(c[0]), float(c[1])))
            else:
                coeffs.append(complex(str(c).replace(" ", "")))
        c0 = torch.tensor(coeffs, dtype=torch.complex64)

    # Create initial wavefunction
    wf_t0 = WaveFunction(ns, ls, ms, c0.to(device), device)

    # Check for torch.func availability
    try:
        from torch.func import vmap, jacrev
        has_torch_func = True
    except Exception:
        has_torch_func = False

    V_traj = torch.zeros_like(X_traj)

    for t_idx in range(T_plus_1):
        x = X_traj[:, t_idx, :]  # (N, 3)
        t = float(time_grid[t_idx].item())

        # Evolve wavefunction to time t
        wf_t = wf_t0.evolve_to(t) if t != 0.0 else wf_t0

        if has_torch_func:
            # Pointwise gradients via torch.func
            def psi_single(x1: torch.Tensor) -> torch.Tensor:
                return wf_t.at(x1.unsqueeze(0)).squeeze(0)

            psi_x = vmap(psi_single)(x)

            def grad_psi_single(x1: torch.Tensor) -> torch.Tensor:
                gr = jacrev(lambda z: psi_single(z).real)(x1)
                gi = jacrev(lambda z: psi_single(z).imag)(x1)
                return gr + 1j * gi

            grad_psi = vmap(grad_psi_single)(x)  # (N, 3) complex

            num = torch.imag(torch.conj(psi_x).unsqueeze(1) * grad_psi)
            den = (psi_x.abs() ** 2).unsqueeze(1) + eps_dens
            V_traj[:, t_idx, :] = num / den
        else:
            # Fallback: phase gradient
            x_req = x.detach().requires_grad_(True)
            psi_x = wf_t.at(x_req)
            theta = psi_x.angle()
            grad_theta = torch.autograd.grad(theta.sum(), x_req, create_graph=False)[0]
            V_traj[:, t_idx, :] = grad_theta.detach()

    return V_traj


# ============================================================
# Add to data_generator.py CLI
# ============================================================

def add_schrodinger_parser(subparsers):
    """Add Schrödinger equation subparser to argument parser."""
    ps = subparsers.add_parser("schrodinger",
                               help="Hydrogen atom quantum dynamics (Bohmian mechanics)")
    ps.add_argument("--N", type=int, required=False, default=None)
    ps.add_argument("--steps", type=int, required=False, default=None)
    ps.add_argument("--dt", type=float, required=False, default=None)
    ps.add_argument("--d", type=int, default=3, help="Must be 3 for hydrogen atom")
    ps.add_argument("--num-p0", type=int, default=1)

    ps.add_argument("--quantum-numbers", type=str, required=False, default="2,1,0",
                   help="Semicolon-separated list of n,l,m triplets, e.g., '2,1,0;2,1,1'")
    ps.add_argument("--coefficients", type=str, default=None,
                   help="Semicolon-separated complex coefficients, e.g., '1+0j;0.5+0.5j'")
    ps.add_argument("--initial-sampling", type=str, default="wavefunction",
                   choices=["wavefunction", "sphere"])
    ps.add_argument("--init-sphere-radius", type=float, default=5.0)

    return ps