import sys
import os
import argparse
import pandas as pd
from tqdm import tqdm
import json
import re

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from utils.formula_utils import FormulaParser
from utils.exact_match_utils import judge_exact_match





def extract_solution(solution_str):
    # Regular expression to find the last occurrence of <answer>...</answer>
    answer_pattern = r'<answer>(.*?)</answer>'
    matches = re.findall(answer_pattern, solution_str, re.DOTALL)
    
    if matches:
        return matches[-1].strip()
    else:
        return ''



def calculate_answer_score(pred, gold_answer, table_data, task):
    try:
        if task == "formula":
            pred_formula = pred
            pred_formula_parser = FormulaParser(pred_formula)
            pred_results = pred_formula_parser.execute(table_data)
        elif task == "text":
            pred_results = pred
        else:
            raise ValueError(f"Unknown task: {task}")

        # format into list for judge_exact_match
        pred_results = [str(result).strip() for result in str(pred_results).split(', ')]
        gold_results = gold_answer.tolist()

        answer_score = 1 if judge_exact_match(pred_results, gold_results) else 0

    except Exception as e:
        print(f"[Error] Evaluation error: {e}")
        pred_results = []
        answer_score = 0

    return answer_score, pred_results


# main
def evaluate(output_path, task, sample_per_data=1):

    df = pd.read_parquet(output_path)

    test_data = [df.iloc[i] for i in range(len(df))]
    
    prompts = [item['prompt'] for item in test_data]
    ground_truths = [item['reward_model']['ground_truth']['answer'] for item in test_data]
    table_data = [item['extra_info']['table'] for item in test_data]
    generated_texts = [item['responses'] for item in test_data]

    scores = []
    for i in range(len(test_data)):
        prompt = prompts[i]
        ground_truth = ground_truths[i]
        tmp_table_data = table_data[i]
        generated_text = generated_texts[i]

        if sample_per_data == 1:
            # print(generated_text[0])
            answer_text = extract_solution(generated_text[0])
            if answer_text:
                try:
                    if task == 'formula':
                        pred = json.loads(answer_text)['formula']
                    else:
                        pred = json.loads(answer_text)['answer']
                    score, pred_results = calculate_answer_score(pred, ground_truth, tmp_table_data, task)
                except (json.JSONDecodeError, KeyError) as e:
                    # print(f"[Error] JSON parsing error: {e}")
                    pred_results = []
                    score = 0
            else:
                pred_results = []
                score = 0
            scores.append(score)
        else:
            raise NotImplementedError(f"sample_per_data = {sample_per_data} is not implemented")

    # # Calculate and print final metrics
    final_accuracy = sum(scores) / len(scores) if len(scores) > 0 else 0
    print("-"*60)
    print(f"\nFinal Results:")
    print(f"Accuracy: {final_accuracy:.4f}")
    print("-"*60)
    

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_path", type=str)
    parser.add_argument("--task", type=str)
    parser.add_argument("--sample_per_data", type=int, default=1)
    args = parser.parse_args()

    evaluate(args.output_path, args.task, args.sample_per_data)


if __name__ == "__main__":
    main()