
"""
ks_rewrite.py

Conventions mirror KS.m:
- u: initial condition on a periodic grid of length L
- L: domain length (called "l" in MATLAB)
- T: final time
- N: number of snapshots to record (uniformly spaced in time)
- h: internal time step

The full solver implements ETDRK4 with M=64 contour points,
as in the original MATLAB code.
"""

from typing import Tuple
import numpy as np
from scipy.fft import fft, ifft
import time
from pdb import set_trace as bp
import h5py
import argparse
from src import utils
import matplotlib.pyplot as plt


def _wavenumbers(s: int, L: float) -> np.ndarray:
    k = (2*np.pi/L) * np.concatenate([np.arange(0, s//2), [0], -np.arange(s//2-1, 0, -1)])
    return k.reshape(-1, 1)  # column vector like MATLAB


def KS_full(u: np.ndarray, L: float, T: float, N: int, h: float) -> Tuple[np.ndarray, np.ndarray]:
    """
    Full Kuramoto–Sivashinsky:
        u_t + u_xx + u_xxxx + 0.5 (u^2)_x = 0
    Pseudospectral ETDRK4 rewrite of KS.m (M=64).
    """
    u = np.asarray(u).reshape(-1)
    s = u.size
    v = np.fft.fft(u)
    k = _wavenumbers(s, L)                     # (s,1)
    Lop = k**2 - k**4                          # (s,1)
    E  = np.exp(h*Lop)
    E2 = np.exp(h*Lop/2)
    M  = 64
    r  = np.exp(1j*np.pi*((np.arange(1, M+1)-0.5)/M))
    LR = h*Lop + r.reshape(1, -1)              # broadcasting to (s,M)

    Q  = h*np.real(np.mean((np.exp(LR/2)-1)/LR, axis=1, keepdims=True))
    f1 = h*np.real(np.mean((-4-LR + np.exp(LR)*(4-3*LR+LR**2))/LR**3, axis=1, keepdims=True))
    f2 = h*np.real(np.mean(( 2+LR + np.exp(LR)*(-2+LR))/LR**3,       axis=1, keepdims=True))
    f3 = h*np.real(np.mean((-4-3*LR-LR**2 + np.exp(LR)*(4-LR))/LR**3,axis=1, keepdims=True))

    uu = np.zeros((N, s), dtype=float)
    tt = np.zeros(N, dtype=float)

    nmax = int(round(T/h))
    nrec = max(1, int(np.floor((T/N)/h)))      # every ≈ T/N seconds
    g = -0.5j*k

    q = 0
    for n in range(1, nmax+1):
        t = n*h
        Nv = g * np.fft.fft(np.real(np.fft.ifft(v))**2).reshape(-1,1)
        a  = E2*v.reshape(-1,1) + Q*Nv
        Na = g * np.fft.fft(np.real(np.fft.ifft(a.ravel()))**2).reshape(-1,1)
        b  = E2*v.reshape(-1,1) + Q*Na
        Nb = g * np.fft.fft(np.real(np.fft.ifft(b.ravel()))**2).reshape(-1,1)
        c  = E2*a + Q*(2*Nb - Nv)
        Nc = g * np.fft.fft(np.real(np.fft.ifft(c.ravel()))**2).reshape(-1,1)
        v  = (E*v.reshape(-1,1) + Nv*f1 + 2*(Na+Nb)*f2 + Nc*f3).ravel()

        if n % nrec == 0 and q < N:
            u = np.real(np.fft.ifft(v))
            uu[q, :] = u
            tt[q] = t
            q += 1

    if q < N:   # pad if rounding left some slots
        uu[q:, :] = uu[q-1, :]
        tt[q:]    = tt[q-1]
    return uu, tt

def KS_linear(u: np.ndarray, L: float, T: float, N: int, h: float) -> Tuple[np.ndarray, np.ndarray]:
    """
    Linear-only model:
        u_t + u_xx + u_xxxx = 0
    Exact exponential update in Fourier space.
    """
    u = np.asarray(u).reshape(-1)
    s = u.size
    v = np.fft.fft(u)
    k = _wavenumbers(s, L)
    Lop = k**2 - k**4
    E = np.exp(h*Lop).ravel()

    uu = np.zeros((N, s), dtype=float)
    tt = np.zeros(N, dtype=float)
    nmax = int(round(T/h))
    nrec = max(1, int(np.floor((T/N)/h)))
    q = 0
    for n in range(1, nmax+1):
        v = E * v
        if n % nrec == 0 and q < N:
            uu[q,:] = np.real(np.fft.ifft(v))
            tt[q] = n*h
            q += 1
    if q < N:
        uu[q:, :] = uu[q-1, :]
        tt[q:]    = tt[q-1]
    return uu, tt


def make_initial_condition(s: int, Ldom: float, rng: np.random.Generator,
                           a: float = 0.1, b: float = 0.1,
                           k_1: int = 10, k_2: int = 10) -> np.ndarray:
    """
    # initial condition: sum of a few modes + small noise (stand-in for GRF1.m)
    """
    x = np.linspace(0, Ldom, s, endpoint=False)
    u0 = a*np.cos(k_1*np.pi*x/Ldom) + b*np.cos(k_2*np.pi*x/Ldom) + 0.05*rng.standard_normal(s)
    return u0

def spacetime(uu, tt, Ldom, title, filename):
    fig, ax = plt.subplots(figsize=(9, 4.2))
    im = ax.imshow(
        uu, extent=[0, Ldom, tt[0], tt[-1]],
        aspect='auto', origin='lower'
    )
    ax.set_xlabel("x"); ax.set_ylabel("time"); ax.set_title(title)

    # colorbar
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("u(x,t)")
    cbar.ax.tick_params(labelsize=20)
    fig.tight_layout()
    fig.savefig(f"{filename}.png", dpi=200, bbox_inches="tight")
    # plt.figure(figsize=(9,4.2))
    # plt.imshow(uu, extent=[0, Ldom, tt[0], tt[-1]], aspect='auto', origin='lower')
    # plt.xlabel("x"); plt.ylabel("time"); plt.title(title)
    # plt.tight_layout(); plt.show()
    # plt.savefig(f"{filename}.png", dpi=200, bbox_inches="tight")

def main():
    p = argparse.ArgumentParser(description="Run KS simulations and save to HDF5.")
    p.add_argument("--output_path", type=str, default='data/ks_1d.h5',required=True, help="Output HDF5 filepath.")
    p.add_argument("--model", type=str, default="full", choices=["full", "linear"], help="KS model variant.")
    p.add_argument("--s", type=int, default=512, help="Number of spatial grid points.")
    p.add_argument("--T", type=float, default=30.0, help="Final time.")
    p.add_argument("--N", type=int, default=200, help="Number of snapshots to record.")
    p.add_argument("--num-traj", type=int, default=1, help="How many independent trajectories to simulate.")
    # parameter ranges
    # in-distribution weak nonlinear, chaos grow
    p.add_argument("--a-min", type=float, default=0.05)
    p.add_argument("--a-max", type=float, default=0.20)
    p.add_argument("--b-min", type=float, default=0.05)
    p.add_argument("--b-max", type=float, default=0.20)
    p.add_argument("--k1-min", type=int, default=1)
    p.add_argument("--k1-max", type=int, default=32)
    p.add_argument("--k2-min", type=int, default=1)
    p.add_argument("--k2-max", type=int, default=32)

    # OOD, strong nonlinear, chaos grows
    # p.add_argument("--a-min", type=float, default=1)
    # p.add_argument("--a-max", type=float, default=10)
    # p.add_argument("--b-min", type=float, default=1)
    # p.add_argument("--b-max", type=float, default=10)
    # p.add_argument("--k1-min", type=int, default=1)
    # p.add_argument("--k1-max", type=int, default=32)
    # p.add_argument("--k2-min", type=int, default=1)
    # p.add_argument("--k2-max", type=int, default=32)

    #OOD, linear, chaos decay
    # p.add_argument("--a-min", type=float, default=1)
    # p.add_argument("--a-max", type=float, default=10)
    # p.add_argument("--b-min", type=float, default=1)
    # p.add_argument("--b-max", type=float, default=10)
    # p.add_argument("--k1-min", type=int, default=64)
    # p.add_argument("--k1-max", type=int, default=128)
    # p.add_argument("--k2-min", type=int, default=64)
    # p.add_argument("--k2-max", type=int, default=128)
    args = p.parse_args()

    # Spatial grid
    Ldom = 2*np.pi*32
    x = np.linspace(0, Ldom, args.s, endpoint=False)

    # Choose solver
    solver = KS_full if args.model == "full" else KS_linear

    
    # Run trajectories
    for i in range(args.num_traj):
        rng = np.random.default_rng(i)

        a = rng.uniform(args.a_min, args.a_max)
        b = rng.uniform(args.b_min, args.b_max)
        k1 = rng.integers(args.k1_min, args.k1_max)
        k2 = rng.integers(args.k2_min, args.k2_max)
        start = time.time()
        u0 = make_initial_condition(args.s, Ldom, rng, a=a, b=b, k_1=k1, k_2=k2)
        uu, tt = solver(u0, Ldom, args.T, args.N, h=0.1)  # (N, s), (N,)
        spacetime(uu, tt, Ldom, "Full KS (ETDRK4, M=64)", f"{args.model}_indist")


        # print(time.time() - start)

        # seed_str = str(i).zfill(5)

        # while True:
        #     try:
        #         with h5py.File(utils.expand_path(args.output_path), "a") as data_f:
        #             data_f.create_dataset(
        #                 f"{seed_str}/data",
        #                 data=uu,
        #                 dtype="float32",
        #                 compression="lzf",
        #             )
        #             data_f.create_dataset(
        #                 f"{seed_str}/grid/x",
        #                 data=x,
        #                 dtype="float32",
        #                 compression="lzf",
        #             )

        #             data_f.create_dataset(
        #                 f"{seed_str}/grid/t",
        #                 data=tt,
        #                 dtype="float32",
        #                 compression="lzf",
        #             )
        #     except OSError:
        #         time.sleep(0.1)
        #         continue
        #     else:
        #         break
    # print(time.time() - start)

    print(f"Wrote {args.num_traj} trajectory(ies) to {args.output_path} under group '/{args.model}'.")
    print(f"u shape: {uu.shape}  t shape: {tt.shape}  x shape: {x.shape}")

if __name__ == "__main__":
    main()
