"""
Mixup for 2D Navier-Stokes data.

This script performs mixup on 2D Navier-Stokes data, including the initial condition `a`, 
velocity field `u`, and forcing term `f`. The mixed data is saved to an HDF5 file.
"""

import torch
import h5py
import math
import numpy as np
import time

def load(data_path):
    """
    Load data from an HDF5 file.
    """
    with h5py.File(data_path, 'r') as f:
        a = f['a'][:]
        u = f['u'][:]
        t = f['t'][:]
        if 'f' in f:
            f = f['f'][:]
            print("[load] Loaded data from", data_path, "with keys a, u, t, f")
            print("a", a.shape, "u", u.shape, "t", t.shape, "f", f.shape)
            return a, u, t, f
        else:
            print("[load] Loaded data from", data_path, "with keys a, u, t")
            print("a", a.shape, "u", u.shape, "t", t.shape)
            return a, u, t

def reconstruct_forcing_cn(a, u, delta_t, visc, device=None):
    """
    Reconstruct the forcing term `f` using Crank-Nicolson scheme.

    Args:
        a: (N, s, s) Initial condition.
        u: (N, s, s, steps) Velocity field snapshots.
        delta_t: Time step size.
        visc: Viscosity (scalar or tensor).
        device: Device to perform computation on.

    Returns:
        f_rec: Reconstructed forcing term.
    """
    if device is None:
        device = u.device
    device = torch.device(device)

    a = a.to(device, dtype=torch.float64)
    u = u.to(device, dtype=torch.float64)
    N, s, _, steps = u.shape

    # Viscosity
    if isinstance(visc, (int, float)):
        visc_vec = torch.full((N,), float(visc), device=device, dtype=torch.float64)
    else:
        visc_vec = torch.as_tensor(visc, device=device, dtype=torch.float64).view(-1)
    visc_view = visc_vec.view(N, 1, 1)

    # Frequency and Laplacian
    k_max = s // 2
    ky = torch.cat((torch.arange(0, k_max, device=device, dtype=torch.float64),
                    torch.arange(-k_max, 0, device=device, dtype=torch.float64)), 0).repeat(s, 1)
    kx = ky.t()
    lap = 4 * (math.pi**2) * (kx**2 + ky**2)
    lap[0, 0] = 1.0
    lap = lap.to(device, dtype=torch.float64)

    f_rec = torch.zeros(N, s, s, steps, device=device, dtype=torch.float64)

    # Precompute (1 ± 0.5 * dt * ν * L) in Fourier space
    A = 1.0 + 0.5 * delta_t * visc_view * lap
    B = 1.0 - 0.5 * delta_t * visc_view * lap

    w_prev = a
    for n in range(steps):
        w_curr = u[..., n]

        # Fourier transform of w^{n}
        w_prev_h = torch.fft.fftn(w_prev, dim=(1, 2))

        # Velocity field (based on w^{n})
        psi_h = w_prev_h / lap
        psi_h[..., 0, 0] = 0.0
        u_vel = torch.fft.ifftn(1j * 2 * math.pi * ky * psi_h, dim=(1, 2)).real
        v_vel = torch.fft.ifftn(-1j * 2 * math.pi * kx * psi_h, dim=(1, 2)).real

        # Gradients of w^{n}
        w_x = torch.fft.ifftn(1j * 2 * math.pi * kx * w_prev_h, dim=(1, 2)).real
        w_y = torch.fft.ifftn(1j * 2 * math.pi * ky * w_prev_h, dim=(1, 2)).real

        # Nonlinear term F^{n}
        F_n = u_vel * w_x + v_vel * w_y
        F_n_h = torch.fft.fftn(F_n, dim=(1, 2))

        # Fourier transform of w^{n+1}
        w_curr_h = torch.fft.fftn(w_curr, dim=(1, 2))

        # Compute f^{n}
        f_n_h = (A * w_curr_h - B * w_prev_h + delta_t * F_n_h) / delta_t
        f_n = torch.fft.ifftn(f_n_h, dim=(1, 2)).real
        f_rec[..., n] = f_n

        w_prev = w_curr

    return f_rec

def mixup2d(a, u, delta_t, visc, alpha=0.2, device=None, seed=42):
    """
    Perform mixup on 2D data.

    Args:
        a: (N, s, s) Initial condition.
        u: (N, s, s, steps) Velocity field snapshots.
        delta_t: Time step size.
        visc: Viscosity.
        alpha: Mixup parameter.
        device: Device to perform computation on.
        seed: Random seed.

    Returns:
        a_mix: Mixed initial condition.
        u_mix: Mixed velocity field.
        f_mix: Mixed forcing term.
    """
    N, S, _, T = u.shape
    assert a.shape == (N, S, S)

    if device is None:
        device = u.device
    device = torch.device(device)
    torch.manual_seed(seed)

    # Sample mixup weights
    weights = torch.distributions.Beta(alpha, alpha).sample((N,)).to(device)
    weights = weights / (weights.sum() + 1e-12)
    weights_u = weights.view(N, 1, 1, 1)
    weights_a = weights.view(N, 1, 1)

    # Perform mixup
    a_mix = torch.sum(a.to(device) * weights_a, dim=0, keepdim=True).repeat(N, 1, 1)
    u_mix = torch.sum(u.to(device) * weights_u, dim=0, keepdim=True).repeat(N, 1, 1, 1)

    # Reconstruct forcing term
    f_mix = reconstruct_forcing_cn(a_mix, u_mix, delta_t, visc, device=device)
    return a_mix, u_mix, f_mix

def main():
    """
    Main function to perform mixup and save the mixed data.
    """
    data_path = "path/to/your/data.h5"
    a, u, t, f = load(data_path)
    delta_t = t[1] - t[0]
    visc = 1e-4
    print("delta_t", delta_t, "visc", visc)

    a_t = torch.tensor(a)
    u_t = torch.tensor(u)
    f_t = torch.tensor(f)
    N_per = 100
    batch = 10
    seed = 42

    a_t = a_t[:N_per]
    u_t = u_t[:N_per]
    print("a_t", a_t.shape)
    print("u_t", u_t.shape)

    f_mix_all = torch.zeros(N_per * batch, *f_t.shape[1:], dtype=torch.float64)
    a_mix_all = torch.zeros(N_per * batch, *a_t.shape[1:], dtype=torch.float64)
    u_mix_all = torch.zeros(N_per * batch, *u_t.shape[1:], dtype=torch.float64)

    start_time = time.time()
    for b in range(batch):
        print(f"Mixup batch {b+1}/{batch}")
        a_mix, u_mix, f_mix = mixup2d(a_t, u_t, delta_t, visc, alpha=0.2, device="cpu", seed=seed + b)
        a_mix_all[b * N_per:(b + 1) * N_per] = a_mix.cpu()
        u_mix_all[b * N_per:(b + 1) * N_per] = u_mix.cpu()
        f_mix_all[b * N_per:(b + 1) * N_per] = f_mix.cpu()
    end_time = time.time()

    print(f"Mixup completed in {end_time - start_time:.2f} seconds")
    print("a_mix_all", a_mix_all.shape)
    print("u_mix_all", u_mix_all.shape)
    print("f_mix_all", f_mix_all.shape)

    out_path = "path/to/save/mixed_data.h5"
    with h5py.File(out_path, "w") as hf:
        hf.create_dataset("a", data=a_mix_all.numpy())
        hf.create_dataset("u", data=u_mix_all.numpy())
        hf.create_dataset("f", data=f_mix_all.numpy())
        hf.create_dataset("t", data=t)
    print("Saved mixed data to", out_path)

if __name__ == "__main__":
    main()