#!/usr/bin/env python3
"""
Backward heat equation dataset generator (periodic, spectral stepping).

Generates `n_samples` trajectories of the 1-D backward heat equation
    u_t = -nu * u_xx
on a periodic domain [0, Lx). Initial conditions are Gaussian processes with
covariance C = 625(-d²/dx² + 25I)^{-2} (discrete periodic Laplacian).

Time stepping is done exactly in Fourier space:
    û_k(t+dt) = exp(+nu * k^2 * dt) * û_k(t)       <-- growth, not decay!

Outputs a compressed NumPy `.npz` with:
  - data: shape (n_samples, nt+1, nx), float32
  - x: grid (nx,)
  - t: times (nt+1,)
  - Lx, nu, dt

⚠️ Ill-posedness note:
  High-k modes grow extremely fast. Use small T, small dt, and consider the
  --kmax_frac safety filter to suppress the very highest modes.

Usage example:
  python backward_heat_dataset.py -n 4 --nx 256 --T 0.2 --dt 5e-4 --nu 0.02 \\
      --kmax_frac 0.8 -o backward_heat_dataset.npz
"""

import argparse
import math

import numpy as np
import torch
from tqdm import tqdm
from pathlib import Path


# ---------- Utilities ----------

def build_covariance(nx: int, dx: float) -> np.ndarray:
    """Return covariance matrix 625(-d²/dx² + 25I)^{-2} with periodic BC."""
    main = 2.0 * np.ones(nx)
    off = -1.0 * np.ones(nx - 1)
    lap = np.diag(main) + np.diag(off, 1) + np.diag(off, -1)
    lap[0, -1] = lap[-1, 0] = -1.0
    lap *= 1.0 / dx ** 2
    operator = lap + 25.0 * np.eye(nx)
    cov = 625.0 * np.linalg.inv(operator @ operator)
    return cov


def sample_initial_conditions(n: int, cov: np.ndarray, device: torch.device) -> torch.Tensor:
    """Draw n samples from N(0, C)."""
    samples = np.random.multivariate_normal(mean=np.zeros(cov.shape[0]), cov=cov, size=n)
    return torch.as_tensor(samples, dtype=torch.float32, device=device)


# ---------- Backward heat equation (spectral) ----------

@torch.no_grad()
def simulate_backward_heat_spectral(
        u0: torch.Tensor,
        nt: int,
        Lx: float,
        dt: float,
        nu: float,
        kmax_frac: float | None = None,
        growth_cap: float | None = 1e6,
) -> torch.Tensor:
    """
    Exact per-step integration in Fourier space on a periodic grid.
    u_t = -nu * u_xx  =>  û_k(t+dt) = exp(+nu * k^2 * dt) * û_k(t)

    Args:
      u0: (nx,) real tensor on desired device
      nt: number of steps
      Lx: domain length
      dt: time step
      nu: diffusivity coefficient (positive)
      kmax_frac: optional fraction in (0,1] to keep |k| <= kmax_frac * k_Nyquist.
                 Higher modes are zeroed to mitigate catastrophic growth.
      growth_cap: optional per-step multiplicative cap applied to exp(nu k^2 dt)
                  to avoid inf/NaN; set None to disable.

    Returns:
      history: (nt+1, nx), real-valued
    """
    nx = u0.shape[-1]
    dx = Lx / nx  # periodic grid [0, Lx)
    # angular wavenumbers k = 2π * m / Lx, consistent with torch.fft.fftfreq
    k = 2.0 * math.pi * torch.fft.fftfreq(nx, d=dx).to(u0.device, dtype=u0.dtype)
    k_abs = k.abs()

    # Optional spectral safety filter
    if kmax_frac is not None:
        k_ny = math.pi / dx  # Nyquist angular wavenumber
        mask = (k_abs <= (kmax_frac * k_ny)).to(u0.dtype)
    else:
        mask = torch.ones_like(k)

    # Exact growth factor
    growth = torch.exp(nu * (k ** 2) * dt)  # shape (nx,)
    if growth_cap is not None:
        # Hard cap the per-step multiplier to avoid overflow on very high k
        growth = torch.minimum(growth, torch.tensor(float(growth_cap), dtype=growth.dtype, device=growth.device))

    # Apply mask so the highest modes don't explode
    growth = growth * mask + (1.0 - mask)  # masked modes get multiplier 1 (frozen)

    uhat = torch.fft.fft(u0)
    # Also zero masked modes in the initial state (optional, but consistent)
    uhat = uhat * mask

    history = [u0.clone()]
    for _ in range(nt):
        uhat = uhat * growth
        u = torch.fft.ifft(uhat).real

        # Basic numerical hygiene: replace NaN/Inf if they appear
        if not torch.isfinite(u).all():
            u = torch.nan_to_num(u, nan=0.0, posinf=0.0, neginf=0.0)

        history.append(u)
    return torch.stack(history)


