import h5py
import numpy as np
import sympy as sp
import os, sys
import torch
import symbolicregression
import requests
import argparse
from joblib import Parallel, delayed
# set environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = ""

DATA_PATH = "../data/boussinesq/boussinesq_1_dt_1e-3.h5"

def load_boussinesq_data(data_path, subsample=1.0, align_utt=False):
    A, B = 0.5, -1.0
    variables = ['t', 'u', 'x', 'u_t', 'u_tt', 'u_ttt', 'u_tttt', 'u_x', 'u_xt', 'u_xtt', 'u_xttt', 'u_xx', 'u_xxt',
             'u_xxtt', 'u_xxx', 'u_xxxt', 'u_xxxx']
    # order of x and t (for the invariant variables)
    alpha_beta = [(0, 1), (0, 2), (0, 3), (0, 4), (0, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (3, 0), (3, 1),
                (4, 0)]
    with h5py.File(data_path, 'r') as f:
        d = {key: np.reshape(np.array(f[key][:]), (f[key].shape[0], -1)) for key in variables}
        d['t'] = np.tile(d['t'], (1, d['x'].shape[1]))
        d = {key: value.flatten() for key, value in d.items()}

        indices = np.arange(d['x'].size)
        n_samples = int(subsample * indices.size)
        selected_indices = np.random.choice(indices, n_samples) if n_samples < indices.size else indices
        d = {key: value[selected_indices] for key, value in d.items()}

    X = np.stack([d[key] for key in d], axis=1)
    x_names = list(d.keys())
    
    def get_var(x, t):
        if x == 0 and t == 0:
            name = 'u'
        else:
            name = 'u_' + ('x' * x) + ('t' * t)
        index = variables.index(name)
        assert index != -1
        return X[:, index]
    
    u_x = get_var(1, 0)
    u_tt = get_var(0, 2)
    # note that we only take positive u_x
    mask = np.logical_and(np.abs(u_x) >= 1e-1, np.abs(u_tt) <= 1e0)
    # mask = np.abs(u_tt / u_x ** 2) <= 100

    invars = []
    invar_map = {}
    invar_names = []
    for (alpha, beta) in alpha_beta:
        num = get_var(alpha, beta)
        name = f"I_{alpha}x{beta}"

        # avoid numerical issues
        exp = int(round(3 * (B - A * alpha - beta) / (B - A)))
        if align_utt:
            den = np.cbrt(u_x[mask] ** (exp - 6))  # -6: scale by u_x^2
        else:
            den = np.cbrt(u_x[mask] ** exp)
        res = num[mask] / den

        invars.append(res)
        invar_names.append(name)
        invar_map[name] = res
    I = np.stack(invars, axis=1)

    return X[mask], I, x_names, invar_names


import random
from functools import partial

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('-S', '--subsample', type=float, default=0.001)
    parser.add_argument('-r', '--regular_vars', action='store_true', default=False)
    parser.add_argument('-N', '--n_runs', type=int, default=10)
    parser.add_argument('-O', '--seed_offset', type=int, default=0)
    args = parser.parse_args()
    
    # mkdir
    method = "regular" if args.regular_vars else "invars"
    if not os.path.exists(f"./results/boussinesq/{method}"):
        os.makedirs(f"./results/boussinesq/{method}")

    model_path = "../data/symmetry-sr/pretrained-symbolic-transformer.pt" 
    try:
        if not os.path.isfile(model_path):
            print("Model not found, downloading...")
            url = "https://dl.fbaipublicfiles.com/symbolicregression/model1.pt"
            r = requests.get(url, allow_redirects=True)
            open(model_path, 'wb').write(r.content)
        if not torch.cuda.is_available():
            print("Warning: CUDA is not available, loading model on CPU.")
            model = torch.load(model_path, map_location=torch.device('cpu'))
        else:
            model = torch.load(model_path)
            model = model.cuda()
        print("Model successfully loaded!")

    except Exception as e:
        print("ERROR: model not loaded! path was: {}".format(model_path))
        print(e)  

    def _run(seed):
        random.seed(seed)
        np.random.seed(seed)
        X, I, x_names, invar_names = load_boussinesq_data(DATA_PATH, subsample=args.subsample)
        est = symbolicregression.model.SymbolicTransformerRegressor(
                                model=model,
                                max_input_points=200,
                                n_trees_to_refine=100,
                                rescale=True
                                )
        if not args.regular_vars:
            y_I = I[:, 1]
            # exclude LHS feature
            # X_I = np.delete(I, 1, axis=1)
            # invar_names_ = invar_names[0:1] + invar_names[2:]
            # or, exclude all time derivatives
            # X_I = I[:, [4, 8, 11, 13]]
            # invar_names_ = [invar_names[i] for i in [4, 8, 11, 13]]
            # or, exclude all mixed derivatives
            X_I = I[:, [0, 2, 3, 4, 8, 11, 13]]
            invar_names_ = [invar_names[i] for i in [0, 2, 3, 4, 8, 11, 13]]
            replace_vars = { f"x_{i}": invar_names_[i] for i in range(len(invar_names_)) }
            est.fit(X_I, y_I)
        else:
            y = X[:, 4]
            # exclude all mixed derivatives
            X_ = X[:, [0, 1, 2, 3, 5, 6, 7, 11, 14, 16]]
            regular_names_ = [x_names[i] for i in [0, 1, 2, 3, 5, 6, 7, 11, 14, 16]]
            replace_vars = { f"x_{i}": regular_names_[i] for i in range(len(regular_names_)) }
            est.fit(X_, y)

        replace_ops = {"add": "+", "mul": "*", "sub": "-", "pow": "**", "inv": "1/"}
        model_str = est.retrieve_tree(with_infos=True)["relabed_predicted_tree"].infix()
        for op,replace_op in replace_ops.items():
            model_str = model_str.replace(op,replace_op)
        # thresholding
        def zero_small_consts(e, threshold=0.01):
            if e.is_Number and abs(e) < threshold:
                return sp.S.Zero
            return e
        expr = sp.parse_expr(model_str)
        for var, replace_var in replace_vars.items():
            expr = expr.subs(sp.Symbol(var), sp.Symbol(replace_var))
        # expr = expr.replace(lambda e: e.is_Number, zero_small_consts)
        expr = sp.expand(expr)
        try:
            expr = expr.replace(lambda e: e.is_Number, partial(zero_small_consts, threshold=0.2))
        except TypeError:  # invalid NaN comparison
            pass
        with open(f"./results/boussinesq/{method}/boussinesq_seed{seed}.txt", "w") as f:
            f.write(str(expr))
        return expr

    results = Parallel(n_jobs=args.n_runs)(delayed(_run)(seed + args.seed_offset) for seed in range(args.n_runs))
    for eq in results:
        print(eq)