import numpy as np
from numpy.fft import fftn

###

# This file is probably obsolete. Everything has moved to evaluation_functions.py
# TODO: check if file is truly obsolete, then delete

###

def calc_divergence_residual(x, y, u, v):
    """
    Calculates the divergence residual of a vector field.
    Vector field must be provided on a structured grid, as divergence is calculated with finite differences.
    NaN's that might appear in the vector field due to interpolation are ignored in the calculation.

    Parameters:
    x: x-coordinates of the grid points.
    y: y-coordinates of the grid points.
    u: x-component of the vector field.
    v: y-component of the vector field.
    
    Returns:
    residual_divergence: The divergence residual as a field.
    rms
    rms_residual_divergence: The root mean square of the divergence residual.
    """

    du_dx = np.zeros_like(u)
    dv_dy = np.zeros_like(v)

    du_dx[:, 1:-1] = (u[:, 2:] - u[:, :-2]) / (x[:, 2:] - x[:, :-2])
    dv_dy[1:-1, :] = (v[2:, :] - v[:-2, :]) / (y[2:, :] - y[:-2, :])

    du_dx[:, 0] = (u[:, 1] - u[:, 0]) / (x[:, 1] - x[:, 0])
    du_dx[:, -1] = (u[:, -1] - u[:, -2]) / (x[:, -1] - x[:, -2])

    dv_dy[0, :] = (v[1, :] - v[0, :]) / (y[1, :] - y[0, :])
    dv_dy[-1, :] = (v[-1, :] - v[-2, :]) / (y[-1, :] - y[-2, :])

    divergence_residual = du_dx + dv_dy
    rms_residual_divergence = np.sqrt(np.nanmean(divergence_residual**2))

    return divergence_residual, rms_residual_divergence


def calc_ke_spectra_over_wavenumber(rollout):
    '''
    Compute the turbulent kinetic energy (TKE) spectra over wavenumbers
    from a rollout of 2D velocity fields.

    Parameters
    ----------
    rollout : np.ndarray
        Array of velocity fields with shape (T, Nx, Ny, 2), where
        - T is the number of time steps,
        - Nx, Ny are spatial grid sizes,
        - the last dimension (2) contains the velocity components (u, v).

    Returns
    -------
    tke_spectrum_all : np.ndarray
        Array of shape (T, N_k) containing the TKE spectrum for each
        time step, where N_k is the maximum resolved wavenumber index.
    wave_numbers : np.ndarray
        1D array of physical wave numbers corresponding to each entry
        of the spectrum (length N_k).

    Notes
    -----
    - The function uses 2D Fourier transforms (`fftn`) to compute spectral
      energy densities of the velocity field.
    - The turbulent kinetic energy density in Fourier space is computed as
      0.5 * (|û(k)|² + |v̂(k)|²).
    - Energy is accumulated into bins based on the magnitude of the
      wave vector `sqrt(kx² + ky²)`.
    - The physical domain size is assumed to be Lx = Ly = 0.23 units.
      If your domain differs, adjust `lx` and `ly` accordingly.
    - The spectra are normalized by the number of grid points (Nx * Ny).
    '''


    tke_spectrum_all = []    
    for time in np.arange(rollout.shape[0]): 

        u = rollout[time, :, :, 0]
        v = rollout[time, :, :, 1]

        nx, ny = u.shape

        lx = 0.23   # TODO: remove hardcoded length scales
        ly = 0.23
        
        nt = nx * ny

        uf = fftn(u)/nt
        vf = fftn(v)/nt

        tkeh = np.zeros((nx, ny))
        tkeh = 0.5*(uf*np.conj(uf) + vf*np.conj(vf)).real

        k0x = 2.0*np.pi/lx
        k0y = 2.0*np.pi/ly

        knorm = (k0x + k0y)/2.0

        kxmax = nx/2
        kymax = ny/2

        nmax = int(np.round(np.sqrt((nx/2)**2 + (ny/2)**2)))
        wave_numbers = knorm * np.arange(0, nmax+1)

        tke_spectrum = np.zeros(len(wave_numbers))
        for kx in range(nx):
            rkx = kx
            if kx > kxmax:
                rkx = rkx - nx
            for ky in range(ny):
                rky = ky
                if ky > kymax:
                    rky = rky - ny

                rk = np.sqrt(rkx*rkx + rky*rky)
                k = int(np.round(rk))
                tke_spectrum[k] = tke_spectrum[k] + tkeh[kx, ky]
    
        tke_spectrum_all.append(tke_spectrum)
    
    tke_spectrum_all = np.array(tke_spectrum_all)
    wave_numbers = np.array(wave_numbers)

    return tke_spectrum_all, wave_numbers

