import sympy as sp
import numpy as np
import h5py

from util import parse_args, set_seed, solve_hs 

variables = ['t', 'u', 'x', 'u_t', 'u_tt', 'u_ttt', 'u_tttt', 'u_x', 'u_xt', 'u_xtt', 'u_xttt', 'u_xx', 'u_xxt',
             'u_xxtt', 'u_xxx', 'u_xxxt', 'u_xxxx']
# number of x terms and t terms (for the invariant variables)
alpha_beta = [(0, 1), (0, 2), (0, 3), (0, 4), (0, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (3, 0), (3, 1),
              (4, 0)]

A = 0.5
B = -1

gt_regular = sp.sympify("u_tt + u*u_xx + u_x*u_x + u_xxxx")
gt_invariant = sp.sympify("I_0x0*I_2x0 + I_0x2 + I_4x0 + 1")

def xyixiy(dataset):
    X = np.stack([dataset[key] for key in variables], axis=1)
    Y = np.zeros_like(X[:, 0])

    def get_var(x, t):
        if x == 0 and t == 0:
            name = 'u'
        else:
            name = 'u_' + ('x' * x) + ('t' * t)
        index = variables.index(name)
        assert index != -1
        return X[:, index]

    u_x = get_var(1, 0)
    u_tt = get_var(0, 2)
    mask = np.logical_and(np.abs(u_x) >= 1e-1, np.abs(u_tt) <= 1e0)

    invars = []
    invar_map = {}
    invar_names = []
    for (alpha, beta) in alpha_beta:
        num = get_var(alpha, beta)
        name = f"I_{alpha}x{beta}"

        # avoid numerical issues
        exp = int(round(3 * (B - A * alpha - beta) / (B - A)))
        den = np.cbrt(u_x[mask] ** exp)
        res = num[mask] / den

        invars.append(res)
        invar_names.append(name)
        invar_map[name] = res
    IX = np.stack(invars, axis=1)
    IY = Y[mask]

    return X, Y, IX, IY, invar_names


def solve(dataset, args):
    X, Y, IX, IY, invar_names = xyixiy(dataset)

    solve_hs(
        args=args,
        operations=["+", "*"],
        unary_operators=[],
        nested_constraints={},
        penalty="abs(x - y)",
        lhs_mask=[],
        X=X,
        Y=Y,
        IX=IX,
        IY=IY,
        regular_names=variables,
        invariant_names=invar_names,
        gts_regular_eq=[gt_regular],
        gts_invariant_eq=[gt_invariant],
    )


def solve_boussinesq(arguments):
    data = arguments.data

    with h5py.File(data, 'r') as f:
        d = {key: np.reshape(np.array(f[key][:]), (f[key].shape[0], -1)) for key in variables}
        d['t'] = np.tile(d['t'], (1, d['x'].shape[1]))
        d = {key: value.flatten() for key, value in d.items()}

        indices = np.arange(d['x'].size)
        selected_indices = np.random.choice(indices, args.samples) if args.samples < indices.size else indices
        d = {key: value[selected_indices] for key, value in d.items()}

        # plot_success(d,
        solve(d, arguments)


if __name__ == '__main__':
    args = parse_args()
    set_seed(args.seed)

    solve_boussinesq(args)
