import cloudpickle as pickle
import os
from sympy import latex
import sympy as sp
from alpha_integrate.train.mathjax_script import mathjax_script

def parse_rule(rule):

    rule_name = rule[0]

    if len(rule) == 1:
        return rule_name
    
    rest = rule[1:]

    if rule_name == "URule":
        var = f"\\({latex(rest[0])}\\)"
        func = f"\\({latex(rest[1])}\\)"
        return f"substitution with {var} = {func}"
    
    if rule_name == "PartsRule":
        u = f"\\({latex(rest[0])}\\)"
        dv = f"\\({latex(rest[1])}\\)"
        return f"integration by parts with {u} and {dv}"



dataset = 'prim_ibp'
res_file = f"alpha_integrate/train/results_{dataset}.pkl"
output_folder = f"alpha_integrate/train/{dataset}_html/"

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

with open(res_file, 'rb') as f:
    res = pickle.load(f)

total_t_success = 0

for i in range(len(res)):
    r = res[i]
    t_success = 1 if r['Transformer_Success'] else 0
    total_t_success += t_success

    print(f"Expression {i+1}: {r['Integral']}")
    print(f"Transformer Success: {r['Transformer_Success']}")
    if t_success:
        prev_file = f"solution_{total_t_success-1}.html" if total_t_success > 1 else "#"
        next_file = f"solution_{total_t_success+1}.html" if total_t_success < len(res) else "#"
        with open(f"{output_folder}/solution_{total_t_success}.html", 'w') as html_file:
            html_file.write(mathjax_script)
            html_file.write("<body>\n")

            html_file.write("<div class='navigation-buttons'>\n")
            html_file.write(f"<a href='{prev_file}' class='button'>Previous</a>\n")
            html_file.write(f"<a href='{next_file}' class='button'>Next</a>\n")
            html_file.write("</div>\n")
            html_file.write(f"<h2>Step by Step Integration of \\( {latex(r['Integral'])} \\)</h2>\n")

            solutions = r['Transformer_Steps']
            for j, s in enumerate(solutions):
                html_file.write(f"<div class='step-header'>Solution {j+1}</div>\n")
                html_file.write("<div class='step-content'>\n")
                for step in s:
                    expr = '\\(' + latex(step[0]) + '\\)'
                    subexpr = '\\(' + latex(step[1]) + '\\)'
                    rule = parse_rule(step[2])
                    result = '\\(' + latex(step[3]) + '\\)'
                    html_file.write(f"<p>Applied {rule} on {subexpr} to transform {expr} <span class='arrow'>&#8594;</span> {result}</p>\n")
                html_file.write("</div>\n")
            html_file.write("</body></html>")

print(f"Total number of expressions: {len(res)}")
print(f"Total number of expressions solved by Transformer: {total_t_success}, {total_t_success/len(res):.2%}")
