"""
Evaluate transformer model
"""
import os
import io
import logging
import timeout_decorator
import time
import argparse
import subprocess
# import subprocess TimeoutExpired
from subprocess import TimeoutExpired

date_time_idx = time.strftime("%Y%m%d-%H%M%S")
logging.basicConfig(filename = f'alpha_integrate/train/evaluatelogs/evaluate_{date_time_idx}.txt', filemode='w', level=logging.INFO)

from alpha_integrate.train.config.train_integrate import block_size
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer
from alpha_integrate.synthetic_data.expression_tokenizer import SYMBOLS, INTEGRAL
from alpha_integrate.train.timeout_error import TimeoutErr

e = ExpressionTokenizer()
seq_to_sp_time = 1e-1

@timeout_decorator.timeout(1, timeout_exception=TimeoutErr)
def seq_to_sp_timeout(seq):
    return e.seq_to_sp(seq)

def main(args):

    model = args.model
    beam_size = args.beam_size
    dataset = args.dataset
    testorval = args.testorval
    timeout = args.timeout

    @timeout_decorator.timeout(timeout, timeout_exception=TimeoutErr)
    def call_run_transformer(expr_file, model, beam_size, res_file, save_file, timeout):
        cmd = f"python -m alpha_integrate.train.run_transformer  --expr_file={expr_file}" 
        cmd += f" --model={model} --beam_size={beam_size} --res_file={res_file} --save_file={save_file}" 
        logging.info(f"Calling: {cmd}")
        # don't capture input output, add timeout
        result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout)
        

    expr_res_dir = "alpha_integrate/train/temp_files"
    os.makedirs(expr_res_dir, exist_ok=True)
    
    expr_file = expr_res_dir + f"/expr_{model}_{dataset}_{testorval}.txt"
    res_file = expr_res_dir + f"/res_{model}_{dataset}_{testorval}.txt"

    save_dir = f"alpha_integrate/train/integrals/{model}/{dataset}/{testorval}"
    os.makedirs(save_dir, exist_ok=True)

    PATH = f'alpha_integrate/synthetic_data/final_steps_dataset/{testorval}/{dataset}_{testorval}.txt'
    logging.info(f"Testing model {model} using test dataset: {PATH}")
    with io.open(PATH, mode='r', encoding='utf-8') as f:
        lines = [line for line in f]

    total_exprs = 0 
    num_exprs = 0

    exprs = []
    results = []
    sp_exprs = []
    sp_results = []
    for line in lines:
        #print(f"Count: {total_exprs}", end = '\r')
        total_exprs += 1
        expr, res = line.split('\t')
        expr = expr.split()
        res = res.split()
        t1 = time.time()
        try:
            sp_expr = seq_to_sp_timeout(expr)
            sp_res = seq_to_sp_timeout(res)
        except:
            continue
        t2 = time.time()
        if t2 - t1 > seq_to_sp_time:
            continue
        if len(expr) < block_size // 2:
            num_exprs += 1
            exprs.append(expr)
            results.append(res)
            sp_exprs.append(sp_expr)
            sp_results.append(sp_res)

    logging.info(f"\n{num_exprs} out of {total_exprs} expressions eligible for testing. ({num_exprs/total_exprs:.2%})")

    results_ls = []
    correct = 0

    for i in range(len(exprs)):
        # save expression into expr_file

        logging.info(f"Expression {i+1}: {sp_exprs[i]}")
        
        with io.open(expr_file, mode='w', encoding='utf-8') as f:
            f.write(' '.join(exprs[i]))

        save_file = save_dir + f"/integral_{i}.pkl"

        # call run_transformer.py with timeout
        try:
            call_run_transformer(expr_file, model, beam_size, res_file, save_file, timeout)
            # if the call is completed, we check the result
            with io.open(res_file, mode='r', encoding='utf-8') as f:
                for x in f:
                    if 'Error' in x or 'Failure' in x:
                        solved = False
                        break
                    elif 'Success' in x:
                        solved = True
                    else:
                        # parse the result from second line
                        res = x.split()
                        break
        except TimeoutErr:
            logging.info('Timed out')
            solved = False
        except TimeoutExpired:
            logging.info('Timed out')
            solved = False
        except:
            logging.info('Error')
            solved = False

        if solved:
            correct += 1
            try:
                logging.info(f"Result: {seq_to_sp_timeout(res)}")
            except:
                logging.info(f"Result: {res}")
            logging.info(f"Solution: {sp_results[i]}")
        else:
            logging.info(f'Transformer could not solve.')
        
        logging.info(f"Accuracy of transformer: {(correct/(i+1)):.3%}")
        logging.info('-----------------')
        



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default=None)
    parser.add_argument('--beam_size', type=int, default=None)
    parser.add_argument('--dataset', type=str, default=None)
    parser.add_argument('--testorval', type=str, default=None)
    parser.add_argument('--timeout', type=int, default=None)
    #parser.add_argument('--num_expressions', type=int, default=None)
    #parser.add_argument('--seed', type=int, default=0)

    args = parser.parse_args()
    main(args)