def calc_cell_area(x_grid, y_grid, u, v):
    dx = np.diff(x_grid, axis=1)
    dy = np.diff(y_grid, axis=0)

    cell_area = dx[:-1, :] * dy[:, :-1]

    u_center = 0.25*(u[:-1,:-1] + u[1:,:-1] + u[:-1,1:] + u[1:,1:])
    v_center = 0.25*(v[:-1,:-1] + v[1:,:-1] + v[:-1,1:] + v[1:,1:])

    return cell_area, u_center, v_center

def calc_conservation_variables(roll_list, true_list, x_grid, y_grid):

    time = np.arange(0, roll_list.shape[0])

    Ke_all_roll = []
    Px_all_roll = []
    Py_all_roll = []

    Ke_all_true = []
    Px_all_true = []
    Py_all_true = []

    for i in time:        
        u_roll = roll_list[i, :, :, 0]
        v_roll = roll_list[i, :, :, 1]
        cell_area, u_center, v_center = calc_cell_area(x_grid, y_grid, u_roll, v_roll)

        Px_roll = np.sum(u_center * cell_area)
        Py_roll = np.sum(v_center * cell_area)
        Ke_roll = 0.5 * np.sum((u_center**2 + v_center**2) * cell_area)

        Px_all_roll.append(Px_roll)
        Py_all_roll.append(Py_roll)
        Ke_all_roll.append(Ke_roll)

        u_true = true_list[i, :, :, 0]
        v_true = true_list[i, :, :, 1]
        cell_area, u_center, v_center = calc_cell_area(x_grid, y_grid, u_true, v_true)

        Px_true = np.sum(u_center * cell_area)
        Py_true = np.sum(v_center * cell_area)
        Ke_true = 0.5 * np.sum((u_center**2 + v_center**2) * cell_area)

        Px_all_true.append(Px_true)
        Py_all_true.append(Py_true)
        Ke_all_true.append(Ke_true)

    return time, Px_all_true, Py_all_true, Ke_all_true, Px_all_roll, Py_all_roll, Ke_all_roll


def calc_mean_variance_field(roll_list, true_list):
    
    mean_roll = np.mean(roll_list, axis=0)
    var_roll = np.var(roll_list, axis=0)
    
    mean_true = np.mean(true_list, axis=0)
    var_true = np.var(true_list, axis=0)
    
    diff_mean = (mean_roll - mean_true) # / mean_true
    diff_var = (var_true - var_roll) # / var_true
    
    return mean_roll, var_roll, mean_true, var_true, diff_mean, diff_var

def find_closest_point_idx(mesh_centers, coords):

    x_target, y_target = coords[0], coords[1]
    target = np.array([x_target, y_target])

    distances = np.linalg.norm(mesh_centers - target, axis=1)
    idx = np.argmin(distances)

    return idx

def POD(u):
    
    timesteps, nrofpoints, components = u.shape

    data_matrix = u.reshape(timesteps, -1).T

    mean_field_flat = np.mean(data_matrix, axis=1, keepdims=True)
    data_fluct = data_matrix - mean_field_flat

    C = np.dot(data_fluct.T, data_fluct) / timesteps

    eigvals, eigvecs = np.linalg.eigh(C)

    idx = np.argsort(eigvals)[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[idx]

    spatial_modes_flat = np.dot(data_fluct, eigvecs) / np.sqrt(eigvals)
    time_coeffs = eigvecs * np.sqrt(eigvals)

    n_modes = spatial_modes_flat.shape[1]
    spatial_modes = spatial_modes_flat.T.reshape(n_modes, nrofpoints, components)

    mean_field = mean_field_flat.reshape(nrofpoints, components)

    return mean_field, spatial_modes, time_coeffs, eigvals

def extract_probe_velocity(mesh_centers, probe_point, rollout, groundtruth):
    
    closest_idx = find_closest_point_idx(mesh_centers, probe_point)
    print("Coordinate of probe point:", probe_point)
    print("Coordinate of closest mesh point: ", mesh_centers[closest_idx])

    rollout_probe = rollout[:, closest_idx, :]
    groundtruth_probe = groundtruth[:, closest_idx, :]

    rollout_probe_norm = np.linalg.norm(rollout_probe, axis=-1)
    groundtruth_probe_norm = np.linalg.norm(groundtruth_probe, axis=-1)

    timespan = np.arange(np.shape(rollout)[0])

    return timespan, rollout_probe_norm, groundtruth_probe_norm