# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os

import h5py
import numpy as np
import torch
from joblib import Parallel, delayed
from phi.flow import (  # SoftGeometryMask,; Sphere,; batch,; tensor,
    Box,
    CenteredGrid,
    Noise,
    StaggeredGrid,
    advect,
    diffuse,
    extrapolation,
    fluid,
    Field,
    UniformGrid,
    tensor,
    spatial
)
from phi.math import reshaped_native
from phi.math import seed as phi_seed
from tqdm import tqdm
from skimage.transform import resize
import matplotlib.pyplot as plt

from pdearena import utils

from .pde import PDEConfig

logger = logging.getLogger(__name__) 
# -- (nx, ny, nt)
# GT: (256,256,256)
# FINE: (128,128,64)
# MEDIUM: (64,64,32)
# COARSE: (32,32,16)
# XCOARSE: (16,16,8)
RESOLUTION_FACTOR = {
    "xcoarse": 16,
    "coarse": 8,
    "medium": 4,
    "fine": 2,
    "gt": 1
}

# The following two functions are used to ensure the initial conditions are the same for all resolutions.
def generate_noise(resolution, scale, smoothness):
    """
    Generate a 2D noise array based on given scale and smoothness.

    Args:
        resolution (tuple): Resolution of the grid (height, width).
        scale (float): Size of noise fluctuations in physical units.
        smoothness (float): Determines how quickly high frequencies die out.

    Returns:
        numpy.ndarray: A 2D array of generated noise.
    """
    # Create random noise in the Fourier domain
    rnd_real = np.random.normal(size=resolution)
    rnd_imag = np.random.normal(size=resolution)
    rnd_complex = rnd_real + 1j * rnd_imag

    # Create frequency grid
    kx = np.fft.fftfreq(resolution[1]) * resolution[1]
    ky = np.fft.fftfreq(resolution[0]) * resolution[0]
    kx, ky = np.meshgrid(kx, ky, indexing='ij')
    k2 = kx**2 + ky**2

    # Apply frequency weighting
    lowest_frequency = 0.1  # Avoid division by very small values
    weight_mask = k2 > lowest_frequency
    inv_k2 = np.where(k2 > 0, 1.0 / k2, 0.0)
    fft_weights = (inv_k2**smoothness) * weight_mask

    # Scale the frequencies
    fft_scaled = rnd_complex * fft_weights

    # Inverse FFT to transform back to the spatial domain
    noise = np.fft.ifft2(fft_scaled).real

    # Normalize to zero mean and unit standard deviation
    noise -= np.mean(noise)
    noise /= np.std(noise)

    return noise

def generate_random_incompressible_field_2d(
    Nx=128, Ny=128, spectrum_exponent=4
):
    """
    Generate a 2D incompressible (divergence-free) velocity field
    with random phases in Fourier space and a power-law spectrum.

    Parameters
    ----------
    Nx, Ny : int
        Resolution in x and y directions
    spectrum_exponent : float
        Power law exponent for the velocity amplitude ~ k^(-spectrum_exponent).
        Larger values emphasize large scales (lower wavenumbers).

    Returns
    -------
    vx, vy : 2D arrays of shape (Nx, Ny)
        The real-space velocity components of the random incompressible field.
    """

    # --------------------------------------------------------------
    # 1. Create frequency grids (kx, ky)
    # --------------------------------------------------------------
    kx = np.fft.fftfreq(Nx) * Nx  # frequencies in range [-Nx/2, Nx/2)
    ky = np.fft.fftfreq(Ny) * Ny
    kx, ky = np.meshgrid(kx, ky, indexing='ij')

    # Squared magnitude of the wavevector
    k2 = kx**2 + ky**2

    # To avoid division by zero at k=0, we can temporarily set k=0 -> 1
    # (then later we can overwrite or just ignore that single mode).
    k2[0, 0] = 1.0

    # --------------------------------------------------------------
    # 2. Define a power spectrum
    # --------------------------------------------------------------
    # For a simple power law, amplitude ~ 1/(k^spectrum_exponent)
    # We add a small epsilon in the denominator so that high-k won't blow up
    amplitude = k2**(-spectrum_exponent / 2.0)

    # --------------------------------------------------------------
    # 3. Generate random phases in Fourier space for vx and vy
    # --------------------------------------------------------------
    # Each component is a complex field with random real & imaginary parts.
    phase_x = np.random.normal(size=(Nx, Ny)) + 1j*np.random.normal(size=(Nx, Ny))
    phase_y = np.random.normal(size=(Nx, Ny)) + 1j*np.random.normal(size=(Nx, Ny))

    # Multiply by the amplitude to get the raw velocity in Fourier space
    vx_hat = amplitude * phase_x
    vy_hat = amplitude * phase_y

    # --------------------------------------------------------------
    # 4. Project the velocity to be divergence-free
    #    (vx_hat * kx + vy_hat * ky = 0)
    # --------------------------------------------------------------
    # Dot product in k-space:
    v_dot_k = vx_hat * kx + vy_hat * ky

    # Subtract the parallel (compressible) component from (vx_hat, vy_hat)
    vx_hat -= v_dot_k * kx / k2
    vy_hat -= v_dot_k * ky / k2

    # Reset k=0 mode to zero if desired (removes net flow in the entire field)
    vx_hat[0,0] = 0.0
    vy_hat[0,0] = 0.0

    # --------------------------------------------------------------
    # 5. Inverse Fourier transform to real space
    # --------------------------------------------------------------
    vx = np.fft.ifft2(vx_hat).real
    vy = np.fft.ifft2(vy_hat).real

    return vx, vy


