from openai import OpenAI
import io
import logging
import sympy as sp
import timeout_decorator
import time
import argparse
import random
from sympy import latex, simplify, E
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
import re
import os


date_time_idx = time.strftime("%Y%m%d-%H%M%S")
logging.basicConfig(filename = f'alpha_integrate/llm/evaluatelogs/evaluate_{date_time_idx}.txt', filemode='w', level=logging.INFO)

from alpha_integrate.train.config.train_integrate import block_size
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer
from alpha_integrate.llm.constants import system_prompt

client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])

sp_conversion_max_secs = 1e-1
finetune = False

'''
def extract_answer(text):
    if "sympy.sympify(" not in text:
        return None
    
    start_index = text.find("sympy.sympify(") + len("sympy.sympify(")
    brace_count = 1  # Start after the first '('
    end_index = start_index

    while end_index < len(text) and brace_count > 0:
        if text[end_index] == '(':
            brace_count += 1
        elif text[end_index] == ')':
            brace_count -= 1
        end_index += 1
    
    if brace_count == 0:
        return text[start_index:end_index - 1]
    else:
        return None
'''

def extract_answer(text):
    searchfor = "sp.parse_expr('"
    if searchfor not in text:
        return None
    
    start_index = text.rfind(searchfor) + len(searchfor)
    end_index = text.find("')", start_index)

    if end_index == -1:
        return None
    
    return text[start_index:end_index]

def extract_answer_finetune(text):
    searchfor = "sp.parse_expr("
    if searchfor not in text:
        return None
    
    start_index = text.rfind(searchfor) + len(searchfor)
    end_index = text.find(")", start_index)

    if end_index == -1:
        return None
    
    return text[start_index:end_index]

def extract_boxed_content(text):
    if r"\boxed{" not in text:
        return None

    start_index = text.find(r"\boxed{") + len(r"\boxed{")
    brace_count = 1  # Start after the first '{'
    end_index = start_index

    while end_index < len(text) and brace_count > 0:
        #logging.info(f"Current text: {text[start_index:end_index+1]}")
        #logging.info(f"Brace Count: {brace_count}")
        if text[end_index] == '{':
            brace_count += 1
        elif text[end_index] == '}':
            brace_count -= 1
        end_index += 1

    if brace_count == 0:
        return text[start_index:end_index - 1]
    else:
        return None  # Indicates mismatched braces
    
def main(args):

    e = ExpressionTokenizer()

    @timeout_decorator.timeout(1)  # Timeout after 1 second
    def seq_to_sp_with_timeout(e, seq_expr) -> sp.Expr:
        return e.seq_to_sp(seq_expr)

    @timeout_decorator.timeout(3)  # Timeout after 3 second
    def parseexpr_timeout(expr_str):
        local_dict = {'e': E, 'E': E}
        return parse_expr(expr_str, local_dict=local_dict)

    @timeout_decorator.timeout(3)  # Timeout after 3 second
    def simplify_timeout(ai_result, sp_expr):
        return simplify(ai_result - sp_expr)

    @timeout_decorator.timeout(3)  # Timeout after 3 second
    def differentiate_timeout(expr):
        return sp.diff(expr, sp.Symbol('x'))

    dataset = args.dataset
    testorval = args.testorval
    num_expressions = args.num_expressions
    seed = args.seed
    model = args.model
    

    random.seed(seed)

    num_exprs = 0
    total_exprs = 0

    PATH = f'alpha_integrate/synthetic_data/final_steps_dataset/{testorval}/{dataset}_{testorval}.txt'

    logging.info(f"Using the {dataset} test dataset to evaluate GPT-4o: {PATH}")

    with io.open(PATH, mode='r', encoding='utf-8') as f:
        lines = [line for line in f]

    if num_expressions is not None:
        num_expressions = min(num_expressions, len(lines))
        lines = random.sample(lines, num_expressions)

    exprs = []
    sp_exprs = []
    sp_results = []

    for line in lines:
        total_exprs += 1
        expr, result = line.split('\t')
        expr = expr.split()
        result = result.split()
        t1 = time.time()
        try:
            sp_expr = seq_to_sp_with_timeout(e, expr)
            sp_result = seq_to_sp_with_timeout(e, result)
        except:
            continue
        t2 = time.time()
        if t2 - t1 > sp_conversion_max_secs:
            continue
        if len(expr) < block_size // 2:
            num_exprs += 1
            exprs.append(expr)
            sp_exprs.append(sp_expr)
            sp_results.append(sp_result)

    logging.info(f"{num_exprs} out of {total_exprs} expressions eligible for testing. ({num_exprs/total_exprs:.2%})")

    success = 0

    for i in range(len(exprs)):
        sp_expr = sp_exprs[i]
        sp_result = sp_results[i]

        logging.info(f"Expression {i+1}: {sp_expr}")    

        user_prompt = f"Expression to integrate: {sp.Integral(sp_expr, sp.Symbol('x'))}"
        logging.info(f"User Prompt: {user_prompt}")

        try:
            response = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": system_prompt}, {
                    "role": "user", "content": user_prompt}
                ],
                max_tokens=2048,
                n=1,
                stop=None,
                temperature=0.9,
            )
            #print(response.choices[0].message)
            #print(type(response.choices[0].message))
            ai_result_str = response.choices[0].message.content
            logging.info(f"GPT-4o Response: {ai_result_str}")
        except Exception as e:
            logging.error(f"Error in GPT-4o response: {str(e)}")
            continue

        try:
            if finetune:
                answer = extract_answer_finetune(ai_result_str)
            else:
                answer = extract_answer(ai_result_str)
            logging.info(f"Answer: {answer}")
            if answer is not None:
                try:
                    ai_result = parseexpr_timeout(answer)
                except Exception as e:
                    logging.info(f"Error parsing AI result: {str(e)}")
                    continue
            else:
                logging.info("No answer found.")
                continue
        except Exception as e:
            logging.info(f"Error parsing AI result: {str(e)}")
            continue

        successnow = False
        if ai_result.has(sp.Integral):
            logging.info("AI result contains an integral.")
            continue

        simplified_difference = None
        difference_with_sp_result = None
        try:
            ai_derivative = differentiate_timeout(ai_result)
            simplified_difference = simplify_timeout(ai_derivative, sp_expr)
        except Exception as e:
            #logging.info(f"Error simplifying difference: {str(e)}")
            pass
        
        try:
            difference_with_sp_result = simplify_timeout(ai_result, sp_result)
        except Exception as e:
            pass

        if simplified_difference is None and difference_with_sp_result is None:
            logging.info("Error in simplification.")
            continue

        if simplified_difference == 0 or (isinstance(difference_with_sp_result, sp.Expr) and difference_with_sp_result.is_constant()):
            success += 1
            successnow = True

        logging.info(f"True Result: {sp_result}")
        if successnow:
            logging.info(f"AI Result: {ai_result} (Correct)")
        else:
            logging.info(f"AI Result: {ai_result} (Incorrect)")

        logging.info(f"AI Success Rate: {(success/(i+1)):.3%}")
        logging.info("-----------------------------")

    logging.info(f"Total Expressions: {total_exprs}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='final_steps', help='Dataset to evaluate on')
    parser.add_argument('--testorval', type=str, default='test', help='Test or validation set')
    parser.add_argument('--num_expressions', type=int, default=None)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--model', type=str, required=True, help='Model to evaluate')
    args = parser.parse_args()
    main(args)
