"""
Evaluate transformer model
"""
import os
import io
import logging
import timeout_decorator
import time
import argparse
import subprocess
import random
import sympy as sp
# import subprocess TimeoutExpired
from subprocess import TimeoutExpired

date_time_idx = time.strftime("%Y%m%d-%H%M%S")
logging.basicConfig(filename = f'alpha_integrate/train/evaluatelogs/evaltransformerllm_{date_time_idx}.txt', filemode='w', level=logging.INFO)

from alpha_integrate.train.config.train_integrate import block_size
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer
from alpha_integrate.synthetic_data.expression_tokenizer import SYMBOLS, INTEGRAL
from alpha_integrate.train.timeout_error import TimeoutErr

e = ExpressionTokenizer()
seq_to_sp_time = 1e-1

@timeout_decorator.timeout(1, timeout_exception=TimeoutErr)
def seq_to_sp_timeout(seq):
    return e.seq_to_sp(seq)

def main(args):

    model = args.model
    beam_size = args.beam_size
    timeout = args.timeout

    @timeout_decorator.timeout(timeout, timeout_exception=TimeoutErr)
    def call_run_transformer(expr_file, model, beam_size, res_file, timeout):
        cmd = f"python -m alpha_integrate.train.run_transformer_llm --expr_file={expr_file}" 
        cmd += f" --model={model} --beam_size={beam_size} --res_file={res_file}" 
        logging.info(f"Calling: {cmd}")
        # don't capture input output, add timeout
        result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout)
        
    expr_res_dir = "alpha_integrate/train/temp_files"
    os.makedirs(expr_res_dir, exist_ok=True)
    
    expr_file = expr_res_dir + f"/expr_{model}_llm5.txt"
    res_file = expr_res_dir + f"/res_{model}_llm5.txt"   

    # Define variable
    x = sp.Symbol('x')

    # Create 100 different integrals
    Nmax = 50
    expressions = [random.randint(1,Nmax) * x for _ in range(500)]

    lines = [' '.join(e.sp_to_seq(expr)) + '\t ' for expr in expressions]
        
    total_exprs = 0 
    num_exprs = 0

    exprs = []
    sp_exprs = []
    for line in lines:
        #print(f"Count: {total_exprs}", end = '\r')
        total_exprs += 1
        expr, res = line.split('\t')
        expr = expr.split()
        t1 = time.time()
        try:
            sp_expr = seq_to_sp_timeout(expr)
        except:
            continue
        t2 = time.time()
        if t2 - t1 > seq_to_sp_time:
            continue
        if len(expr) < block_size // 2:
            num_exprs += 1
            exprs.append(expr)
            sp_exprs.append(sp_expr)

    logging.info(f"\n{num_exprs} out of {total_exprs} expressions eligible for testing. ({num_exprs/total_exprs:.2%})")

    correct = 0

    for i in range(len(exprs)):
        # save expression into expr_file
        print(f"Processing expression {i+1}/{len(exprs)}", end = '\r')

        logging.info(f"Expression {i+1}: {sp_exprs[i]}")
        
        with io.open(expr_file, mode='w', encoding='utf-8') as f:
            f.write(' '.join(exprs[i]))

        # call run_transformer.py with timeout
        try:
            call_run_transformer(expr_file, model, beam_size, res_file, timeout)
            # if the call is completed, we check the result
            with io.open(res_file, mode='r', encoding='utf-8') as f:
                for x in f:
                    if 'Error' in x or 'Failure' in x:
                        solved = False
                        break
                    elif 'Success' in x:
                        solved = True
                    else:
                        # parse the result from second line
                        res = x.split()
                        break
        except TimeoutErr:
            logging.info('Timed out')
            solved = False
        except TimeoutExpired:
            logging.info('Timed out')
            solved = False
        except:
            logging.info('Error')
            solved = False

        if solved:
            correct += 1
            try:
                logging.info(f"Result: {seq_to_sp_timeout(res)}")
            except:
                logging.info(f"Result: {res}")
        else:
            logging.info(f'Transformer could not solve.')
        
        logging.info(f"Accuracy of transformer: {(correct/(i+1)):.3%}")
        logging.info('-----------------')

    print()



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default=None)
    parser.add_argument('--beam_size', type=int, default=None)
    parser.add_argument('--timeout', type=int, default=None)
    #parser.add_argument('--num_expressions', type=int, default=None)
    #parser.add_argument('--seed', type=int, default=0)

    args = parser.parse_args()
    main(args)