#from sympy.integrals.manualintegrate import manualintegrate, _manualintegrate, heaviside_pattern

# Rules
#from sympy.integrals.manualintegrate import manual_subs, contains_dont_know, IntegralInfo, ExpRule, AddRule, ConstantRule, ConstantTimesRule, ReciprocalRule, PowerRule, PiecewiseRule, DerivativeRule, DontKnowRule, RewriteRule, TrigRule, URule, ArctanRule, ArccothRule, ArctanhRule, EiRule, CiRule, ChiRule, SiRule, ShiRule, LiRule, ErfRule, FresnelSRule, FresnelCRule, UpperGammaRule, PolylogRule, EllipticFRule, EllipticERule, InverseHyperbolicRule, ArcsinRule, JacobiRule, GegenbauerRule, ChebyshevTRule, ChebyshevURule, LegendreRule, HermiteRule, LaguerreRule, AssocLaguerreRule

#from sympy.core.logic import fuzzy_not
from sympy.abc import a, x, y
#from manualintegrate import *
#from manualintegrate import _manualintegrate
import sympy as sp
from sympy import *
from sympy.abc import a, x, u
from sympy import Integral, cos, sqrt, Poly, erfi
from sympy import Symbol, exp, log, tan, sin, pi, diff
from sympy.parsing.sympy_parser import parse_expr

import pandas as pd
import numpy as np
from tqdm import tqdm
import re
import random
import csv
import gc
import signal

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import TensorDataset
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
import torch.optim as optim
from transformers import AdamW, get_linear_schedule_with_warmup

#from sympy.integrals.manualintegrate import (
#    _manualintegrate, integral_steps, evaluates,
#    ConstantRule, ConstantTimesRule, PowerRule, AddRule, URule,
#    PartsRule, CyclicPartsRule, TrigRule, ExpRule, ReciprocalRule, ArctanRule,
#    AlternativeRule, DontKnowRule, RewriteRule
#)
import sys
sys.path.insert(0, 'SymbolicMathematics')
import os
import torch
from src.envs.char_sp import CharSPEnvironment

from src.envs.sympy_utils import remove_root_constant_terms, reduce_coefficients, reindex_coefficients
from src.utils import AttrDict
from src.envs import build_env
from src.model import build_modules

from src.utils import to_cuda
from src.envs.sympy_utils import simplify
from sympy.functions.elementary.trigonometric import TrigonometricFunction

from model_var import *

params = params = AttrDict({

    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 384,
    'max_int': 11,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'reload_data': "prim_fwd.test",
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,erfi:1,erf:1,erfc:1,erfinv:1,erfcinv:1,expint:1,ei:1,li:1,si:1,ci:1,shi:1,chi:1,fresnelc:1,fresnels:1,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

})

env = build_env(params)
x = env.local_dict['x']

SPECIAL_WORDS = ['<pad>', '<s>', '</s>']
constants = ['pi', 'E']
variables = ['x', 'y', 'z', 't', 'u', '_u', '_theta']
symbols = ['I', 'INT+', 'INT-', 'INT', 'FLOAT', '-', '.', '10^']
int_base = 60
max_digit = (int_base + 1) // 2
elements = [str(i) for i in range(max_digit - abs(int_base), max_digit)]
OPERATORS = {
    # Elementary functions
    'add': 2,
    'sub': 2,
    'mul': 2,
    'div': 2,
    'pow': 2,
    'rac': 2,
    'inv': 1,
    'pow2': 1,
    'pow3': 1,
    'pow4': 1,
    'pow5': 1,
    'sqrt': 1,
    'exp': 1,
    'ln': 1,
    'abs': 1,
    'sign': 1,
    # Trigonometric Functions
    'sin': 1,
    'cos': 1,
    'tan': 1,
    'cot': 1,
    'sec': 1,
    'csc': 1,
    # Trigonometric Inverses
    'asin': 1,
    'acos': 1,
    'atan': 1,
    'acot': 1,
    'asec': 1,
    'acsc': 1,
    # Hyperbolic Functions
    'sinh': 1,
    'cosh': 1,
    'tanh': 1,
    'coth': 1,
    'sech': 1,
    'csch': 1,
    # Hyperbolic Inverses
    'asinh': 1,
    'acosh': 1,
    'atanh': 1,
    'acoth': 1,
    'asech': 1,
    'acsch': 1,
    'erfi': 1,
    'erf': 1,
    'erfc': 1,
    'erfinv': 1,
    'erfcinv': 1,
    'expint': 1,
    'ei': 1,
    'li': 1,
    'si': 1,
    'ci': 1,
    'shi': 1,
    'chi': 1,
    'fresnelc': 1,
    'fresnels': 1,
    # Derivative
    'derivative': 2,
    # custom functions
    'f': 1,
    'g': 2,
    'h': 3,
}
'''rules = ['add_rule', 'constant_rule', 'dont_know_rule', 'exp_rule',
         'inverse_trig_rule', 'mul_rule', 'parts_rule', 'power_rule',
         'quadratic_denom_rule', 'root_mul_rule', 'special_function_rule',
         'steps_rule', 'substitution_rule', 'trig_expand_rule',
         'trig_powers_products_rule', 'trig_product_rule', 'trig_rule',
         'trig_substitution_rule']'''
