import sympy as sp
import numpy as np
import h5py

from util import parse_args, set_seed, solve_hs

variables = ['u', 'u_t', 'u_x', 'u_y', 'u_xx', 'u_yy', 'u_xy', 'x', 'y', 't',
             'v', 'v_t', 'v_x', 'v_y', 'v_xx', 'v_yy', 'v_xy']


gt_regular_0 = sp.sympify("u_t - 0.1 * (u_xx + u_yy) - (1 - u * u - v * v) * u - (u * u + v * v) * v")
gt_regular_1 = sp.sympify("v_t - 0.1 * (v_xx + v_yy) - (1 - u * u - v * v) * v + (u * u + v * v) * u")

gt_invariant_0 = sp.sympify("I_t - 0.1 * (I_xx + I_yy) - (1 - R) * R")
gt_invariant_1 = sp.sympify("E_t - 0.1 * (E_xx + E_yy) + R ** 2")


def xyixiy(dataset):
    X = np.stack([dataset[key] for key in variables], axis=1)
    Y = np.zeros_like(X[:, 0])

    def v(name):
        return dataset[name]
    
    invar_names  = ['I_t', 'I_x', 'I_y', 'I_xx', 'I_xy', 'I_yy', 'R', 'x', 'y', 't']
    invar_names += ['E_t', 'E_x', 'E_y', 'E_xx', 'E_xy', 'E_yy']
    invars = [v(invar) for invar in invar_names]
    
    IX = np.stack(invars, axis=1)
    IY = Y

    return X, Y, IX, IY, invar_names

def solve(dataset, args):
    X, Y, IX, IY, invar_names = xyixiy(dataset)
    lhs_mask = ["I_t", "E_t"] if args.solve_invariants else ["u_t", "v_t"]

    solve_hs(
        args=args,
        operations=["+", "*"],
        unary_operators=[],
        nested_constraints={},
        penalty="abs(x - y)",
        lhs_mask=lhs_mask,
        X=X,
        Y=Y,
        IX=IX,
        IY=IY,
        regular_names=variables,
        invariant_names=invar_names,
        gts_regular_eq=[gt_regular_0, gt_regular_1],
        gts_invariant_eq=[gt_invariant_0, gt_invariant_1],
    )

def solve_reac_diff(arguments):
    data = arguments.data

    def reshape(x):
        x = np.array(x)
        x = np.transpose(x, [2, 0, 1])
        x = np.reshape(x, [x.shape[0], -1])
        return x

    with h5py.File(data, 'r') as f:
        d = {}

        invariants  = ['E_t', 'E_x', 'E_y', 'E_xx', 'E_xy', 'E_yy']
        invariants += ['I_t', 'I_x', 'I_y', 'I_xx', 'I_xy', 'I_yy']
        for i in invariants:
            d[i] = reshape(f[i])

        # regular variables (decomposed into u and v)
        for ext in ["", "_t", "_x", "_y", "_xx", "_xy", "_yy"]:
            d["u" + ext] = reshape(f['u' + ext][..., 0])
            d["v" + ext] = reshape(f['u' + ext][..., 1])

        d['R'] = reshape(f['R'])

        num_time = d['R'].shape[0]
        num_spatial = d['R'].shape[1]

        spatial_grid = f['spatial_grid']
        x, y = spatial_grid[..., 0], spatial_grid[..., 1]
        d['x'] = np.tile(x, (num_time, 1))
        d['y'] = np.tile(y, (num_time, 1))
        d['t'] = np.tile(f['t'], (1, num_spatial))

        # clip border
        d = {key: value.reshape(-1, 128, 128)[:40, 3:-3, 3:-3].flatten() for key, value in d.items()}

        indices = np.arange(d['x'].size)
        selected_indices = np.random.choice(indices, args.samples) if args.samples < indices.size else indices
        d = {key: value[selected_indices] for key, value in d.items()}

        solve(d, arguments)


if __name__ == '__main__':
    args = parse_args()
    set_seed(args.seed)

    solve_reac_diff(args)
