"""
Sample from a trained model
"""

import io
import logging
import sympy as sp
import timeout_decorator
import time
import argparse
import subprocess
import random

date_time_idx = time.strftime("%Y%m%d-%H%M%S")
logging.basicConfig(filename = f'alpha_integrate/train/evaluatelogs/evaluate_{date_time_idx}.txt', filemode='w', level=logging.INFO)

from alpha_integrate.train.config.train_integrate import block_size
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer

sp_conversion_max_secs = 5e-2

def main(args):

    e = ExpressionTokenizer()

    @timeout_decorator.timeout(1)  # Timeout after 1 second
    def seq_to_sp_with_timeout(e, seq_expr) -> sp.Expr:
        return e.seq_to_sp(seq_expr)

    dataset = args.dataset
    testorval = args.testorval
    timeout = args.timeout
    num_expressions = args.num_expressions
    seed = args.seed

    random.seed(seed)

    num_exprs = 0
    total_exprs = 0

    PATH = f'alpha_integrate/synthetic_data/final_steps_dataset/{testorval}/{dataset}_{testorval}.txt'

    logging.info(f"Using the {dataset} test dataset to evaluate sympy: {PATH}")

    with io.open(PATH, mode='r', encoding='utf-8') as f:
        lines = [line for line in f]

    if num_expressions is not None:
        num_expressions = min(num_expressions, len(lines))
        lines = random.sample(lines, num_expressions)

    exprs = []
    sp_exprs = []

    for line in lines:
        total_exprs += 1
        print(f"Count: {total_exprs}", end = '\r')
        expr, _ = line.split('\t')
        expr = expr.split()
        t1 = time.time()
        try:
            sp_expr = seq_to_sp_with_timeout(e, expr)
        except:
            continue
        t2 = time.time()
        if t2 - t1 > sp_conversion_max_secs:
            continue
        if len(expr) < block_size // 2:
            num_exprs += 1
            exprs.append(expr)
            sp_exprs.append(sp_expr)

    logging.info(f"{num_exprs} out of {total_exprs} expressions eligible for testing. ({num_exprs/total_exprs:.2%})")

    sympy_exp_results = []
    sympy_success = 0
    expression_path = f'alpha_integrate/train/temp_files/expression_temp_{dataset}.txt'
    result_path = f'alpha_integrate/train/temp_files/result_temp_{dataset}.txt'
    script = 'alpha_integrate.train.run_sympy'

    for i in range(len(exprs)):
        r = dict()
        expr = exprs[i]
        sp_expr = sp_exprs[i]
        logging.info(f"Expression: {sp_expr}")

        # save the expression into expression_path
        with io.open(expression_path, mode='w', encoding='utf-8') as f:
            f.write(' '.join(expr))

        # run python -m alpha_integrate.train.run_sympy which will try to integrate the expression
        # and save the result into result_path
        cmd = f'python -m {script} --dataset={dataset} --timeout={timeout}'
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        #print(result.stdout)
        #print(result.stderr)

        # read the result from result_path
        with io.open(result_path, mode='r', encoding='utf-8') as f:
            for x in f:
                line = x.split()
                break
        
        if 'TimeoutError' in line:
            sp_result = None
            sp_solved = False
            logging.info(f"Timeout Error")
        elif 'Error' in line:
            sp_result = None
            sp_solved = False
            logging.info(f"Other Error")
        else:
            sp_result = e.seq_to_sp(line)
            if sp_result.has(sp.Integral):
                sp_result = None
                sp_solved = False
            else:
                sp_solved = True
        
        sympy_success += 1 if sp_solved else 0

        r['Sympy_Success'] = sp_solved
        r['Sympy_Result'] = sp_result
        sympy_exp_results.append(r)

        logging.info(f"Sympy Solved: {sp_solved}")
        logging.info(f"Sympy Solution: {sp_result}")  
        logging.info(f"Sympy Success Rate: {(sympy_success/(i+1)):.3%}")

    logging.info(f"Total Expressions: {total_exprs}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='final_steps', help='Dataset to evaluate on')
    parser.add_argument('--testorval', type=str, default='test', help='Test or validation set')
    parser.add_argument('--timeout', type=int, default=5)
    parser.add_argument('--num_expressions', type=int, default=None)
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()
    main(args)
