#!/usr/bin/env python3
"""
Fractional backward diffusion dataset generator (1D, periodic).
PDE: u_t = + (-Δ)^α u,  α ∈ (0, 2]

Stable-ish implementation:
- Uses the *finite-difference* Laplacian symbol λ_k (periodic), then μ_k = (-λ_k)^α.
- Time stepping: Crank–Nicolson per-mode via FFT:
      û^{n+1}_k = ((1 + 0.5*dt*σ*μ_k) / (1 - 0.5*dt*σ*μ_k)) * û^{n}_k
  This approximates exp(μ_k dt) but avoids the violent BE singularity at dt*μ_k≈1.
- α tunes the instability spectrum: growth ~ e^{|k|^{2α} t} (α=1 → backward heat; α=2 → backward bi-heat).
- σ∈(0,1] scales the growth (smaller σ = slower blow-up).

Applications: inverse anomalous-diffusion / deconvolution models; control/identification tests.

Output: compressed .npz with
  - data: (n_samples, nt+1, nx)
  - x: (nx,), t: (nt+1,)
  - meta: Lx, nx, T, dt, pde="frac_bd_cn", alpha, sigma
"""

import argparse
import numpy as np
import torch
from tqdm import tqdm

# ---------------- Utilities ----------------

def build_covariance(nx: int, dx: float) -> np.ndarray:
    """Return covariance 625(-d2/dx2 + 25 I)^{-2} with periodic BC (FD Laplacian)."""
    main = 2.0 * np.ones(nx)
    off = -1.0 * np.ones(nx - 1)
    lap = np.diag(main) + np.diag(off, 1) + np.diag(off, -1)
    lap[0, -1] = lap[-1, 0] = -1.0
    lap *= 1.0 / dx**2
    operator = lap + 25.0 * np.eye(nx)
    cov = 625.0 * np.linalg.inv(operator @ operator)
    return cov

def sample_initial_conditions(n: int, cov: np.ndarray, device: torch.device, seed: int | None) -> torch.Tensor:
    rng = np.random.default_rng(seed)
    samples = rng.multivariate_normal(mean=np.zeros(cov.shape[0]), cov=cov, size=n)
    return torch.as_tensor(samples, dtype=torch.float32, device=device)

def fd_laplacian_eigs(nx: int, dx: float, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    """Eigenvalues of periodic centered-difference Laplacian: λ_k = -4 sin^2(pi k / nx) / dx^2."""
    k = torch.arange(nx, device=device, dtype=dtype)
    lam = -4.0 * torch.sin(np.pi * k / nx)**2 / (dx**2)
    return lam

# ---------------- Fractional backward diffusion (CN) ----------------

def simulate_frac_bd_cn(u0: torch.Tensor, nt: int, Lx: float, dt: float, alpha: float, sigma: float) -> torch.Tensor:
    """
    Crank–Nicolson stepping for u_t = (+) (-Δ)^α u.
    With FD Laplacian eigenvalues λ_k ≤ 0, define μ_k = (-λ_k)^α ≥ 0, then:
        g_k = (1 + 0.5*dt*sigma*μ_k) / (1 - 0.5*dt*sigma*μ_k)
    """
    nx = u0.shape[-1]
    dx = Lx / nx
    device, dtype = u0.device, u0.dtype

    lam = fd_laplacian_eigs(nx, dx, device, dtype)
    mu = torch.pow((-lam).clamp_min(0.0), alpha)  # (-λ)^α

    num = 1.0 + 0.5 * dt * sigma * mu
    den = 1.0 - 0.5 * dt * sigma * mu

    # Soft clamp to avoid division-by-zero when dt*sigma*mu ~ 2.
    eps = torch.finfo(dtype).eps
    den = torch.where(torch.abs(den) < 100*eps, torch.full_like(den, 100*eps), den)

    amp = num / den  # per-mode amplification (≈ exp(mu*dt))

    uhat = torch.fft.fft(u0)
    hist = [u0.clone()]
    for _ in range(nt):
        uhat = uhat * amp
        u = torch.fft.ifft(uhat).real
        hist.append(u)
    return torch.stack(hist, dim=0)

# ---------------- Dataset generation ----------------

def generate_dataset(n_samples: int, Lx: float, nx: int, T: float, dt: float, alpha: float, sigma: float,
                     outfile: str, device: str | None, seed: int | None, u_scale: float) -> None:
    if not (0.0 < alpha <= 2.0):
        raise ValueError("alpha must be in (0, 2].")
    if not (0.0 < sigma <= 1.0):
        raise ValueError("sigma must be in (0, 1].")
    if seed is not None:
        np.random.seed(seed); torch.manual_seed(seed)

    dev = torch.device(device) if device is not None else (
        torch.device("mps") if torch.backends.mps.is_available() else
        torch.device("cuda" if torch.cuda.is_available() else "cpu")
    )
    dtype = torch.float32
    dx = Lx / nx
    nt = int(round(T / dt))
    x = torch.linspace(0.0, Lx, steps=nx, device=dev, dtype=dtype)
    t = torch.linspace(0.0, T, steps=nt + 1, device=dev, dtype=dtype)

    cov = build_covariance(nx, dx)
    u0s = sample_initial_conditions(n_samples, cov, dev, seed) * u_scale

    sims = []
    for i in tqdm(range(n_samples), desc=f"Simulating [FracBD-CN α={alpha}, σ={sigma}]"):
        U = simulate_frac_bd_cn(u0s[i], nt, Lx, dt, alpha, sigma)
        sims.append(U.cpu())
    data = torch.stack(sims, dim=0)

    np.savez_compressed(
        outfile,
        data=data.numpy().astype(np.float32),
        x=x.cpu().numpy().astype(np.float32),
        t=t.cpu().numpy().astype(np.float32),
        Lx=np.float32(Lx), nx=np.int32(nx), T=np.float32(T), dt=np.float32(dt),
        pde="frac_bd_cn", alpha=np.float32(alpha), sigma=np.float32(sigma),
    )
    print(f"Saved Fractional-BD (CN) dataset '{outfile}' with shape data={tuple(data.shape)}")

# ---------------- CLI ----------------

def _cli():
    p = argparse.ArgumentParser(description="Fractional backward diffusion 1D generator (periodic FD symbol + Crank–Nicolson).")
    p.add_argument("-n", "--n_samples", type=int, default=8)
    p.add_argument("--Lx", type=float, default=(1/4)*np.pi)
    p.add_argument("--nx", type=int, default=256)
    p.add_argument("--T", type=float, default=1.0)
    p.add_argument("--dt", type=float, default=2e-3)
    p.add_argument("--alpha", type=float, default=1.2)
    p.add_argument("--sigma", type=float, default=0.5, help="Global growth scale in (0,1]; smaller slows blow-up.")
    p.add_argument("-o", "--outfile", type=str, default="fracbd_cn_1d_dataset.npz")
    p.add_argument("--device", type=str, default=None)
    p.add_argument("--seed", type=int, default=None)
    p.add_argument("--u_scale", type=float, default=1.0)
    args = p.parse_args()
    generate_dataset(**vars(args))

if __name__ == "__main__":
    _cli()
