"""
Navier-Stokes equation with noise

w_t + u w_x + v w_y = visc (w_xx + w_yy) + f

w' = w + noise

w'_t + u w'_x + v w'_y = visc (w'_xx + w'_yy) + f'

f' = (u - delta_u) noise_x + (v - delta_v) noise_y - visc (noise_xx + noise_yy) + f - delta_u * w_x - delta_v * w_y

delta_u = (delta_phi)_y
delta_v = - (delta_phi)_x

Δ delta_phi = noise
"""

import torch
import h5py
import math
import numpy as np
import matplotlib.pyplot as plt
import imageio
import time
import os

def load(data_path):
    """
    Load data from an HDF5 file.
    """
    with h5py.File(data_path, 'r') as f:
        a = f['a'][:]
        u = f['u'][:]
        t = f['t'][:]
        if 'f' in f:
            f = f['f'][:]
            print("[load] Loaded data from", data_path, "with keys a, u, t, f")
            print("a", a.shape, "u", u.shape, "t", t.shape, "f", f.shape)
            return a, u, t, f
        else:
            print("[load] Loaded data from", data_path, "with keys a, u, t")
            print("a", a.shape, "u", u.shape, "t", t.shape)
            return a, u, t

def get_noise(data, noise_level=0.01, seed=0):
    """
    Generate Gaussian noise for the data.
    """
    torch.manual_seed(seed)
    scale = data.abs().max().item()
    noise_level = noise_level * scale
    noise = noise_level * torch.randn_like(data)
    return noise

def get_random_subu(a, u, scale=1e-2, seed=0):
    """
    Randomly shuffle and scale input data.
    """
    N, S, S, T = u.shape
    torch.manual_seed(seed)
    idx = torch.randperm(N)
    a = scale * a[idx]
    u = scale * u[idx]
    return a, u

def solve_dft(sequence, dt):
    """
    Compute time derivative using finite differences.

    sequence: (N, S, S, T+1) time sequence
    dt: time step size
    return: (N, S, S, T) time derivative sequence
    """
    N, S, S, T = sequence.shape
    device = sequence.device
    dft = torch.zeros((N, S, S, T-1), device=device)
    for t in range(0, T-1):
        dft[:, :, :, t] = (sequence[:, :, :, t+1] - sequence[:, :, :, t]) / dt
    return dft

def get_f(w, noise, f, visc, device):
    """
    Compute the new source term f' based on the noisy solution.
    """
    N, S, S, T = w.shape
    k_max = math.floor(S / 2.0)
    ky = torch.cat((torch.arange(0, k_max, device=device, dtype=torch.float64),
                    torch.arange(-k_max, 0, device=device, dtype=torch.float64)), 0).repeat(S, 1)
    kx = ky.transpose(0, 1)
    lap = 4 * (math.pi ** 2) * (kx ** 2 + ky ** 2)
    lap[0, 0] = 1.0
    lap = lap.unsqueeze(0).unsqueeze(-1).to(device, dtype=torch.float64)
    kx = kx.unsqueeze(0).unsqueeze(-1).to(device, dtype=torch.float64)
    ky = ky.unsqueeze(0).unsqueeze(-1).to(device, dtype=torch.float64)

    noise_ft = torch.fft.fft2(noise)
    delta_phi_ft = noise_ft / (-lap)
    delta_phi_ft[:, 0, 0, :] = 0.0

    delta_u_ft = 1j * 2 * math.pi * ky * delta_phi_ft
    delta_v_ft = -1j * 2 * math.pi * kx * delta_phi_ft
    delta_u = torch.fft.ifft2(delta_u_ft).real
    delta_v = torch.fft.ifft2(delta_v_ft).real

    w_ft = torch.fft.fft2(w)
    phi_ft = w_ft / (-lap)
    phi_ft[:, 0, 0, :] = 0.0
    u_ft = 1j * 2 * math.pi * ky * phi_ft
    v_ft = -1j * 2 * math.pi * kx * phi_ft
    u = torch.fft.ifft2(u_ft).real
    v = torch.fft.ifft2(v_ft).real
    w_x_ft = 1j * 2 * math.pi * kx * w_ft
    w_y_ft = 1j * 2 * math.pi * ky * w_ft
    w_x = torch.fft.ifft2(w_x_ft).real
    w_y = torch.fft.ifft2(w_y_ft).real

    F_h = delta_u * w_x + delta_v * w_y
    F_h_ft = torch.fft.fft2(F_h)

    noise_x_ft = 1j * 2 * math.pi * kx * noise_ft
    noise_y_ft = 1j * 2 * math.pi * ky * noise_ft
    noise_xxyy_ft = -lap * noise_ft

    f_new_ft = (u_ft - delta_u_ft) * noise_x_ft + (v_ft - delta_v_ft) * noise_y_ft - visc * noise_xxyy_ft + torch.fft.fft2(f) - F_h_ft
    f_new = torch.fft.ifft2(f_new_ft).real
    return f_new