def simulate_backward_heat_fd(u0: torch.Tensor, nt: int, Lx: float, dt: float, nu: float) -> torch.Tensor:
    """
    Finite difference solver for backward heat equation: u_t = -nu * u_xx
    Using implicit backward Euler to handle instability.
    """
    nx = u0.shape[-1]
    dx = Lx / nx
    r = nu * dt / (dx**2)
    
    # Build tridiagonal matrix for implicit scheme
    # (I + r*A)u^{n+1} = u^n where A is the discrete Laplacian
    main_diag = 1 + 2*r
    off_diag = -r
    
    # Create circulant matrix for periodic BC
    A = torch.zeros((nx, nx), device=u0.device, dtype=u0.dtype)
    for i in range(nx):
        A[i, i] = main_diag
        A[i, (i+1) % nx] = off_diag
        A[i, (i-1) % nx] = off_diag
    
    u = u0.clone()
    history = [u0.clone()]
    
    for _ in range(nt):
        # Solve linear system: A * u_new = u
        u = torch.linalg.solve(A, u)
        history.append(u.clone())
    
    return torch.stack(history)


# ---------- Dataset generation ----------

def generate_dataset(
        n_samples: int,
        Lx: float = 2.0,
        nx: int = 64,
        T: float = 0.2,
        dt: float = 1e-3,
        nu: float = 0.02,
        kmax_frac: float | None = 0.9,
        growth_cap: float | None = 1e6,
        outfile: str = "backward_heat_dataset.npz",
        device: str | None = None,
        seed: int | None = 1234,
) -> None:
    """Generate n_samples backward-heat trajectories and save to *outfile*."""
    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)

    dev = torch.device(device) if device is not None else (
        torch.device("mps") if torch.backends.mps.is_available() else
        torch.device("cuda" if torch.cuda.is_available() else "cpu")
    )

    # Periodic grid: use endpoint=False so 0 and Lx are not duplicated
    dx = Lx / nx
    nt = int(round(T / dt))
    x = torch.linspace(start=0.0, end=Lx, steps=nx, device=dev, dtype=torch.float32)
    t = torch.linspace(start=0.0, end=T, steps=nt + 1, device=dev, dtype=torch.float32)

    # GP initial conditions via discrete periodic operator (CPU numpy)
    cov = build_covariance(nx, dx)
    u0s = sample_initial_conditions(n_samples, cov, dev)  # (n, nx) on device
    u0s = u0s * 10  # scale if you want stronger signals

    trajectories = [
        simulate_backward_heat_spectral(
            u0, nt, Lx, dt, nu, kmax_frac=kmax_frac, growth_cap=growth_cap
        ).cpu()
        for u0 in tqdm(u0s)
    ]
    data = torch.stack(trajectories)  # (n, nt+1, nx), CPU float32

    np.savez_compressed(
        outfile,
        data=data.numpy().astype(np.float32),
        x=x.cpu().numpy(),
        t=t.cpu().numpy(),
        Lx=np.float32(Lx),
        nu=np.float32(nu),
        dt=np.float32(dt),
        kmax_frac=np.float32(kmax_frac if kmax_frac is not None else -1.0),
        growth_cap=np.float32(growth_cap if growth_cap is not None else -1.0),
    )
    print(f"Saved dataset with shape {tuple(data.shape)} to '{outfile}'.")


def generate_backward_heat_data(n_samples=1000, nx=64, nt=50, Lx=2*np.pi, T=0.1):
    """Generate backward heat equation data using finite differences"""
    dt = T / nt
    nu = 0.01  # diffusion coefficient
    
    # Grid
    x = torch.linspace(0, Lx, nx)
    
    # Storage
    data = torch.zeros(n_samples, nt+1, nx)
    
    for i in range(n_samples):
        # Random initial condition (smoother for backward heat)
        u0 = 0.5 * torch.sin(x) + 0.2 * torch.sin(3*x) + 0.05 * torch.randn(nx)
        
        # Solve using finite differences
        solution = simulate_backward_heat_fd(u0, nt, Lx, dt, nu)
        data[i] = solution
    
    return data.numpy(), x.numpy()


# ---------- CLI ----------

def _cli() -> None:
    parser = argparse.ArgumentParser(description="Generate backward-heat datasets (periodic, spectral).")
    parser.add_argument("-n", "--n_samples", type=int, default=5)
    parser.add_argument("--Lx", type=float, default=2.0)
    parser.add_argument("--nx", type=int, default=256)
    parser.add_argument("--T", type=float, default=0.005, help="total time (keep small; blow-up risk)")
    parser.add_argument("--dt", type=float, default=5e-4)
    parser.add_argument("--nu", type=float, default=0.02)
    parser.add_argument("--kmax_frac", type=float, default=0.9,
                        help="keep modes with |k| <= kmax_frac * k_Nyquist; set 1.0 to keep all")
    parser.add_argument("--growth_cap", type=float, default=1e6,
                        help="cap exp(nu*k^2*dt) per step; set -1 to disable")
    parser.add_argument("-o", "--outfile", type=str, default="backward_heat_dataset.npz")
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--seed", type=int, default=None)
    args = parser.parse_args()

    growth_cap = None if (args.growth_cap is not None and args.growth_cap < 0) else args.growth_cap
    kmax_frac = None if (args.kmax_frac is not None and args.kmax_frac <= 0) else args.kmax_frac

    generate_dataset(
        n_samples=args.n_samples,
        Lx=args.Lx,
        nx=args.nx,
        T=args.T,
        dt=args.dt,
        nu=args.nu,
        kmax_frac=kmax_frac,
        growth_cap=growth_cap,
        outfile=args.outfile,
        device=args.device,
        seed=args.seed,
    )


if __name__ == "__main__":
    _cli()
