#!/usr/bin/env python3
"""
Reaction–diffusion (linear, mildly unstable) dataset generator (periodic, exact spectral stepping).

PDE: u_t = nu * u_xx + lam * u   on [0, Lx), periodic.

Fourier mode-wise solution:
    d/dt û_k = (lam - nu * k^2) û_k
    => û_k(t+dt) = exp((lam - nu k^2) dt) * û_k(t)

Initial condition:
  - u0 ~ GP(0, C), with C = 625(-d²/dx² + 25I)^{-2} from the discrete periodic Laplacian.

Outputs a compressed NumPy `.npz` with:
  - data: (n_samples, nt+1, nx)  # u(t, x)
  - x: (nx,)
  - t: (nt+1,)
  - Lx, nu, lam, dt

Usage:
  python reaction_diffusion1d_dataset.py -n 8 --nx 256 --T 2.0 --dt 2e-3 --nu 1e-2 --lam 0.05 -o rd1d_dataset.npz
"""

import argparse
from pathlib import Path

import numpy as np
import torch
from tqdm import tqdm


# ---------- Utilities ----------

def build_covariance(nx: int, dx: float) -> np.ndarray:
    """Return covariance 625(-d²/dx² + 25I)^{-2} with periodic BC."""
    main = 2.0 * np.ones(nx)
    off = -1.0 * np.ones(nx - 1)
    lap = np.diag(main) + np.diag(off, 1) + np.diag(off, -1)
    lap[0, -1] = lap[-1, 0] = -1.0
    lap *= 1.0 / dx ** 2
    operator = lap + 25.0 * np.eye(nx)
    cov = 625.0 * np.linalg.inv(operator @ operator)
    return cov


def sample_initial_conditions(n: int, cov: np.ndarray, device: torch.device) -> torch.Tensor:
    """Draw n samples from N(0, C)."""
    samples = np.random.multivariate_normal(mean=np.zeros(cov.shape[0]), cov=cov, size=n)
    return torch.as_tensor(samples, dtype=torch.float32, device=device)


# ---------- Reaction–diffusion (spectral, exact) ----------

def simulate_rd_spectral(u0: torch.Tensor, nt: int, Lx: float, dt: float, nu: float, lam: float) -> torch.Tensor:
    """
    Exact per-step integration in Fourier space on a periodic grid.
    u0: (nx,) real tensor on desired device
    Returns: U of shape (nt+1, nx)
    """
    nx = u0.shape[-1]
    dx = Lx / nx
    k = 2.0 * np.pi * torch.fft.fftfreq(nx, d=dx).to(u0.device, dtype=u0.dtype)  # (nx,)
    growth = lam - nu * (k ** 2)  # per-mode growth rate

    uhat = torch.fft.fft(u0)
    U_hist = [u0.clone()]

    for _ in range(nt):
        uhat = uhat * torch.exp(growth * dt)
        u = torch.fft.ifft(uhat).real
        U_hist.append(u)

    U = torch.stack(U_hist)  # (nt+1, nx)
    return U


def simulate_rd_fd(u0: torch.Tensor, nt: int, Lx: float, dt: float, nu: float, lam: float) -> torch.Tensor:
    """
    Finite difference solver for reaction-diffusion: u_t = nu * u_xx + lam * u
    Using operator splitting: diffusion (implicit) + reaction (explicit).
    """
    nx = u0.shape[-1]
    dx = Lx / nx
    r = nu * dt / (dx ** 2)

    # Build matrix for implicit diffusion step
    main_diag = 1 + 2 * r
    off_diag = -r

    A = torch.zeros((nx, nx), device=u0.device, dtype=u0.dtype)
    for i in range(nx):
        A[i, i] = main_diag
        A[i, (i + 1) % nx] = off_diag
        A[i, (i - 1) % nx] = off_diag

    u = u0.clone()
    history = [u0.clone()]

    for _ in range(nt):
        # Step 1: Implicit diffusion
        u = torch.linalg.solve(A, u)

        # Step 2: Explicit reaction
        u = u * torch.exp(torch.tensor(lam * dt))

        history.append(u.clone())

    return torch.stack(history)


# ---------- Dataset generation ----------

def generate_dataset(
        n_samples: int,
        Lx: float = 2.0,
        nx: int = 256,
        T: float = 2.0,
        dt: float = 2e-3,
        nu: float = 1e-2,
        lam: float = 0.05,
        outfile: str = "rd1d_dataset.npz",
        device: str | None = None,
        seed: int | None = 1234,
        u_scale: float = 1.0,
) -> None:
    """Generate n_samples RD trajectories and save to *outfile*."""
    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)

    dev = torch.device(device) if device is not None else (
        torch.device("mps") if torch.backends.mps.is_available() else
        torch.device("cuda" if torch.cuda.is_available() else "cpu")
    )

    dx = Lx / nx
    nt = int(round(T / dt))
    x = torch.linspace(start=0.0, end=Lx, steps=nx, device=dev, dtype=torch.float32)
    t = torch.linspace(start=0.0, end=T, steps=nt + 1, device=dev, dtype=torch.float32)

    cov = build_covariance(nx, dx)  # CPU numpy
    u0s = sample_initial_conditions(n_samples, cov, dev) * u_scale

    U_list = []
    for u0 in tqdm(u0s, total=n_samples):
        U = simulate_rd_spectral(u0, nt, Lx, dt, nu, lam)
        U_list.append(U.cpu())

    U_all = torch.stack(U_list)  # (n, nt+1, nx)

    np.savez_compressed(
        outfile,
        data=U_all.numpy().astype(np.float32),
        x=x.cpu().numpy(),
        t=t.cpu().numpy(),
        Lx=np.float32(Lx),
        nu=np.float32(nu),
        lam=np.float32(lam),
        dt=np.float32(dt),
    )
    print(f"Saved dataset with shape data={tuple(U_all.shape)} to '{outfile}'.")


def generate_reaction_diffusion_data(n_samples=1000, nx=64, nt=100, Lx=2 * np.pi, T=1.0):
    """Generate reaction-diffusion equation data using finite differences"""
    dt = T / nt
    nu = 0.01  # diffusion coefficient
    lam = 0.1  # reaction rate

    # Grid
    x = torch.linspace(0, Lx, nx)

    # Storage
    data = torch.zeros(n_samples, nt + 1, nx)

    for i in range(n_samples):
        # Random initial condition
        u0 = torch.sin(x) + 0.3 * torch.sin(2 * x) + 0.1 * torch.randn(nx)

        # Solve using finite differences
        solution = simulate_rd_fd(u0, nt, Lx, dt, nu, lam)
        data[i] = solution

    return data.numpy(), x.numpy()


# ---------- CLI ----------

def _cli() -> None:
    parser = argparse.ArgumentParser(description="Generate reaction–diffusion datasets (periodic, spectral, exact).")
    parser.add_argument("-n", "--n_samples", type=int, default=5)
    parser.add_argument("--Lx", type=float, default=2.0)
    parser.add_argument("--nx", type=int, default=256)
    parser.add_argument("--T", type=float, default=2.0)
    parser.add_argument("--dt", type=float, default=2e-3)
    parser.add_argument("--nu", type=float, default=1e-2)
    parser.add_argument("--lam", type=float, default=0.05)
    parser.add_argument("-o", "--outfile", type=str, default="rd1d_dataset.npz")
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--u_scale", type=float, default=1.0, help="Scale for u0 GP samples.")
    args = parser.parse_args()
    generate_dataset(**vars(args))


if __name__ == "__main__":
    _cli()
