import random
import sympy as sp
import numpy as np
import argparse
from pysr import PySRRegressor
from sympy import simplify

EPSILON = 1e-3
DEBUG = True

if DEBUG:
    import pandas as pd
    pd.set_option('display.max_colwidth', 360)


def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument('data', help="data path")
    ap.add_argument('-s', '--samples', type=int, default=10000, help="max number of samples")
    ap.add_argument('-P', '--populations', type=int, default=127, help="number of populations")
    ap.add_argument('-p', '--population_size', type=int, default=27, help="population size")
    ap.add_argument('-I', '--niterations', type=int, default=10, help="number of iterations")
    ap.add_argument('-m', '--max_size', type=int, default=25, help="max equation size")
    ap.add_argument('-f', '--frac_replaced', type=float, default=0.00036, help="fraction replaced")
    ap.add_argument('-i', '--solve_invariants', action='store_true', help="solve using invariants")
    ap.add_argument('-r', '--solve_regular', action='store_true', help="solve using regular")
    ap.add_argument('--seed', type=int, default=0, help="random seed, -1 for random")
    args = ap.parse_args()
    assert args.solve_invariants ^ args.solve_regular, \
        "Exactly one of solve invariants (-i) or solve regular (-r) should be provided as flags"
    return args


def set_seed(seed): 
    if seed != -1:
        np.random.seed(seed)
        random.seed(seed)
    

def _select_from_hof(hall_of_fame):
    """ currently uses the same as the 'best' scoring strategy from pysr """
    losses = hall_of_fame["loss"].to_list()
    scores = hall_of_fame["score"].to_list()
    best_loss = losses[-1]
    start_ind = next(i for i, l in enumerate(losses) if l < best_loss * 1.5)
    best_ind = start_ind + np.argmax(scores[start_ind:])
    return hall_of_fame["sympy_format"].to_list()[best_ind]


def _prune_low_coeff(expr, cutoff=0.01):
    expr = expr.expand()
    coeff_dict = expr.as_coefficients_dict()
    
    # normalization factor
    denom = np.max(np.abs(list(coeff_dict.values())))
    pruned_terms = {term: coeff / denom for term, coeff in coeff_dict.items() if abs(coeff / denom) >= cutoff}

    pruned_expr = sp.S.Zero
    for term, coeff in pruned_terms.items():
        pruned_expr += coeff * term

    return pruned_expr


def _hof_summary(hof, norm, k):
    """ k denotes (max) number of equations to return """

    losses = []
    for var, table in hof.items():
        raw_loss = np.min(table["loss"].to_numpy())
        score = raw_loss / norm[var]
        print("var", var, "score", score, "raw", raw_loss, "expr", _select_from_hof(hof[var]))
        losses.append([score, var])
    
    losses = sorted(losses, key=lambda x:x[0])

    exprs = []
    for _, var in losses:
        best_expr = simplify(_select_from_hof(hof[var]))
        # convert e.g. ut = uxx + uyy to ut - uxx - uyy = 0
        # before norm: u_xxxx = -1.0*u*u_xx - 1.0*u_tt - 1.0*u_x**2 + 2.9717057e-8*u_x*(u + u_x - 2.7151405e-8)
        norm_expr = _prune_low_coeff(sp.sympify(var) - best_expr)
        
        for prev in exprs:
            matches, _ = _expr_matches(norm_expr, prev)
            if matches:
                break
        else:
            exprs.append(norm_expr)

        if len(exprs) == k:
            return exprs 
    
    return exprs


def _expr_matches(prediction, target, term_deviation=0.05):
    """
    see if two expressions differ by only a constant multiple
    where coefficients should be within some threshold
    """
    pred_expanded = prediction.expand()
    target_expanded = target.expand()

    pred_terms = pred_expanded.as_coefficients_dict()
    target_terms = target_expanded.as_coefficients_dict()

    # evaluate terms numerically on test points
    all_symbols = list(pred_expanded.free_symbols.union(target_expanded.free_symbols))
    test_points = {sym: np.random.randn(100) for sym in all_symbols}
    def evaluate_expression(expr, sample_dict):
        """ Evaluates an expression over multiple sampled points """
        subs_dicts = [dict(zip(sample_dict.keys(), vals)) for vals in zip(*sample_dict.values())]
        return np.array([float(expr.subs(subs)) for subs in subs_dicts])
    pred_evals = {term: evaluate_expression(term, test_points) for term, coeff in pred_terms.items()}
    target_evals = {term: evaluate_expression(term, test_points) for term, coeff in target_terms.items()}
    matched_terms = set()   

    # Compare terms numerically
    pred_term_matched = { term: None for term in pred_terms.keys() }
    for target_term, target_values in target_evals.items():
        found_match = False
        for pred_term, pred_values in pred_evals.items():
            if np.allclose(target_values, pred_values, atol=term_deviation * np.abs(target_values).max()):
                matched_terms.add(pred_term)
                found_match = True
                pred_term_matched[pred_term] = target_term
                break
        if not found_match:
            return False, f"Missing equivalent term for {target_term}"
    for pred_term, matched in pred_term_matched.items():
        if not matched:
            return False, f"Redundant term {pred_term}"
        
    return True, "Matched"


