import json
import re
import argparse
import pdb
def check_string_type(s):
    if isinstance(s, str):
        try:
            int_value = int(s)
            return "integer"
        except ValueError:
            try:
                float_value = float(s)
                return "float"
            except ValueError:
                return "neither"
    else:
        return "neither"

def normalize_answer(text, unit=None):
    """
    Normalize the text answer, stripping currency symbols and commas, handling fractions, and converting to numbers.
    """
    text = re.sub(r"^[\$]", "", text)
    text = re.sub(r"[\,\.\,\/]$", "", text)

    result = re.match(r"^[-+]?[\d,./]+$", text)

    if result is not None:
        text = text.replace(",", "")
        if "/" in text:
            nums = text.split("/")
            number = round(float(nums[0]) / float(nums[1]), 3)
        else:
            if check_string_type(text) == "integer" or check_string_type(text) == "float":
                number = round(float(text), 3)
            else:
                number = text
        number = str(number)
        number = re.sub(r"\.[0]+$", "", number)
        return number
    else:
        if unit:
            text = text.replace(unit, "").strip()
        return text

def extract_result_from_anchor(text):
    """
    Extract the result from within <a></a> tags if present in the text.
    """
    match = re.search(r'<a>(.*?)<\/a>', text)
    return match.group(1) if match else text

def evaluate_results(input_file, output_file, args):
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        correct = 0
        total = 0
        correct_grades_1_6 = 0
        total_grades_1_6 = 0
        correct_grades_7_8 = 0
        total_grades_7_8 = 0
        results = []

        for line in infile:
            data = json.loads(line)
            answer = data['answer']
            prediction = extract_result_from_anchor(data['result'])
            if '<a>' in prediction:
                prediction = prediction.split('<a>')[-1]
            
            unit = data.get('unit', None)
            grade = data.get('grade', None)

            answer_norm = normalize_answer(answer, unit)
            prediction_norm = normalize_answer(prediction, unit)

            is_correct = answer_norm.lower() == prediction_norm.lower()
            data['require_opt'] = 'false' if is_correct else 'true'
            outfile.write(json.dumps(data, ensure_ascii=False) + '\n')  
            if is_correct:
                correct += 1
                if grade in range(1, 7):
                    correct_grades_1_6 += 1
                elif grade in range(7, 9):
                    correct_grades_7_8 += 1

            total += 1
            if grade in range(1, 7):
                total_grades_1_6 += 1
            elif grade in range(7, 9):
                total_grades_7_8 += 1

            results.append({
                'answer': answer,
                'prediction': prediction,
                'answer_norm': answer_norm,
                'prediction_norm': prediction_norm,
                'correct': is_correct,
                'grade': grade
            })

            # Print the original evaluation details
            print("\n##################################")
            print(f"# [Prompt] {data.get('answer_prompt', 'N/A')}\n")
            print(f"# [Output] {prediction}\n")
            print(f"[A] labeled answer (normalized): {answer_norm}")
            print(f"[P] predicted answer (normalized): {prediction_norm}")
            print(f"[Acc]: {is_correct}\n")
            print(f"[A] labeled answer: {answer}")
            print(f"[P] predicted answer: {prediction}")

        accuracy = correct / total * 100
        accuracy_grades_1_6 = (correct_grades_1_6 / total_grades_1_6 * 100) if total_grades_1_6 > 0 else 0
        accuracy_grades_7_8 = (correct_grades_7_8 / total_grades_7_8 * 100) if total_grades_7_8 > 0 else 0

        output_data = {
            'accuracy': accuracy,
            'total': total,
            'correct': correct,
            'results': results,
            'args': vars(args),
            'accuracy_grades_1_6': accuracy_grades_1_6,
            'total_grades_1_6': total_grades_1_6,
            'correct_grades_1_6': correct_grades_1_6,
            'accuracy_grades_7_8': accuracy_grades_7_8,
            'total_grades_7_8': total_grades_7_8,
            'correct_grades_7_8': correct_grades_7_8
        }

        print(f"\nTotal: {total}, Correct: {correct}, Accuracy: {round(accuracy, 2)}%")
        print(f"Grades 1-6 - Total: {total_grades_1_6}, Correct: {correct_grades_1_6}, Accuracy: {round(accuracy_grades_1_6, 2)}%")
        print(f"Grades 7-8 - Total: {total_grades_7_8}, Correct: {correct_grades_7_8}, Accuracy: {round(accuracy_grades_7_8, 2)}%")


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate results against ground truth answers.")
    parser.add_argument('--input_file', type=str, required=True, help='Path to the input JSONL file.')
    parser.add_argument('--output_file', type=str, required=True, help='Path to the output JSON file.')
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    evaluate_results(args.input_file, args.output_file, args)