def generate_trajectories_smoke(
    pde: PDEConfig,
    mode: str,
    num_samples: int,
    batch_size: int,
    device: torch.device = torch.device("cpu"),
    dirname: str = "data",
    n_parallel: int = 1,
    seed: int = 42,
    coarse_level: str = "gt",
) -> None:
    """
    Generate data trajectories for smoke inflow in bounded domain
    Args:
        pde (PDEConfig): pde at hand [NS2D]
        mode (str): [train, valid, test]
        num_samples (int): how many trajectories do we create
        batch_size (int): batch size
        device: device (cpu/gpu)
    Returns:
        None
    """ 
    factor = RESOLUTION_FACTOR[coarse_level]

    pde_string = str(pde)
    logger.info(f"Equation: {pde_string}")
    logger.info(f"Mode: {mode}")
    logger.info(f"Number of samples: {num_samples}")

    save_name = os.path.join(dirname, "_".join([pde_string, mode, str(seed), f"{pde.buoyancy_y:.5f}"]))
    if mode == "train":
        save_name = save_name + "_" + str(num_samples)
    h5f = h5py.File("".join([save_name, ".h5"]), "a")
    dataset = h5f.create_group(mode)

    tcoord, xcoord, ycoord, dt, dx, dy = {}, {}, {}, {}, {}, {}
    h5f_u, h5f_vx, h5f_vy = {}, {}, {} 
    nt, nx, ny = pde.grid_size[0], pde.grid_size[1], pde.grid_size[2]
    # The scalar field u, the components of the vector field vx, vy,
    # the coordinations (tcoord, xcoord, ycoord) and dt, dx, dt are saved
    h5f_u = dataset.create_dataset("u", (num_samples, nt, nx, ny), dtype=float)
    h5f_vx = dataset.create_dataset("vx", (num_samples, nt, nx, ny), dtype=float)
    h5f_vy = dataset.create_dataset("vy", (num_samples, nt, nx, ny), dtype=float)
    tcoord = dataset.create_dataset("t", (num_samples, nt), dtype=float)
    dt = dataset.create_dataset("dt", (num_samples,), dtype=float)
    xcoord = dataset.create_dataset("x", (num_samples, nx), dtype=float)
    dx = dataset.create_dataset("dx", (num_samples,), dtype=float)
    ycoord = dataset.create_dataset("y", (num_samples, ny), dtype=float)
    dy = dataset.create_dataset("dy", (num_samples,), dtype=float)
    buo_y = dataset.create_dataset("buo_y", (num_samples,), dtype=float)  
    viscosity = dataset.create_dataset("viscosity", (num_samples,), dtype=float)
    def genfunc(idx, s):
        try:
            # Existing code
            print("Seed:", idx + s)
            phi_seed(int(idx + s))  
            np.random.seed(int(idx + s))
            fig, axes = plt.subplots(1, 2, figsize=(12, 5))

            values = generate_noise(resolution=(256, 256), scale=11.0, smoothness=9.0) 
            values = resize(values, (256//factor, 256//factor), anti_aliasing=True)  
            values = tensor(values, spatial('x,y'))
            field = CenteredGrid(
                values=values,
                boundary=extrapolation.BOUNDARY,
                x=256//factor,
                y=256//factor,
                bounds=Box['x,y', 0 : pde.Lx, 0 : pde.Ly],
            )  
            smoke = abs(field)

            print("Smoke shape:", smoke.values.shape)
            print("Smoke values:", smoke.values) 
            velocity = StaggeredGrid(
                0, extrapolation.ZERO, x=pde.nx, y=pde.ny, bounds=Box['x,y', 0 : pde.Lx, 0 : pde.Ly]
            )
            fluid_field_ = []
            velocity_ = []
            for i in range(0, pde.nt + pde.skip_nt):
                smoke = advect.semi_lagrangian(smoke, velocity, pde.dt)
                buoyancy_force = (smoke * (0, pde.buoyancy_y)).at(velocity)
                velocity = advect.semi_lagrangian(velocity, velocity, pde.dt) + pde.dt * buoyancy_force
                velocity = diffuse.explicit(velocity, pde.nu, pde.dt) 
                velocity, _ = fluid.make_incompressible(velocity)
                fluid_field_.append(reshaped_native(smoke.values, groups=("x", "y", "vector"), to_numpy=True))
                velocity_.append(
                    reshaped_native(
                        velocity.staggered_tensor(),
                        groups=("x", "y", "vector"),
                        to_numpy=True,
                    )
                )

            fluid_field_ = np.asarray(fluid_field_[pde.skip_nt :]).squeeze()
            velocity_corrected_ = np.asarray(velocity_[pde.skip_nt :]).squeeze()[:, :-1, :-1, :]
            return fluid_field_[:: pde.sample_rate, ...], velocity_corrected_[:: pde.sample_rate, ...]
        except Exception as e:
            logger.error(f"Error in genfunc for idx={idx}, seed={s}: {e}", exc_info=True)
            raise 
    with utils.Timer() as gentime:
        rngs = np.random.randint(np.iinfo(np.int32).max, size=num_samples)
        fluid_field, velocity_corrected = zip(
            *Parallel(n_jobs=n_parallel)(delayed(genfunc)(idx, rngs[idx]) for idx in tqdm(range(num_samples)))
        )

    logger.info(f"Took {gentime.dt:.3f} seconds")
    print(f"Took {gentime.dt:.3f} seconds") 
    with utils.Timer() as writetime:
        for idx in range(num_samples):
            # fmt: off
            # Saving the trajectories
            h5f_u[idx : (idx + 1), ...] = fluid_field[idx]
            h5f_vx[idx : (idx + 1), ...] = velocity_corrected[idx][..., 0]
            h5f_vy[idx : (idx + 1), ...] = velocity_corrected[idx][..., 1]
            # fmt:on
            xcoord[idx : (idx + 1), ...] = np.asarray([np.linspace(0, pde.Lx, pde.nx)])
            dx[idx : (idx + 1)] = pde.dx
            ycoord[idx : (idx + 1), ...] = np.asarray([np.linspace(0, pde.Ly, pde.ny)])
            dy[idx : (idx + 1)] = pde.dy
            tcoord[idx : (idx + 1), ...] = np.asarray([np.linspace(pde.tmin, pde.tmax, pde.trajlen)])
            dt[idx : (idx + 1)] = pde.dt * pde.sample_rate
            buo_y[idx : (idx + 1)] = pde.buoyancy_y 
            viscosity[idx : (idx + 1)] = pde.nu
    logger.info(f"Took {writetime.dt:.3f} seconds writing to disk")

    print()
    print("Data saved")
    print()
    print()
    h5f.close()
