import sympy
from sympy.integrals.manualintegrate import integral_steps
from functools import wraps
import io
import random
import timeout_decorator
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer
import numpy as np

e = ExpressionTokenizer()

# Global counter for explored nodes
node_count = 0

# Wrapper to count the nodes explored
def count_nodes(func):
    @wraps(func)
    def wrapper(integrand, symbol, *args, **kwargs):
        global node_count
        node_count += 1  # Increment the counter every time a rule is applied
        return func(integrand, symbol, *args, **kwargs)
    return wrapper

# Apply the node-counting wrapper to integral_steps
sympy.integrals.manualintegrate.integral_steps = count_nodes(integral_steps)

# Apply the timeout to the manualintegrate function
@timeout_decorator.timeout(5, timeout_exception=StopIteration)
def integrate_with_timeout(expr, symbol):
    return sympy.integrals.manualintegrate.integral_steps(expr, symbol)

x = sympy.Symbol('x')
expressions = []
N_samples = 5000

# Get the list of expressions from the file 
PATH = f'alpha_integrate/synthetic_data/final_steps_dataset/test/prim_fwd_test.txt'
with io.open(PATH, mode='r', encoding='utf-8') as f:
    lines = [line for line in f]
    # Sample 100 random lines
    random.seed(19012002)
    lines = random.sample(lines, N_samples)
    for i, line in enumerate(lines):
        try:
            expr = e.seq_to_sp(line.split('\t')[0].split())
        except:
            continue
        expressions.append(expr)

print(f"Number of expressions: {len(expressions)}")
# Test each expression and count explored nodes
results = {}
for i, expr in enumerate(expressions):
    print(f"Processing: {i+1}/{N_samples}", end='\r')
    node_count = 0  # Reset counter before each expression
    try:
        steps = integrate_with_timeout(expr, x)
        if not steps.contains_dont_know():  # Check if it's a known integral
            results[expr] = node_count
    except:
        print(f"\nTimeout for expression {i+1}, skipping...")

# Output the results
counts = []
for expr, count in results.items():
    counts.append(count)

# Print mean, standard deviation, and quantiles of the node counts, max, min
print(f"Mean: {np.mean(counts)}")
print(f"Std: {np.std(counts)}")
print(f"Quantiles: {np.quantile(counts, [0.25, 0.5, 0.75])}")
print(f"Max: {np.max(counts)}")
print(f"Min: {np.min(counts)}")

# Write node counts newline-separated to sympy_tree.txt
with open("alpha_integrate/train/sympy_tree.txt", "w") as f:
    for expr, count in results.items():
        f.write(f"{count}\n")
