"""
Korteweg-de Vries (KdV) equation with noise

u_t + alpha*u*u_x + beta*u_xxx = f(x,t)
alpha = -1.0, beta = -1.0
u'_t + alpha*u'*u'_x + beta*u'_xxx = f'(x,t)
f' = f + noise_t + beta*noise_xxx + alpha*(noise_x*u + u_x*noise + noise*noise_x)
u' = u + noise
"""

import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
import imageio
from matplotlib.animation import FuncAnimation
import os
import math

def load_kdv(datapath):
    """
    Load KdV data from an HDF5 file.
    """
    with h5py.File(datapath, 'r') as p:
        input = p['input'][:]
        output = p['output'][:]
        f = p['f'][:]
        t = p['time'][:]
        print("[load] Loaded data from", datapath, "with keys: input, output, f, t")
        print("input", input.shape, "output", output.shape, "f", f.shape, "t", t.shape)
        return input, output, f, t

def get_noise(data, noise_level=0.01, seed=0):
    """
    Generate Gaussian noise for the data.
    """
    torch.manual_seed(seed)
    scale = data.abs().max().item()
    noise_level = noise_level * scale
    noise = noise_level * torch.randn_like(data)
    return noise

def get_random_subu(a, u, scale=1e-2, seed=0):
    """
    Randomly shuffle and scale input data.
    """
    N, S, T = u.shape
    torch.manual_seed(seed)
    idx = torch.randperm(N)
    a = scale * a[idx]
    u = scale * u[idx]
    return a, u

def solve_dft(sequence, dt):
    """
    Compute time derivative using finite differences.

    sequence: (N, S, T+1) time sequence
    dt: time step size
    return: (N, S, T) time derivative sequence
    """
    N, S, T = sequence.shape
    device = sequence.device
    dft = torch.zeros((N, S, T-1), device=device)
    for t in range(1, T):
        dft[:, :, t-1] = (sequence[:, :, t] - sequence[:, :, t-1]) / dt
    return dft

def get_f(w, noise, f, alpha, beta, device):
    """
    Compute the new source term f' based on the noisy solution.
    """
    N, S, T = w.shape
    k_max = math.floor(S / 2.0)
    kx = torch.cat((torch.arange(0, k_max, device=device, dtype=torch.float64),
                    torch.arange(-k_max, 0, device=device, dtype=torch.float64)), 0).unsqueeze(0).unsqueeze(-1)
    noise_hat = torch.fft.fft(noise, dim=1)
    noise_x_hat = 1j * kx * noise_hat
    noise_xxx_hat = -(1j * kx) ** 3 * noise_hat
    noise_x = torch.fft.ifft(noise_x_hat, dim=1).real
    noise_xxx = torch.fft.ifft(noise_xxx_hat, dim=1).real
    w_x = torch.fft.ifft(1j * kx * torch.fft.fft(w, dim=1), dim=1).real
    f_new = f + beta * noise_xxx + alpha * (noise_x * w + w_x * noise + noise * noise_x)
    return f_new

def get_data(data, noise_level=0.01, beta=-1.0, alpha=-1.0, device='cpu', seed=0):
    """
    Add noise to the data and compute the new source term f'.
    """
    a, u, t, f = data[0], data[1], data[2], data[3] if len(data) == 4 else None
    a = torch.tensor(a, device=device, dtype=torch.float64)
    u = torch.tensor(u, device=device, dtype=torch.float64)
    t = torch.tensor(t, device=device, dtype=torch.float64)
    dt = (t[1] - t[0])
    if f is not None:
        f = torch.tensor(f, device=device, dtype=torch.float64)
    N, S, T = u.shape

    suba, subu = get_random_subu(a, u, scale=noise_level, seed=seed)
    a = a + suba
    u = u + subu
    if f is not None:
        f_ = get_f(u, subu, f, beta=beta, alpha=alpha, device=device)
        sequence = torch.cat((suba.unsqueeze(-1), subu), dim=-1)
        dft = solve_dft(sequence, dt)
        f = f_ + dft
    noise = get_noise(a, noise_level, seed)
    a_noisy = a + noise
    noise = noise.unsqueeze(-1).repeat(1, 1, T)
    if f is not None:
        f_new = get_f(u, noise, f, beta=beta, alpha=alpha, device=device)
    else:
        f_new = None
    u_noisy = u + noise
    return a_noisy, u_noisy, t, f_new

def save_data(datapath, data):
    """
    Save data to an HDF5 file.
    """
    a, u, f, t = data
    with h5py.File(datapath, 'w') as p:
        p.create_dataset('input', data=a)
        p.create_dataset('output', data=u)
        if f is not None:
            p.create_dataset('f', data=f)
        p.create_dataset('t', data=t)
    print("[save] Data saved to", datapath)

def vis(input, output, f, save_path='.', idx=0):
    """
    Visualize input, output, and source term f as a GIF.
    """
    os.makedirs(save_path, exist_ok=True)
    x = np.linspace(0, 1, input.shape[-1])
    steps = output.shape[1]
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 6))
    plt.tight_layout()
    frames = []

    def init_left():
        ax1.cla()
        ax1.set_title('Input u(x,0)')
        ax1.plot(x, input[idx], color='blue')
        ax1.set_ylabel('u(x,0)')
        ax1.set_xlabel('x')
        ax1.set_ylim(np.min(input)-0.1, np.max(input)+0.1)
        return ax1

    def update_right(t):
        ax2.cla()
        ax2.set_title(f'Output u(x,t) at t={t}')
        ax2.plot(x, output[idx, t], color='red')
        ax2.set_ylabel('u(x,t)')
        ax2.set_xlabel('x')
        ax2.set_ylim(np.min(output)-0.1, np.max(output)+0.1)
        return ax2

    def update_f(t):
        ax3.cla()
        ax3.set_title(f'Source term f(x,t) at t={t}')
        ax3.plot(x, f[idx, t], color='green')
        ax3.set_ylabel('f(x,t)')
        ax3.set_xlabel('x')
        ax3.set_ylim(np.min(f)-0.1, np.max(f)+0.1)
        return ax3

    for t in range(steps):
        init_left()
        update_right(t)
        update_f(t)
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        frames.append(image)
    plt.close(fig)
    gif_path = os.path.join(save_path, 'visualization.gif')
    imageio.mimsave(gif_path, frames, fps=10)
    print(f'[save] GIF saved to {gif_path}')

if __name__ == "__main__":
    data_path = "path/to/your/data.h5"
    a, u, f, t = load_kdv(data_path)
    noise_level = 0.001
    a_noisy, u_noisy, t, f_new = get_data((a, u, t, f), noise_level=noise_level, beta=-1.0, alpha=-1.0)
    save_path = "path/to/save/noisy_data.h5"
    save_data(save_path, (a_noisy.numpy(), u_noisy.numpy(), f_new.numpy() if f_new is not None else None, t.numpy()))