"""
Evaluate transformer model
"""
import os
import io
import logging
import timeout_decorator
import time
import argparse
import subprocess
import random
# import subprocess TimeoutExpired
from subprocess import TimeoutExpired

date_time_idx = time.strftime("%Y%m%d-%H%M%S")
logging.basicConfig(filename = f'alpha_integrate/train/evaluatelogs/treecount_{date_time_idx}.txt', filemode='w', level=logging.INFO)

from alpha_integrate.train.config.train_integrate import block_size
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer
from alpha_integrate.synthetic_data.expression_tokenizer import SYMBOLS, INTEGRAL
from alpha_integrate.train.timeout_error import TimeoutErr

e = ExpressionTokenizer()
seq_to_sp_time = 1e-1

@timeout_decorator.timeout(1, timeout_exception=TimeoutErr)
def seq_to_sp_timeout(seq):
    return e.seq_to_sp(seq)

def main(args):

    model = args.model
    beam_size = args.beam_size
    dataset = args.dataset
    testorval = args.testorval
    timeout = args.timeout

    @timeout_decorator.timeout(timeout, timeout_exception=TimeoutErr)
    def call_run_transformer(expr_file, model, beam_size, timeout):
        cmd = f"python -m alpha_integrate.train.run_transformer_treecount  --expr_file={expr_file}" 
        cmd += f" --model={model} --beam_size={beam_size}" 
        logging.info(f"Calling: {cmd}")
        # don't capture input output, add timeout
        result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout)
        

    expr_res_dir = "alpha_integrate/train/temp_files"
    os.makedirs(expr_res_dir, exist_ok=True)
    
    expr_file = expr_res_dir + f"/expr_{model}_{dataset}_{testorval}.txt"
    res_file = expr_res_dir + f"/res_{model}_{dataset}_{testorval}.txt"

    save_dir = f"alpha_integrate/train/integrals/{model}/{dataset}/{testorval}"
    os.makedirs(save_dir, exist_ok=True)

    PATH = f'alpha_integrate/synthetic_data/final_steps_dataset/{testorval}/{dataset}_{testorval}.txt'
    N_samples = 5000    
    logging.info(f"Testing model {model} using test dataset: {PATH}")
    with io.open(PATH, mode='r', encoding='utf-8') as f:
        lines = [line for line in f]
        random.seed(19012002)
        lines = random.sample(lines, N_samples)
        
    total_exprs = 0 
    num_exprs = 0

    exprs = []
    sp_exprs = []
    for line in lines:
        #print(f"Count: {total_exprs}", end = '\r')
        total_exprs += 1
        expr, res = line.split('\t')
        expr = expr.split()
        t1 = time.time()
        try:
            sp_expr = seq_to_sp_timeout(expr)
        except:
            continue
        t2 = time.time()
        if t2 - t1 > seq_to_sp_time:
            continue
        if len(expr) < block_size // 2:
            num_exprs += 1
            exprs.append(expr)
            sp_exprs.append(sp_expr)

    logging.info(f"\n{num_exprs} out of {total_exprs} expressions eligible for testing. ({num_exprs/total_exprs:.2%})")

    correct = 0

    for i in range(len(exprs)):
        # save expression into expr_file
        print(f"Processing expression {i+1}/{len(exprs)}", end = '\r')

        logging.info(f"Expression {i+1}: {sp_exprs[i]}")
        
        with io.open(expr_file, mode='w', encoding='utf-8') as f:
            f.write(' '.join(exprs[i]))

        # call run_transformer.py with timeout
        try:
            call_run_transformer(expr_file, model, beam_size, timeout)

        except TimeoutErr:
            logging.info('Timed out')
            solved = False
        except TimeoutExpired:
            logging.info('Timed out')
            solved = False
        except:
            logging.info('Error')
            solved = False

    print()



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default=None)
    parser.add_argument('--beam_size', type=int, default=None)
    parser.add_argument('--dataset', type=str, default=None)
    parser.add_argument('--testorval', type=str, default=None)
    parser.add_argument('--timeout', type=int, default=None)
    #parser.add_argument('--num_expressions', type=int, default=None)
    #parser.add_argument('--seed', type=int, default=0)

    args = parser.parse_args()
    main(args)