rules = ['power_rule', 'mul_rule', 'add_rule', 'partial_fractions_rule',
         'exp_rule', 'trig_rule', 'substitution_rule', 'parts_rule',
         'constant_rule', 'quadratic_denom_rule', 'cancel_rule',
         'distribute_expand_rule', 'sqrt_linear_rule',
         'trig_sindouble_rule', 'trig_expand_rule',
         'sqrt_quadratic_rule', 'hyperbolic_rule', 'trig_tansec_rule',
         'trig_sincos_rule', 'inverse_trig_rule', 'trig_cotcsc_rule',
         'trig_substitution_rule', 'special_function_rule',
         'trig_product_rule']

words = SPECIAL_WORDS + constants + variables + sorted(list(OPERATORS.keys())) + symbols + elements + rules + ['e', 'f', 'g', 'h']
id2word = {i: s for i, s in enumerate(words)}
word2id = {s: i for i, s in id2word.items()}

from manualintegrate_model import integral_steps_model

from manualintegrate_model import _manualintegrate as _manualintegrate_model

#from manualintegrate_orig import integral_steps

#from manualintegrate_orig import _manualintegrate

import io
import time

import core
#import core_orig

class TimeoutExceptionSym(Exception): 
    pass

def timeout_handler(signum, frame):
    raise TimeoutExceptionSym

signal.signal(signal.SIGALRM, timeout_handler)

path = 'prim_bwd.test'
train = False
if path is not None:
    assert os.path.isfile(path)
#    logger.info(f"Loading data from {path} ...")
    with io.open(path, mode='r', encoding='utf-8') as f:
        # either reload the entire file, or the first N lines (for the training set)
        if not train:
            lines = [line.rstrip().split('|') for line in f]
        else:
            lines = []
            for i, line in tqdm(enumerate(f)):
                if i == 10000000:
                    break
                lines.append(line.rstrip().split('|'))
    data = [xy.split('\t') for _, xy in lines]
    data = [xy for xy in data if len(xy) == 2]
#    logger.info(f"Loaded {len(data)} equations from the disk.")

## integral_steps_model

pos_count = 0
corr_ind = 0
time_taken=0
pbar = tqdm(range(0, len(data)))
index_file_name = 'bwd_test_set_ind_neurips_data_2024.npy'
nodes_file_name = 'bwd_test_set_model_nodes_neurips_data_2024.npy'
time_file_name = 'bwd_test_set_model_time_neurips_data_2024.npy'
if (os.path.isfile(index_file_name) and os.path.isfile(nodes_file_name) and os.path.isfile(time_file_name)):
    print("recovering old files")
    bwd_test_set_ind = np.load(index_file_name)
    bwd_test_set_ind = bwd_test_set_ind.tolist()
    bwd_test_set_model_nodes = np.load(nodes_file_name)
    bwd_test_set_model_nodes = bwd_test_set_model_nodes.tolist()
    bwd_test_set_model_time = np.load(time_file_name)
    bwd_test_set_model_time = bwd_test_set_model_time.tolist()
else:
    bwd_test_set_ind = []
    bwd_test_set_model_nodes = []
    bwd_test_set_model_time = []