def get_data(data, noise_level=0.01, visc=1e-5, device='cpu',seed = 0):
    a, u, t, f = data[0], data[1], data[2], data[3] if len(data) == 4 else None
    a = torch.tensor(a, device=device, dtype=torch.float64)
    u = torch.tensor(u, device=device, dtype=torch.float64)
    t = torch.tensor(t, device=device, dtype=torch.float64)
    dt = t[0]
    if f is not None:
        f = torch.tensor(f, device=device, dtype=torch.float64)
    N, S, S, T = u.shape

    suba, subu = get_random_subu(a, u, scale=noise_level, seed=seed)
    a = a + suba
    u = u + subu
    if f is not None:
        f_ = get_f(u, subu, f, visc, device)
        sequence = torch.cat((suba.unsqueeze(-1), subu), dim=-1)  # (N, S, S, T+1)
        dft = solve_dft(sequence, dt)  # (N, S, S, T)
        f = f_ + dft
    noise = get_noise(a, noise_level,seed)
    a_noisy = a + noise
    T2 = u.shape[-1]
    noise = noise.unsqueeze(-1).repeat(1, 1, 1, T2)
    # print("noise shape", noise.shape)
    if f is not None:
        f_new = get_f(u, noise, f, visc, device)
    else:
        f_new = None
    u_noisy = u + noise
    return a_noisy, u_noisy, t, f_new

def save_h5(a, u, t, f, path):
    """
    Save data to an HDF5 file.
    """
    with h5py.File(path, 'w') as p:
        p.create_dataset('a', data=a.numpy())
        p.create_dataset('u', data=u.numpy())
        p.create_dataset('t', data=t.numpy())
        if f is not None:
            p.create_dataset('f', data=f.numpy())
    print("[save] Data saved to", path)

def visualize_to_gif(a, u, f, save_path="visualization.gif", fps=10):
    """
    Visualize a, u, f and save as a GIF.
    """
    frames = []
    T = u.shape[-1]
    N = 0

    for t in range(T):
        fig, axes = plt.subplots(1, 3, figsize=(12, 4))

        ax = axes[0]
        im = ax.imshow(a[N], cmap="viridis", origin="lower")
        ax.set_title("a")
        ax.axis("off")
        fig.colorbar(im, ax=ax)

        ax = axes[1]
        im = ax.imshow(u[N, :, :, t], cmap="viridis", origin="lower")
        ax.set_title(f"u (t={t})")
        ax.axis("off")
        fig.colorbar(im, ax=ax)

        ax = axes[2]
        if f is not None:
            im = ax.imshow(f[N, :, :, t], cmap="viridis", origin="lower")
            ax.set_title(f"f (t={t})")
            ax.axis("off")
            fig.colorbar(im, ax=ax)

        plt.tight_layout()
        fig.canvas.draw()
        frame = np.asarray(fig.canvas.buffer_rgba())[:, :, :3]
        frames.append(frame)
        plt.close(fig)

    imageio.mimsave(save_path, frames, fps=fps, loop=0)
    print(f"GIF saved to {save_path}")

if __name__ == "__main__":
    data_path = "path/to/your/data.h5"
    a, u, t, f = load(data_path)
    print("a", a.shape, "u", u.shape, "t", t.shape, "f", f.shape if f is not None else None)
    N_per = 100
    batch = 100
    seed = 42
    a_all = np.zeros((N_per * batch, a.shape[1], a.shape[2]))
    u_all = np.zeros((N_per * batch, u.shape[1], u.shape[2], u.shape[3]))
    if f is not None:
        f_all = np.zeros((N_per * batch, f.shape[1], f.shape[2], f.shape[3]))
    a = a[:N_per]
    u = u[:N_per]
    if f is not None:
        f = f[:N_per]
    print("min max a", a.min(), a.max(), "min max u", u.min(), u.max())
    data = (a, u, t, f) if f is not None else (a, u, t)
    start = time.time()
    for i in range(batch):
        seed = 42 + i
        a_noisy, u_noisy, t, f_new = get_data(data, noise_level=1e-3, visc=1e-4, device='cpu', seed=seed)
        a_all[i * N_per:(i + 1) * N_per] = a_noisy.numpy()
        u_all[i * N_per:(i + 1) * N_per] = u_noisy.numpy()
        if f is not None and f_new is not None:
            f_all[i * N_per:(i + 1) * N_per] = f_new.numpy()
        if i % 10 == 0:
            print("[metric] a rel error", np.linalg.norm(a_noisy - a) / np.linalg.norm(a))
            print("[metric] u rel error", np.linalg.norm(u_noisy - u) / np.linalg.norm(u))
            if f is not None and f_new is not None:
                print("[metric] f rel error", np.linalg.norm(f_new - f) / np.linalg.norm(f))
    end = time.time()
    print("time per batch", (end - start) / batch)
    a_noisy = torch.tensor(a_all, dtype=torch.float64)
    u_noisy = torch.tensor(u_all, dtype=torch.float64)
    t = t.clone().detach().to(torch.float64)
    if f is not None:
        f_new = torch.tensor(f_all, dtype=torch.float64)
    else:
        f_new = None
    print("a_noisy", a_noisy.shape, "u_noisy", u_noisy.shape, "t", t.shape)
    if f_new is not None:
        print("f_new", f_new.shape)
    visualize_to_gif(a_noisy.numpy(), u_noisy.numpy(), f_new.numpy() if f_new is not None else None, save_path="visualization_noisy.gif", fps=10)
    save_h5(a_noisy, u_noisy, t, f_new, "path/to/save/noisy_data.h5")