def _match_pred_with_gt(predictions, gts) -> bool:
    unmatched_gt = []
    matched_gt = {}
    matched_pred = set()
    unmatched_pred = []

    print("* Ground Truths *")
    for gt in gts:
        print(0, "=", gt)

    print("* Predictions *")
    for p in predictions:
        print(0, "=", p)

    for i, gt in enumerate(gts):
        matched_gt[i] = []
        for j, pred in enumerate(predictions):
            match, reason = _expr_matches(pred, gt)
            if match:
                matched_gt[i].append(j)
                matched_pred.add(j)
        if not matched_gt[i]:
            unmatched_gt.append(i)

    unmatched_pred = [j for j in range(len(predictions)) if j not in matched_pred]

    if unmatched_gt:
        print("* Unmatched Ground Truths *")
        for i in unmatched_gt:
            print(f"GT[{i}]: {gts[i]}")
    else:
        print("* All ground truths have been matched with predictions *")

    if len(unmatched_gt) != len(gts):
        print("* Matched Ground Truths and Predictions *")
        for i, matched_preds in matched_gt.items():
            if matched_preds:
                print(f"GT[{i}]: {gts[i]}")
                for j in matched_preds:
                    print(f"  -> Matches with Prediction[{j}]: {predictions[j]}")
    else:
        print("* No matches between Ground Truth and Predictions *")
    
    if unmatched_pred:
        print("* Unmatched Predictions *")
        for j in unmatched_pred:
            print(f"Prediction[{j}]: {predictions[j]}")
            for i, gt in enumerate(gts):
                match, reason = _expr_matches(predictions[j], gt)
                print(f"  - Against GT[{i}]: {reason}")
    else:
        print("* All predictions have been matched with ground truths *")

    return len(unmatched_gt) == 0 and len(unmatched_pred) == 0


def _solve_variable(model: PySRRegressor, X, Y, index, lhs_mask, var_names):
    target = X[:, index]

    Yp = Y + target
    Xp = np.delete(X, index, 1) 

    names = var_names.copy()
    print("Solving for ", names[index])
    del names[index]

    # delete all variables in lhs_mask
    # from rhs
    to_delete = []
    for i, v in enumerate(names):
        if v in lhs_mask:
            to_delete.append(i)
    Xp = np.delete(Xp, to_delete, 1) 

    to_delete.reverse()
    for d in to_delete:
        del names[d];

    model.fit(Xp, Yp, 
              variable_names=names)

    return model.get_hof()

def _solve_isolated_variable(model, X, Y, lhs_mask, var_names, gts):
    hof = {}
    score_norm = {}

    for i, v in enumerate(var_names):
        if v in lhs_mask or lhs_mask == []:
            hof[v] = _solve_variable(model, X, Y, i, lhs_mask, var_names)
            score_norm[v] = np.abs(X[:, i]).mean()

    norm_exprs = _hof_summary(hof, score_norm, len(gts)) 

    return _match_pred_with_gt(norm_exprs, gts)

def _solve_hs_helper(model: PySRRegressor,
               X, Y, IX, IY,
               lhs_mask,
               regular_names, invariant_names,
               gts_regular_eq, gts_invariant_eq,
               do_regular, do_invariant
               ):
    if do_regular:
        correct = _solve_isolated_variable(model, X, Y, lhs_mask, regular_names, gts_regular_eq)
        print("fully correct!" if correct else "one or more equations incorrect!")

    if do_invariant:
        correct = _solve_isolated_variable(model, IX, IY, lhs_mask, invariant_names, gts_invariant_eq)
        print("fully correct!" if correct else "one or more equations incorrect!")

    return correct


# y and iy should be zero, only provided for formalties sake
def solve_hs(args,
             operations: list[str],
             unary_operators: list[str],
             nested_constraints: dict[str, dict],
             penalty: str,
             lhs_mask: list[str],  # or [] if no mask
             X, Y, IX, IY,
             regular_names, invariant_names,
             gts_regular_eq, gts_invariant_eq):
    
    print("Running for iters", args.niterations)
    model = PySRRegressor(
        maxsize=args.max_size,
        niterations=args.niterations,
        binary_operators=operations,
        unary_operators=unary_operators,
        nested_constraints=nested_constraints,
        elementwise_loss="loss(x, y) = " + penalty,
        populations=args.populations,
        population_size=args.population_size,
        fraction_replaced=args.frac_replaced,
        random_state=None if args.seed == -1 else args.seed,
        deterministic=(args.seed != -1),
        parallelism="serial" if args.seed != -1 else None,
        batching=True
    )

    _solve_hs_helper(model,
                    X, Y, IX, IY, lhs_mask,
                    regular_names, invariant_names,
                    gts_regular_eq, gts_invariant_eq,
                    args.solve_regular, args.solve_invariants)


if __name__ == '__main__':
    expr1 = sp.simplify("0.125125491483546*L - 1.0*zeta_2 + 0.125282649100849*exp(4.0056014*R)")
    expr2 = sp.simplify("-L + zeta_2 - exp(4*R)")

    print(_expr_matches(expr1, expr2))

    expr1 = sp.sympify("0.125373589790026*L - 1.0*zeta_2") + 0.125580606661487*sp.exp(sp.sympify("4.009881*R"))
    expr2 = 8 * sp.sympify("zeta_2") - sp.sympify("L") - sp.exp(sp.sympify("4 * R"))
    print(_match_pred_with_gt([expr1], [expr2]))