import json
import math
import random
import time
import re
from sympy import simplify
import numpy as np
import tqdm
import itertools
from typing import Tuple
from random import shuffle
from collections import Counter
import argparse
from utils import *

def most_frequent(input_list):
    occurence_count = Counter(input_list)
    return occurence_count.most_common(1)[0][0]

def constrain_func(function):
    function_ = function.split()
    function_post = []
    coeffs = []
    coeff_diffs = []
    x_term_coeff = []
    for item in function_:
        if item == '-' or item == '+':
            function_post.append(item)
            continue
        number_part = float(item.split('*')[0])
        var_part = item.split("*")[1]
        coeff_int = int(number_part)
        coeff_diff = abs(float(number_part) - coeff_int)
        ext_var = item.split('*')[1]
        coeffs.append(coeff_int)
        # the float difference should not be larger than 0.1
        coeff_diffs.append(coeff_diff)

        if "x" in var_part:
            x_term_coeff.append(abs(number_part))
        function_post.append(str(coeff_int) + '*' + ext_var)

    if any([True if coeff_diff >= 0.2 else False for coeff_diff in coeff_diffs]) \
            or np.sum(np.abs(x_term_coeff)) > 3 \
            or any([True if x_coeff > 1 else False for x_coeff in x_term_coeff]):
        return ""

    function = '+'.join(function_post)
    function = str(simplify(function))
    if function == '0' or len(function) == 2:
        return ""
    return function

# input a subset of original fp_results
def reverse_function(fp_results, num_x, numbers_orig) -> Tuple[str, str]:
    num_op = num_x - 1
    K = 1
    para_precision = 0.1
    expand_variables = expand_variables_funx(num_x, num_op, K, show_expand_variables=False)

    # STEP1 generate Matrix A, and B, we want solve AX = B, X = (A^-1)B, X is combination of coefficients
    A = []
    for fp_result in fp_results:
        try:
            fp_output = float(fp_results[fp_result])
            fp_input = eval(fp_result)
        except:
            continue
        compute_ret = compute_expand_variable(
            fp_input, num_op, K, fp_output)
        A.append(compute_ret)

    # the value of b can be sensitive for constant
    b_candi_list = [1, 10, 100]
    b = [0] * len(A) + [random.choice(b_candi_list)]
    if len(A):
        A.append([1] * len(A[0]))

    # STEP2 analytical solver
    try:
        A = np.array(A).astype(np.float)
        b = np.array(b).astype(np.float)
        if len(A) == len(A[0]):
            x = np.linalg.solve(A, b).astype(np.float)
        else:
            x = np.linalg.pinv(A).dot(b).astype(np.float)
        y_coeff_before_normalize = x[1]
        if abs(y_coeff_before_normalize) > 0.1:
            x = x / y_coeff_before_normalize

        analytic_eqtion = show_computed_funx(x, expand_variables, para_precision)
        analytic_symp = str(simplify(analytic_eqtion))

        # simplify and remove "y term"
        if "y" not in analytic_symp:
            return "", ""
        else:
            # if there are less than original variables, pass them
            for x_serial in range(num_x):
                if "x" + str(x_serial) not in analytic_symp:
                    return "", ""
            splits = analytic_symp.split()
            y_coeff = float(splits[-1][:-2])
            y_flag = 1 if splits[-2] == '+' else -1

            function = str(simplify(f"({analytic_symp} - {y_flag} * {y_coeff} * y) * ({y_flag} * -1) / {y_coeff}"))
            function_ret = constrain_func(function)
            if function_ret == "":
                return "", ""
        # if c in function, replace it with 1
        if "c" in function:
            function = function.replace("*c", "")

        fp_pred = eval_function_from_str(function, numbers_orig)
        fp_pred = round(fp_pred, 5)
        if int(fp_pred * 10 // 10) == fp_pred:
            fp_pred = int(fp_pred)

        return fp_pred, function
    except:
        return "", ""


def try_analytic_solve(try_calibration_path, output_path):
    cali_file = open(try_calibration_path)
    cali_results = json.load(cali_file)

    for idx, item in enumerate(tqdm.tqdm(cali_results)):

        if idx > 100:
            break

        key = item["query_id"]
        if isinstance(item["golds"], dict):
            ground_truths = item["golds"]["spans"]
        else:
            ground_truths = [str(int(ex)) for ex in item["golds"]]
        pred_orig = item["pred_orig"]
        numbers_orig = item["operand"]
        fp_results = item["fp_results"]

        pred_fp, pred_func = "", ""
        # generate variables
        num_x = len(numbers_orig)
        # optimal operators
        optimal_operand = len(numbers_orig) + 1

        # randomly take optimal operand samples to solve constant, y, x1, x2, ...
        data_length = len(fp_results)

        if data_length > optimal_operand:
            # there shall be multiple predictions, and we should take them carefully
            all_keys = list(fp_results.keys())
            key_combinations = list(itertools.combinations(all_keys, r=optimal_operand))
            shuffle(key_combinations)
            # take at most 100 combinations
            key_combinations = key_combinations[: 10]

            all_pred = []
            all_pred_functions = []
            for combo in key_combinations:
                # construct new fp_results
                subset_fp_results = {key: val for (key, val) in fp_results.items() if key in combo}
                pred_fp, pred_func = reverse_function(fp_results=subset_fp_results,
                                                      num_x=num_x,
                                                      numbers_orig=numbers_orig)
                all_pred.append(pred_fp)
                all_pred_functions.append(pred_func)

            if len(all_pred_functions):
                pred_func = most_frequent(all_pred_functions)
                pred_fp = most_frequent(all_pred)
        else:
            # there is only one prediction
            pred_fp, pred_func = reverse_function(fp_results=fp_results,
                                                  num_x=num_x,
                                                  numbers_orig=numbers_orig)

        if pred_func == "":
            continue

        with open(output_path, 'a') as f:
            f.write(key + '\t' + str(pred_fp) + '\t' + str(ground_truths) + '\t' + pred_func + '\n')

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="analytical paths")
    parser.add_argument(
        "--substitution_path",
        type=str,
        required=False,
        default="./substitution_outputs/BART_Large_fp_all_DROP_cases.json",
        help="location of the substitution file",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        required=False,
        default="./BART_analytical_results.txt",
        help="location of the output file",
    )
    args = parser.parse_args()
    file = open(args.output_path, 'w').close()
    try_analytic_solve(
        args.substitution_path,
        args.output_path
    )
