"""
Burgers equation with noise

u_t + u*u_x = nu*u_xx + f

u' = u + noise

u'_t + u'*u'_x = nu*u'_xx + f'

f' = noise_t + u*noise_x + noise*u_x + noise*noise_x - nu*noise_xx + f
"""

import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
import imageio
from matplotlib.animation import FuncAnimation
import os
import math

def load_burgers(datapath):
    with h5py.File(datapath, 'r') as p:
        input = p['input'][:]
        output = p['output'][:]
        f = p['f'][:]
        t = p['time'][:]
        print("[load] Loaded data from", datapath, "with keys: input, output, f, t")
        print("input", input.shape, "output", output.shape, "f", f.shape, "t", t.shape)
        return input, output, f, t

def get_noise(data, noise_level=0.01, seed=0, mode='gaussian',
              perlin_cells=32, multi_sine_k=8):
    """
    Generate noise for data.

    Modes:
      gaussian      Standard Gaussian noise
      perlin        1D Perlin procedural noise (smooth, reused for all leading dimensions)
      multi_sine    Low-frequency sine/cosine superposition (non-Gaussian, smooth)
      random_walk   Random walk noise (cumulative uniform steps, strongly correlated)
      zero          Zero noise
    """
    torch.manual_seed(seed)
    scale = data.detach().abs().amax().clamp_min(1e-8)
    amp = noise_level * scale

    if mode == 'gaussian':
        return amp * torch.randn_like(data)

    L = data.shape[-1]
    device, dtype = data.device, data.dtype

    if mode == 'multi_sine':
        ks = torch.arange(1, multi_sine_k + 1, device=device, dtype=dtype)
        x = torch.linspace(0, 1, L, device=device, dtype=dtype)
        phase = 2 * math.pi * torch.rand(multi_sine_k, device=device, dtype=dtype)
        a = 2 * torch.rand(multi_sine_k, device=device, dtype=dtype) - 1
        b = 2 * torch.rand(multi_sine_k, device=device, dtype=dtype) - 1
        waves = (a.unsqueeze(-1) * torch.sin(2 * math.pi * ks.unsqueeze(-1) * x + phase.unsqueeze(-1))
                 + b.unsqueeze(-1) * torch.cos(2 * math.pi * ks.unsqueeze(-1) * x + phase.unsqueeze(-1)))
        pattern = waves.sum(0)
        pattern = pattern / (pattern.abs().amax() + 1e-8) * amp
        while pattern.dim() < data.dim():
            pattern = pattern.unsqueeze(0)
        return pattern.expand_as(data)

    if mode == 'perlin':
        cells = min(perlin_cells, L - 1)
        x = torch.linspace(0, 1, L, device=device, dtype=dtype)
        t = x * cells
        i0 = torch.clamp(t.floor().long(), max=cells - 1)
        local = t - i0
        fade = local**3 * (local * (local * 6 - 15) + 10)
        g = 2 * torch.rand(cells + 1, device=device, dtype=dtype) - 1
        g0 = g[i0]
        g1 = g[i0 + 1]
        v0 = g0 * local
        v1 = g1 * (local - 1.0)
        pattern = v0 + (v1 - v0) * fade
        pattern = pattern / (pattern.abs().amax() + 1e-8) * amp
        while pattern.dim() < data.dim():
            pattern = pattern.unsqueeze(0)
        return pattern.expand_as(data)

    if mode == 'random_walk':
        steps = 2 * torch.rand(L, device=device, dtype=dtype) - 1
        pattern = torch.cumsum(steps, dim=-1)
        pattern = pattern - pattern.mean()
        pattern = pattern / (pattern.abs().amax() + 1e-8) * amp
        while pattern.dim() < data.dim():
            pattern = pattern.unsqueeze(0)
        return pattern.expand_as(data)

    if mode == 'zero':
        return torch.zeros_like(data)

    raise ValueError(f"Unknown noise mode={mode}, available: gaussian | perlin | multi_sine | random_walk | zero")

def get_random_subu(a, u, scale=1e-2, seed=0):
    """
    Randomly shuffle and scale input data.
    """
    N, S, T = u.shape
    torch.manual_seed(seed)
    idx = torch.randperm(N)
    a = scale * a[idx]
    u = scale * u[idx]
    return a, u

def solve_dft(sequence, dt):
    """
    Compute time derivative using finite differences.

    sequence: (N, S, T+1) time sequence
    dt: time step size
    return: (N, S, T) time derivative sequence
    """
    N, S, T = sequence.shape
    device = sequence.device
    dft = torch.zeros((N, S, T-1), device=device)
    for t in range(1, T):
        dft[:, :, t-1] = (sequence[:, :, t] - sequence[:, :, t-1]) / dt
    return dft

