import numpy as np
import h5py

def load_boussinesq_data(data_path, subsample=1.0, align_utt=False):
    A, B = 0.5, -1.0
    variables = ['t', 'u', 'x', 'u_t', 'u_tt', 'u_ttt', 'u_tttt', 'u_x', 'u_xt', 'u_xtt', 'u_xttt', 'u_xx', 'u_xxt',
             'u_xxtt', 'u_xxx', 'u_xxxt', 'u_xxxx']
    # order of x and t (for the invariant variables)
    alpha_beta = [(0, 1), (0, 2), (0, 3), (0, 4), (0, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (3, 0), (3, 1),
                (4, 0)]
    with h5py.File(data_path, 'r') as f:
        d = {key: np.reshape(np.array(f[key][:]), (f[key].shape[0], -1)) for key in variables}
        d['t'] = np.tile(d['t'], (1, d['x'].shape[1]))
        d = {key: value.flatten() for key, value in d.items()}

        indices = np.arange(d['x'].size)
        n_samples = int(subsample * indices.size)
        selected_indices = np.random.choice(indices, n_samples) if n_samples < indices.size else indices
        d = {key: value[selected_indices] for key, value in d.items()}

    X = np.stack([d[key] for key in d], axis=1)
    x_names = list(d.keys())
    
    def get_var(x, t):
        if x == 0 and t == 0:
            name = 'u'
        else:
            name = 'u_' + ('x' * x) + ('t' * t)
        index = variables.index(name)
        assert index != -1
        return X[:, index]
    
    u_x = get_var(1, 0)
    u_tt = get_var(0, 2)
    # note that we only take positive u_x
    mask = np.logical_and(np.abs(u_x) >= 1e-1, np.abs(u_tt) <= 1e0)
    # mask = np.abs(u_tt / u_x ** 2) <= 100

    invars = []
    invar_map = {}
    invar_names = []
    for (alpha, beta) in alpha_beta:
        num = get_var(alpha, beta)
        name = f"I_{alpha}x{beta}"

        # avoid numerical issues
        exp = int(round(3 * (B - A * alpha - beta) / (B - A)))
        if align_utt:
            den = np.cbrt(u_x[mask] ** (exp - 6))  # -6: scale by u_x^2
        else:
            den = np.cbrt(u_x[mask] ** exp)
        res = num[mask] / den

        invars.append(res)
        invar_names.append(name)
        invar_map[name] = res
    I = np.stack(invars, axis=1)

    return X[mask], I, x_names, invar_names


def load_rd_data(data_path, subsample=1.0, t_range=[0, 200], x_range=[5, -5], y_range=[5, -5]):
    with h5py.File(data_path, 'r') as f:
        dat = {key: np.array(f[key][:]) for key in f}
    dat['x'] = np.tile(dat['spatial_grid'][:, :, None, 0], (1, 1, dat['t'].shape[0]))
    dat['y'] = np.tile(dat['spatial_grid'][:, :, None, 1], (1, 1, dat['t'].shape[0]))
    dat['t'] = np.tile(dat['t'], (dat['spatial_grid'].shape[0], dat['spatial_grid'].shape[1], 1))
    del dat['spatial_grid']
    invar_names = ['I_t', 'I_x', 'I_y', 'I_xx', 'I_yy', 'I_xy', 'E_t', 'E_x', 'E_y', 'E_xx', 'E_yy', 'E_xy', 'R', 'x', 'y', 't']
    # n_ts = int(subsample * dat['t'].shape[-1]) if subsample < 1.0 else dat['t'].shape[-1]
    # t_indices = np.random.choice(dat['t'].shape[-1], n_ts, replace=False)
    I = np.stack([dat[key][x_range[0]:x_range[1], y_range[0]:y_range[1], t_range[0]:t_range[1]].flatten() for key in invar_names], axis=1)
    n_samples = int(subsample * I.shape[0]) if subsample < 1.0 else I.shape[0]
    indices = np.random.choice(I.shape[0], n_samples, replace=False)
    I = I[indices]
    return I, invar_names
