import sympy as sp
import numpy as np
import h5py

from util import parse_args, set_seed, solve_hs

variables = ['u', 'u_x', 'u_xx', 'u_xy', 'u_y', 'u_yy', 'x', 'y']

gt_regular = 8 * sp.sympify("x * u_x + y * u_y") - sp.sympify("u_xx + u_yy") - sp.exp(sp.sympify("4 * (x * x + y * y)"))
gt_invariant = 8 * sp.sympify("zeta_2") - sp.sympify("L") - sp.exp(sp.sympify("4 * R"))

def xyixiy(dataset):
    X = np.stack([dataset[key] for key in variables], axis=1)
    Y = np.zeros_like(X[:, 0])

    def v(name):
        return dataset[name]

    invars = [
        v('x') * v('x') + v('y') * v('y'),
        v('u'),

        v('x') * v('u_y') - v('y') * v('u_x'),
        v('x') * v('u_x') + v('y') * v('u_y'),

        v('u_xx') + v('u_yy'),
        v('u_xx') * v('u_xx') + 2 * v('u_xy') * v('u_xy') + v('u_yy') * v('u_yy'),
        v('x') * v('x') * v('u_xx') * v('u_xx') + v('y') * v('y') * v('u_yy') * v('u_yy') + 2 * v('x') * v('y') +
        v('u_xy'),
    ]
    invar_names = ["R", "u", "zeta_1", "zeta_2", "L", "T", "theta_3"]

    IX = np.stack(invars, axis=1)
    IY = Y

    return X, Y, IX, IY, invar_names

invar_to_regular = {
    "R": "x * x + y * y",
    "u": "u",
    "zeta_1": "x * u_y - y * u_x",
    "zeta_2": "x * u_x + y * u_y",
    "L": "u_xx + u_yy",
    "T": "u_xx * u_xx + 2 * u_xy * u_xy + u_yy * u_yy",
    "theta_3": "x * x * u_xx * u_xx + y * y * u_yy * u_yy + 2 * x * y * u_xy",
}

def solve(dataset, args):
    X, Y, IX, IY, invar_names = xyixiy(dataset)

    solve_hs(
        args=args,
        operations=["+", "*"],
        unary_operators=["exp"],
        nested_constraints={"exp": {"exp": 0}},
        penalty="abs(x - y)",
        lhs_mask=[],
        X=X,
        Y=Y,
        IX=IX,
        IY=IY,
        regular_names=variables,
        invariant_names=invar_names,
        gts_regular_eq=[gt_regular],
        gts_invariant_eq=[gt_invariant],
    )

def solve_darcy(arguments):
    data = arguments.data

    with h5py.File(data, 'r') as f:
        d = {key: np.reshape(np.array(f[key][:]), (f[key].shape[0], -1)) for key in variables}
        x, y = np.meshgrid(d['x'], d['y'], indexing='ij')
        d['x'] = np.tile(x, (5, 1))
        d['y'] = np.tile(y, (5, 1))

        # clip out border pixels
        d = {key: value.reshape(5, 128, 128)[:, 3:-3, 3:-3].flatten() for key, value in d.items()}

        indices = np.arange(d['u'].size)
        selected_indices = np.random.choice(indices, args.samples) if args.samples < indices.size else indices
        d = {key: value[selected_indices] for key, value in d.items()}

        solve(d, arguments)


if __name__ == '__main__':
    args = parse_args()
    set_seed(args.seed)

    solve_darcy(args)