def get_f(w, noise, f, visc, device):
    """
    Compute the new source term f' based on the noisy solution.
    """
    N, S, T = w.shape
    k_max = math.floor(S / 2.0)
    kx = torch.cat((torch.arange(0, k_max, device=device, dtype=torch.float64),
                    torch.arange(-k_max, 0, device=device, dtype=torch.float64)), 0).unsqueeze(0).unsqueeze(-1)
    lap = 4 * (math.pi ** 2) * (kx ** 2)
    lap[:, 0, :] = 1.0
    lap = lap.to(device, dtype=torch.float64)
    ft = torch.fft.fft2(f)
    noise_ft = torch.fft.fft2(noise)
    noise_ft_x = 1j * 2 * math.pi * kx * noise_ft
    noise_ft_x[:, 0, :] = 0.0
    noise_ft_xx = - (2 * math.pi) ** 2 * (kx ** 2) * noise_ft
    noise_ft_xx[:, 0, :] = 0.0
    w_ft = torch.fft.fft2(w)
    w_ft_x = 1j * 2 * math.pi * kx * w_ft
    w_ft_x[:, 0, :] = 0.0
    noise_x = torch.fft.ifft2(noise_ft_x).real
    noise_xx = torch.fft.ifft2(noise_ft_xx).real
    w_x = torch.fft.ifft2(w_ft_x).real
    f_new = f + w * noise_x + noise * w_x + noise * noise_x - visc * noise_xx
    return f_new

def get_data(data, noise_level=0.01, visc=1e-5, device='cpu',seed = 0, noisy_mode='gaussian'):
    a, u, t, f = data[0], data[1], data[2], data[3] if len(data) == 4 else None
    a = torch.tensor(a, device=device, dtype=torch.float64)
    u = torch.tensor(u, device=device, dtype=torch.float64)
    t = torch.tensor(t, device=device, dtype=torch.float64)
    dt = 10*(t[1] - t[0])
    # print("dt", dt.item())
    # dt = 0.5
    if f is not None:
        f = torch.tensor(f, device=device, dtype=torch.float64)
    N, S, T = u.shape

    suba, subu = get_random_subu(a, u, scale=noisy_level, seed=seed)
    # print("min max suba", suba.min().item(), suba.max().item(), "min max subu", subu.min().item(), subu.max().item())
    a = a + suba
    u = u + subu
    if f is not None:
        f_ = get_f(u, subu, f, visc, device)
        sequence = torch.cat((suba.unsqueeze(-1), subu), dim=-1)  # (N, S, T+1)
        # print("sequence == zero", (sequence==0).all().item())
        dft = solve_dft(sequence, dt)  # (N, S, T)
        # print("min max dft", dft.min().item(), dft.max().item())
        f = f_ + dft
    noise = get_noise(a, 0.001,seed, mode=noisy_mode)
    a_noisy = a + noise
    noise = noise.unsqueeze(-1).repeat(1, 1, 1, T)
    if f is not None:
        f_new = get_f(u, noise, f, visc, device)
    else:
        f_new = None
    u_noisy = u + noise
    return a_noisy, u_noisy, t, f_new

def save_data(datapath, data):
    a, u, f, t = data
    with h5py.File(datapath, 'w') as p:
        p.create_dataset('input', data=a)
        p.create_dataset('output', data=u)
        if f is not None:
            p.create_dataset('f', data=f)
        p.create_dataset('t', data=t)
    print("[save] save data to", datapath, "with keys input, output, f, t")


def vis(input, output, f, save_path='.', idx=0):
    """
    input: (N, s)
    output: (N, steps, s)
    f: (N, steps, s)
    """
    os.makedirs(save_path, exist_ok=True)
    x = np.linspace(0, 1, input.shape[-1])
    steps = output.shape[1]
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 6))
    plt.tight_layout()
    frames = []

    def init_left():
        ax1.cla()
        ax1.set_title('Input u(x,0)')
        ax1.plot(x, input[idx], color='blue')
        ax1.set_ylabel('u(x,0)')
        ax1.set_xlabel('x')
        ax1.set_ylim(np.min(input)-0.1, np.max(input)+0.1)
        return ax1

    def update_right(t):
        ax2.cla()
        ax2.set_title(f'Output u(x,t) at t={t}')
        ax2.plot(x, output[idx, t], color='red')
        ax2.set_ylabel('u(x,t)')
        ax2.set_xlabel('x')
        ax2.set_ylim(np.min(output)-0.1, np.max(output)+0.1)
        return ax2

    def updata_f(t):
        ax3.cla()
        ax3.set_title(f'Source term f(x,t) at t={t}')
        ax3.plot(x, f[idx, t], color='green')
        ax3.set_ylabel('f(x,t)')

        ax3.set_xlabel('x')
        ax3.set_ylim(np.min(f)-0.1, np.max(f)+0.1)
        return ax3
    # 生成所有帧
    for t in range(steps):
        init_left()
        update_right(t)
        updata_f(t)
        # 保存当前帧到列表
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        frames.append(image)
    plt.close(fig)
    # 保存为 GIF
    gif_path = os.path.join(save_path, 'visualization.gif')
    imageio.mimsave(gif_path, frames, fps=10)  
    print(f'GIF {gif_path}')
