import os
import io
import numpy as np
import sympy as sp
import torch
import signal
from tqdm import tqdm
import sys
sys.path.insert(0, 'SymbolicMathematics')

from src.utils import AttrDict
from src.envs import build_env
from src.model import build_modules

from src.utils import to_cuda
from src.envs.sympy_utils import simplify

class TimeoutException(Exception): 
    pass

def timeout_handler(signum, frame):
    raise TimeoutException

signal.signal(signal.SIGALRM, timeout_handler)

#!wget https://dl.fbaipublicfiles.com/SymbolicMathematics/models/fwd.pth

# trained model, e.g. "wget https://dl.fbaipublicfiles.com/SymbolicMathematics/models/fwd_bwd.pth"
model_path = './fwd.pth'
assert os.path.isfile(model_path)

params = params = AttrDict({

    # environment parameters
    'env_name': 'char_sp_lc',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 512,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

    # model parameters
    'cpu': False,
    'emb_dim': 1024,
    'n_enc_layers': 6,
    'n_dec_layers': 6,
    'n_heads': 8,
    'dropout': 0,
    'attention_dropout': 0,
    'sinusoidal_embeddings': False,
    'share_inout_emb': True,
    'reload_model': model_path,

})

env = build_env(params)
x = env.local_dict['x']

modules = build_modules(env, params)
encoder = modules['encoder']
decoder = modules['decoder']

path = 'prim_bwd.test'
train = False
if path is not None:
    assert os.path.isfile(path)
#    logger.info(f"Loading data from {path} ...")
    with io.open(path, mode='r', encoding='utf-8') as f:
        # either reload the entire file, or the first N lines (for the training set)
        if not train:
            lines = [line.rstrip().split('|') for line in f]
        else:
            lines = []
            for i, line in tqdm(enumerate(f)):
                if i == 10000000:
                    break
                lines.append(line.rstrip().split('|'))
    data = [xy.split('\t') for _, xy in lines]
    data = [xy for xy in data if len(xy) == 2]
#    logger.info(f"Loaded {len(data)} equations from the disk.")

bwd_test_set_ind = np.load("bwd_test_set_ind_neurips_data_2024.npy")

pos_count = 0
corr_ind = 0
pbar = tqdm(range(0, len(data)))
for ind, i in enumerate(pbar):
    signal.signal(signal.SIGALRM, timeout_handler)
    if i not in bwd_test_set_ind:
        continue
    x1_prefix = data[i][0].split(" ")
    x1 = torch.LongTensor(
        [env.eos_index] +
        [env.word2id[w] for w in x1_prefix] +
        [env.eos_index]
    ).view(-1, 1)
    len1 = torch.LongTensor([len(x1)])
    x1, len1 = to_cuda(x1, len1)

    with torch.no_grad():
        encoded = encoder('fwd', x=x1, lengths=len1, causal=False).transpose(0, 1)
        
    beam_size = 1
    corr_ind+=1
    try:
        signal.alarm(30)
        hyp_func = env.prefix_to_infix(x1_prefix[2:])
        hyp_func = env.infix_to_sympy(hyp_func).expand()
    except:
        print(i)
        print("prefix infix sympy conversion went wrong")
    signal.alarm(0)
    try:
        signal.alarm(2700)
        with torch.no_grad():
            _, _, beam = decoder.generate_beam(encoded, len1, beam_size=beam_size, length_penalty=1.0, early_stopping=1, max_len=200)
            assert len(beam) == 1
        hypotheses = beam[0].hyp
        assert len(hypotheses) == beam_size
    except TimeoutException as tex:
        signal.alarm(0)
        print(i)
        print("function", hyp_func)
        print("------------------------------------------timeout timeout---------------------------------------")
        continue

    signal.alarm(0)
    
    for score, sent in sorted(hypotheses, key=lambda x: x[0], reverse=True):

        # parse decoded hypothesis
        ids = sent[1:].tolist()                  # decoded token IDs
        tok = [env.id2word[wid] for wid in ids]  # convert to prefix

        try:
            signal.alarm(30)
            hyp = env.prefix_to_infix(tok)       # convert to infix
            hyp = env.infix_to_sympy(hyp)        # convert to SymPy

            #simp_ic = simplify(hyp.diff(x) - hyp_func, seconds=1)
            hyp_df = sp.diff(hyp, x)
            signal.alarm(0)
            try:
                signal.alarm(30)
                #simp = (sym_inf-pal).is_zero()
                simp_ic = sp.simplify(hyp_func-hyp_df)
            except TimeoutException as tex:
                signal.alarm(0)
                print(tex)
                print("sym", hyp_func)
                print("ral", ral)
                print("pal", pal)
                print("---------------------------------------timeout no simp-------------------------------------------------")
                continue
            except Exception as ex:
                print(ex)
                print("sym", hyp_func)
                print("ral", ral)
                print("pal", pal)
                print("---------------------------------------no simp-------------------------------------------------")
                continue
            signal.alarm(0)
            try:
                signal.alarm(30)
                if simp_ic == 0 or ('log(x)' in str(simp_ic) and 'log(1/x)' in str(simp_ic)) or sp.simplify(simp_ic).is_constant() :
                    pos_count+=1
                else:
                    print(i)
                    print("function", hyp_func)
                    print("integral", hyp)
                    print("simp_ic", simp_ic)
                    print("------------------------------------------wrong wrong---------------------------------------")
            except:
                signal.alarm(0)
                print(i)
                print("function", hyp_func)
                print("integral", hyp)
                print("simp_ic", simp_ic)
                print("------------------------------------------issue while final evaluation---------------------------------------")
            signal.alarm(0)
        except Exception as e:
            signal.alarm(0)
            print(i)
            print(e)
            print("function", hyp_func)
            print("integral", tok)
            print("------------------------------------------wrong wrong---------------------------------------")

    print("================================================================================================")
                
    pbar.set_description(f"Accuracy {pos_count*100/corr_ind}%, Pos Count {pos_count}, Total {corr_ind}")
print("pos_count", pos_count)
print("corr_ind", corr_ind)