#!/usr/bin/env python3

import sys
import re
import json
from tqdm import tqdm
from time import time
from sympy import diff, E, Symbol, log
from sympy import expand, simplify, expand_trig, trigsimp, expand_log, integrate
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr, standard_transformations
from sympy.parsing.sympy_parser import implicit_multiplication
from sympy.utilities.lambdify import lambdify
from scipy.optimize import minimize_scalar
from math import isnan
from utils import nonstandard_operators_pre, nonstandard_operators_post
from utils import timeout
from pebble import ProcessPool
from concurrent.futures import TimeoutError

TIMEOUT_SYMBOLIC=10
TIMEOUT_NUMERIC=10

x_real = Symbol('x', real=True)
trs = standard_transformations + (implicit_multiplication,)

def parse_input(inp):
    return parse_expr(inp, transformations=trs, local_dict={'e': E, 'x': x_real})

def clean_answer(ans):
    ans = ans.strip()
    ans = ans.replace(r'\left|', '|')
    ans = ans.replace(r'\right|', '|')
    ans = ans.replace(r'\lvert', '|')
    ans = ans.replace(r'\rvert', '|')
    ans = ans.replace(r'\vert', '|')
    ans = ans.replace(r'\left(', '(')
    ans = ans.replace(r'\right)', ')')
    ans = ans.replace(r'\bigl', '')
    ans = ans.replace(r'\bigr', '')
    ans = ans.replace(r'\Bigl', '')
    ans = ans.replace(r'\Bigr', '')
    ans = ans.replace(r'\big(', '(')
    ans = ans.replace(r'\big)', ')')
    ans = ans.replace(r'\displaystyle', '')
    ans = ans.replace(r'\,', ' ')
    ans = ans.replace(r'\;', ' ')
    ans = re.sub('  *', ' ', ans)
    ans = re.sub('^ ', '', ans)
    ans = re.sub(' $', '', ans)
    ans = re.sub(r'\.$', '', ans)
    return ans

def parse_answers(ans):
    ans = clean_answer(ans)
    ans = nonstandard_operators_pre(ans)
    try:
        ans = parse_latex(ans)
        ans = nonstandard_operators_post(ans)
        ans = str(ans).replace('log(Abs', 'log(')
        # because for SymPy diff(log(Abs(x))) == sign(x)/Abs(x) != 1/x
        if 'operatorname' in ans or 'aux' in ans:
            raise Exception("some operator not parsed correctly")
        parse_expr_l = lambda x: parse_expr(x,
          transformations=trs, local_dict={'e': E, 'x': x_real})
        parse_expr_l = timeout(10)(parse_expr_l)
        ans = parse_expr_l(ans)
        return ans
    except Exception:
        return None

def symbolic_checker(inp, ans):
    try:
        deriv = diff(ans, x_real)
        if deriv == 0:
            return False, 'deriv zero'
        if len(deriv.free_symbols) > 1:
            return False, f'unknown variables present: {ans}'
        if 'Derivative' in str(deriv):
            return False, f'cannot fully differentiate: {ans}'
    except Exception as e:
        return False, f'failed differentiation: {e}'
    expand_force = lambda x: expand(x, force=True)
    expand_log_force = lambda x: expand_log(x, force=True)
    try:
        correct = 0 == simplify(
                       expand_force(
                       expand_log_force(
                       trigsimp(
                       expand_trig(
                           inp - deriv
                       )))))
        if correct == False:
            correct = None
    except Exception as e:
        return None, f'failed simplification: {e}'
    #print(f'symbolic check time: {time() - time0:.2}', file=sys.stderr)
    return correct, 'simplification concluded'

def symbolic_checker_with_integ(inp, ans):
    time0 = time()
    try:
        integral = integrate(inp)
        if 'Integral' in str(integral):
            return None, 'failed integration'
    except Exception as e:
        return None, f'failed integration: {e}'
    expand_force = lambda x: expand(x, force=True)
    expand_log_force = lambda x: expand_log(x, force=True)
    try:
        correct = 0 == simplify(
                       expand_force(
                       expand_log_force(
                       trigsimp(
                       expand_trig(
                           integral - ans
                       )))))
        if correct == False:
            correct = None
        return correct, 'simplification concluded'
    except Exception as e:
        return None, f'failed simplification: {e}'