for ind, i in enumerate(pbar):
    signal.signal(signal.SIGALRM, timeout_handler)
    if corr_ind != 0:
        pbar.set_description(f"Accuracy {pos_count*100/corr_ind}%, Pos Count {pos_count}, Total {corr_ind}")
    #if i not in bwd_test_set_ind:
    #    continue
    linf_replace = [s.replace('_u', 'x') for s in data[i][0].split(" ")[2:]]
    hyp_inf = env.prefix_to_infix(linf_replace)
    try:
        sym_inf = env.infix_to_sympy(hyp_inf).expand()
        if 'sinh' in str(sym_inf) or 'cosh' in str(sym_inf) or 'tanh' in str(sym_inf):
            continue
    except Exception as ex:
        print(ex)
        print("infix to sympy issue")
        print("-------------------------------------------------------------------")
        continue
    try:
        signal.alarm(2700)
        global nodes_explored
        core.nodes_config.nodes_explored = 0
        start_time = time.time()
        ral_steps = integral_steps_model(data[i][0].split(" ")[2:], x, True)
        time_taken = (time.time() - start_time)%60
        if "DontKnowRule" not in str(ral_steps): 
            ral = _manualintegrate_model(ral_steps)
            signal.alarm(0)
        else:
            print("DontKnowRule", i)
            print("----------------------------------------------------------------------------------------")
            corr_ind+=1
            bwd_test_set_ind.append(i)
            bwd_test_set_model_time.append(time_taken)
            bwd_test_set_model_nodes.append(core.nodes_config.nodes_explored)
            np.save(index_file_name, bwd_test_set_ind)
            np.save(time_file_name, bwd_test_set_model_time)
            np.save(nodes_file_name, bwd_test_set_model_nodes)
            signal.alarm(0)
            continue
    except TimeoutExceptionSym as tex:
        time_taken = (time.time() - start_time)%60
        corr_ind+=1
        bwd_test_set_ind.append(i)
        bwd_test_set_model_time.append(time_taken)
        bwd_test_set_model_nodes.append(core.nodes_config.nodes_explored)
        np.save(index_file_name, bwd_test_set_ind)
        np.save(time_file_name, bwd_test_set_model_time)
        np.save(nodes_file_name, bwd_test_set_model_nodes)
        print(tex)
        print("timeout sym not able", sym_inf)
        print("----------------------------------------------------------------------------------------")
        continue
    except Exception as ex:
        signal.alarm(0)
        if str(ex) == "node_limit_breached":
            time_taken = (time.time() - start_time)%60
            print("========node_limit_breached=========")
            print("sym not able", sym_inf)
            print("----------------------------------------------------------------------------------------")
            corr_ind+=1
            bwd_test_set_ind.append(i)
            bwd_test_set_model_time.append(time_taken)
            bwd_test_set_model_nodes.append(core.nodes_config.nodes_explored)
            np.save(index_file_name, bwd_test_set_ind)
            np.save(time_file_name, bwd_test_set_model_time)
            np.save(nodes_file_name, bwd_test_set_model_nodes)
        else:
            print(ex)
            print("sym not able", sym_inf)
            print("----------------------------------------------------------------------------------------")
        continue
    try:
        signal.alarm(30)
        pal = sp.diff(ral, x)
        signal.alarm(0)
    except TimeoutExceptionSym as tex:
        print(tex)
        print("sym", sym_inf)
        print("ral", ral)
        print("timeout no pal")
        print("----------------------------------------------------------------------------------------")
        continue
    except Exception as ex:
        print(ex)
        print("sym", sym_inf)
        print("ral", ral)
        print("no pal")
        print("----------------------------------------------------------------------------------------")
        continue
    try:
        signal.alarm(30)
        #simp = (sym_inf-pal).is_zero()
        simp_ic = sp.simplify(sym_inf-pal)
        signal.alarm(0)
    except TimeoutExceptionSym as tex:
        print(tex)
        print("sym", sym_inf)
        print("ral", ral)
        print("pal", pal)
        print("---------------------------------------timeout no simp-------------------------------------------------")
        continue
    except Exception as ex:
        print(ex)
        print("sym", sym_inf)
        print("ral", ral)
        print("pal", pal)
        print("---------------------------------------no simp-------------------------------------------------")
        continue
    try:
        signal.alarm(30)
        if simp_ic == 0 or ('log(x)' in str(simp_ic) and 'log(1/x)' in str(simp_ic)) or sp.simplify(simp_ic).is_constant() :
            pos_count+=1
            #print(pos_count/corr_ind)
        else:
            print("sym", sym_inf)
            print("ral", ral)
            print("pal", pal)
            print("ral-pal", simp_ic)
            print("------------------------------------------wrong wrong----------------------------------------------")
        corr_ind+=1
        signal.alarm(0)
        bwd_test_set_ind.append(i)
        bwd_test_set_model_time.append(time_taken)
        bwd_test_set_model_nodes.append(core.nodes_config.nodes_explored)
        np.save(index_file_name, bwd_test_set_ind)
        np.save(time_file_name, bwd_test_set_model_time)
        np.save(nodes_file_name, bwd_test_set_model_nodes)
    except TimeoutExceptionSym as tex:
        print(tex)
        print("sym", sym_inf)
        print("ral", ral)
        print("pal", pal)
        print("ral-pal", simp_ic)
        print("------------------------------------------timeout wrong wrong----------------------------------------------")
        continue
    pbar.set_description(f"Accuracy {pos_count*100/corr_ind}%, Pos Count {pos_count}, Total {corr_ind}")
print("pos_count", pos_count)
print("corr_ind", corr_ind)
print("---------------------------------------------------------------------------------------------")
np.save(index_file_name, bwd_test_set_ind)
np.save(time_file_name, bwd_test_set_model_time)
np.save(nodes_file_name, bwd_test_set_model_nodes)