if __name__ == "__main__":
    data_path =""
    a, u, f, t = load_burgers(data_path)
    u = np.transpose(u, (0, 2, 1))  # (N, s, steps)
    f = np.transpose(f, (0, 2, 1))  # (N, s, steps)
    print("a shape:", a.shape)
    print("u shape:", u.shape)
    print("f shape:", f.shape)
    print("t shape:", t.shape)
    # print(t)
    noisy_mode = 'random_walk' # 'gaussian' | 'perlin' | 'multi_sine' | 'zero'
    N_per = 500
    batch = 2
    # batch2 = 10
    seed = 42
    a_all = np.zeros((N_per*batch, a.shape[1]))
    u_all = np.zeros((N_per*batch, u.shape[1], u.shape[2]))
    if f is not None:
        f_all = np.zeros((N_per*batch, f.shape[1], f.shape[2]))
    a = a[:N_per]
    u = u[:N_per]
    if f is not None:
        f = f[:N_per]
    print("min max a", a.min(), a.max(), "min max u", u.min(), u.max())
    data = (a, u, t, f) if f is not None else (a, u, t)
    import time 
    noisy_level = 1e-3
    start = time.time()
    for i in range(batch):
        seed = 42 + i
        a_noisy, u_noisy, t, f_new = get_data(data, noise_level=noisy_level, visc=1e-3, device='cpu', seed=seed, noisy_mode=noisy_mode)
        a_all[i * N_per:(i + 1) * N_per] = a_noisy.numpy()
        u_all[i * N_per:(i + 1) * N_per] = u_noisy.numpy()
        if f is not None and f_new is not None:
            f_all[i * N_per:(i + 1) * N_per] = f_new.numpy()
        if i % 10 == 0:
            print("[metric] a rel error", np.linalg.norm(a_noisy - a) / np.linalg.norm(a))
            print("[metric] u rel error", np.linalg.norm(u_noisy - u) / np.linalg.norm(u))
            if f is not None and f_new is not None:
                print("[metric] f rel error", np.linalg.norm(f_new - f) / np.linalg.norm(f))
    end = time.time()
    print(start - end)
    a_noisy = torch.tensor(a_all, dtype=torch.float64)
    u_noisy = torch.tensor(u_all, dtype=torch.float64)
    t = t.clone().detach().to(torch.float64)
    if f is not None:
        f_new = torch.tensor(f_all, dtype=torch.float64)
    else:
        f_new = None
    u_noisy = u_noisy.permute(0, 2, 1)  # (N, steps, s)
    if f_new is not None:
        f_new = f_new.permute(0, 2, 1)
    print("a_noisy", a_noisy.shape, "u_noisy", u_noisy.shape, "t", t.shape)
    if f_new is not None:
        print("f_new", f_new.shape)
    # random choose 1000 for save
    indices = np.random.choice(a_noisy.shape[0], size=1000, replace=False)
    a_noisy = a_noisy[indices]
    u_noisy = u_noisy[indices]
    if f_new is not None:
        f_new = f_new[indices]
    print("after choose 1000, a_noisy", a_noisy.shape, "u_noisy", u_noisy.shape)
    print("min max a_noisy", a_noisy.min().item(), a_noisy.max().item(), "min max u_noisy", u_noisy.min().item(), u_noisy.max().item())
    if f_new is not None:
        print("min max f_new", f_new.min().item(), f_new.max().item())
    vis(a_noisy.numpy(), u_noisy.numpy(), f_new.numpy() if f_new is not None else None, save_path='.', idx=0)
    datapath = ""
    save_data(datapath, (a_noisy.numpy(), u_noisy.numpy(), f_new.numpy() if f_new is not None else None, t.numpy()))
    print(f"noisy data saved to {datapath}")