def numeric_checker(inp, ans):
    time0 = time()
    deriv = diff(ans, x_real) # should pass, was computed before
    expr = -abs(inp - deriv)
    expr = expr.replace(log, lambda e: log(abs(e) + 1e-15)) # TODO keep it?
    x = Symbol('x', real=True)
    try:
        expr_numeric = lambdify(x, expr, 'numpy')
    except Exception as e:
        return None, f'error when lambdifying {expr}: {e}'
    positive_evidence = []
    for i in range(0, 1000, 10):
        for j in [i, -i, 1/(i+1), -1/(1+i)]:
            try:
                expr_numeric_j = expr_numeric(j)
                if abs(expr_numeric_j) > 1e-4:
                    return False, 'pre-check negative'
                if abs(expr_numeric_j) < 1e-8:
                    positive_evidence.append(expr_numeric_j)
            except:
                pass
                #print(f'illegal operation in numeric pre-check', file=sys.stderr)
    try:
        result = minimize_scalar(expr_numeric).fun
    except:
        result = float('nan')
    #print(f'numeric check time: {time() - time0:.2}', file=sys.stderr)
    #print(f'numeric check result: {result}', file=sys.stderr)
    if isnan(result) and len(positive_evidence) > 100:
        return True, 'minimum nan, but long positive evidence'
    return bool(abs(result) < 1e-8), 'numeric check concluded'



if __name__ == "__main__":
    examples = sys.argv[1]
    with open(examples) as f:
        examples = [json.loads(l) for l in f.read().splitlines()]
    eval_results = {}
    inp_parsings, ans_parsings = {}, {}
    print('parsing data...', file=sys.stderr)
    for example in tqdm(examples):
        inp = example['sympy']
        inp_parsed = parse_input(inp)
        inp_parsings[inp] = inp_parsed
        for response in example['responses']:
            ans = response['final_answer']
            if ans in ans_parsings:
                ans_parsed = ans_parsings[ans]
            else:
                ans_parsed = parse_answers(ans)
                ans_parsings[ans] = ans_parsed
            if ans_parsed is None:
                eval_results[(inp_parsed, ans_parsed)] = False, 'failed parsing'
            else:
                eval_results[(inp_parsed, ans_parsed)] = None, None
    remaining = [c for c in eval_results if eval_results[c][0] is None]
    print(f'{len(eval_results)} pairs', file=sys.stderr)
    print(f'{len(remaining)} pairs remaining to evaluate', file=sys.stderr)
    print('symbolic check...', file=sys.stderr)
    with ProcessPool() as pool:
        futures = [pool.schedule(symbolic_checker, args=(inp, ans),
                 timeout=TIMEOUT_SYMBOLIC) for inp, ans in remaining]
        for inp_ans, future in tqdm(list(zip(remaining, futures))):
            try:
                result = future.result()
            except TimeoutError:
                result = (None, f'symbolic checker timeout: {TIMEOUT_SYMBOLIC}')
            correct, comment = result
            eval_results[inp_ans] = (correct, comment)
    remaining = [c for c in eval_results if eval_results[c][0] is None]
    print(f'{len(remaining)} pairs remaining to evaluate', file=sys.stderr)
    print('symbolic check with integration...', file=sys.stderr)
    with ProcessPool() as pool:
        futures = [pool.schedule(symbolic_checker_with_integ, args=(inp, ans),
                 timeout=TIMEOUT_SYMBOLIC) for inp, ans in remaining]
        for inp_ans, future in tqdm(list(zip(remaining, futures))):
            try:
                result = future.result()
            except TimeoutError:
                result = (None, f'symbolic checker timeout: {TIMEOUT_SYMBOLIC}')
            correct, comment = result
            eval_results[inp_ans] = (correct, comment)
    remaining = [c for c in eval_results if eval_results[c][0] is None]
    print(f'{len(remaining)} pairs remaining to evaluate', file=sys.stderr)
    print('numeric check...', file=sys.stderr)
    with ProcessPool() as pool:
        futures = [pool.schedule(numeric_checker, args=(inp, ans),
                 timeout=TIMEOUT_NUMERIC) for inp, ans in remaining]
        for inp_ans, future in tqdm(list(zip(remaining, futures))):
            try:
                result = future.result()
            except TimeoutError:
                result = (None, f'numeric checker timeout: {TIMEOUT_SYMBOLIC}')
            correct, comment = result
            eval_results[inp_ans] = (correct, comment)
    remaining = [c for c in eval_results if eval_results[c][0] is None]
    print(f'{len(remaining)} pairs remaining to evaluate', file=sys.stderr)
    for example in examples:
        inp = example['sympy']
        for response in example['responses']:
            ans = response['final_answer']
            inp_parsed = inp_parsings[inp]
            ans_parsed = ans_parsings[ans]
            correct, comment = eval_results[(inp_parsed, ans_parsed)]
            response['correct'] = correct
            response['comment'] = comment
        print(json.dumps(example), flush=True)

