from email.mime import text
import json
import re
import os
import glob
import tiktoken
from collections import defaultdict
from parser_helper import is_equiv, remove_boxed, last_boxed_only_string
from auto_scoring_judge import AutoScoringJudge
from transformers import AutoTokenizer
import time
import numpy as np

scorer = AutoScoringJudge()

def count_effective_tokens(text):
    if not text:
        return 0
    text = text.replace("<|endoftext|>", "")
    enc = tiktoken.get_encoding("cl100k_base")
    tokens = enc.encode(text)
    return len(tokens)


def parse_gsm_answers(json_path=None, json_data=None):
    if json_path:
        with open(json_path, "r") as file:
            data = json.load(file)
    else:
        data = json_data

    total_correct = 0
    total_processed = 0
    total_effective_tokens = 0
    processed_items = []
    is_correct_at_k = 0

    for item in data.get("generations", []):
        total_processed += 1


        ground_truth = item.get("ground_truth")
        generations = item.get("generations", [])
        generations = item.get("generations", [])
        question = item.get("question", "")


        effective_tokens = min(
            count_effective_tokens(gen) for gen in generations
        ) if generations else 0


        effective_tokens = min(
            count_effective_tokens(gen) for gen in generations
        ) if generations else 0

        total_effective_tokens += effective_tokens

        parsed_answers = []
        correctness_flags = []

        for raw_generation in generations:
            parsed_answer = None
        parsed_answers = []
        correctness_flags = []

        for raw_generation in generations:
            parsed_answer = None

            boxed_matches = re.findall(r"\\boxed{(.*?)}", raw_generation)
            boxed_matches = re.findall(r"\\boxed{(.*?)}", raw_generation)
            for boxed_content in boxed_matches:
                boxed_content = boxed_content.strip()
                if boxed_content and boxed_content != "..." and not re.match(r"^\.+$", boxed_content):
                    numbers = re.findall(r"-?\d+\.?\d*", boxed_content)
                    if numbers:
                        try:
                            parsed_answer = float(numbers[0])
                            break
                        except ValueError:
                            pass
                    numbers = re.findall(r"-?\d+\.?\d*", boxed_content)
                    if numbers:
                        try:
                            parsed_answer = float(numbers[0])
                            break
                        except ValueError:
                            pass

            if parsed_answer is None:
                answer_match = re.search(
                    r"<answer>(.*?)</answer>", raw_generation, re.DOTALL
                )
                if answer_match:
                    numbers = re.findall(
                        r"-?\d+\.?\d*", answer_match.group(1)
                    )
                    if numbers:
                        try:
                            parsed_answer = float(numbers[-1])
                        except ValueError:
                            pass
            if parsed_answer is None:
                answer_match = re.search(
                    r"<answer>(.*?)</answer>", raw_generation, re.DOTALL
                )
                if answer_match:
                    numbers = re.findall(
                        r"-?\d+\.?\d*", answer_match.group(1)
                    )
                    if numbers:
                        try:
                            parsed_answer = float(numbers[-1])
                        except ValueError:
                            pass

            is_correct = (
                parsed_answer is not None and parsed_answer == ground_truth
            )
            if( is_correct ):
                total_correct += 1

            parsed_answers.append(parsed_answer)
            correctness_flags.append(is_correct)

        k = len(generations)

        if(any(correctness_flags)):
            is_correct_at_k += 1
        
        processed_items.append(
            {
                "question": question,
                "raw_generations": generations,
                "extracted_answers": parsed_answers,
                "raw_generations": generations,
                "extracted_answers": parsed_answers,
                "ground_truth": ground_truth,
                "is_correct_at_k": is_correct_at_k,
                "num_correct": sum(correctness_flags),
                "effective_tokens": effective_tokens,
            }
        )

    return (
        total_correct,
        is_correct_at_k, 
        k,          
        total_processed,
        processed_items,
        total_effective_tokens,
    )

import json
import re


def parse_math_answers(json_path=None, json_data=None):
    if json_path:
        with open(json_path, "r") as file:
            data = json.load(file)
    else:
        data = json_data

    total_correct = 0            # Accuracy@k numerator
    total_processed = 0
    total_effective_tokens = 0
    processed_items = []
    is_correct_at_k = 0

    for item in data.get("generations",[]):
        total_processed += 1


        question = item.get("question", "")
        ground_truth = item.get("ground_truth", "")
        generations = item.get("generations", [])   # LIST of k generations

        # ----------------------------
        # Effective tokens (best-of-k)
        # ----------------------------
        effective_tokens = min(
            count_effective_tokens(gen) for gen in generations
        ) if generations else 0

        total_effective_tokens += effective_tokens

        parsed_answers = []
        correctness_flags = []
        parsed_answers = []
        correctness_flags = []

        # ----------------------------
        # Evaluate each generation
        # ----------------------------
        for raw_generation in generations:
            parsed_answer = None

            # 1) Try boxed answer
            try:
                parsed_answer = remove_boxed(
                    last_boxed_only_string(raw_generation)
                )
            except Exception:
                parsed_answer = None

            # 2) Try <answer> tag
            if not parsed_answer:
                answer_match = re.search(
                    r"<answer>(.*?)</answer>", raw_generation, re.DOTALL
                )
                if answer_match:
                    parsed_answer = answer_match.group(1).strip()

            # ----------------------------
            # Correctness check
            # ----------------------------
            is_correct = False

            if parsed_answer is not None:
                is_correct = is_equiv(parsed_answer, ground_truth)
            # 1) Try boxed answer
            try:
                parsed_answer = remove_boxed(
                    last_boxed_only_string(raw_generation)
                )
            except Exception:
                parsed_answer = None

            # 2) Try <answer> tag
            if not parsed_answer:
                answer_match = re.search(
                    r"<answer>(.*?)</answer>", raw_generation, re.DOTALL
                )
                if answer_match:
                    parsed_answer = answer_match.group(1).strip()

            # ----------------------------
            # Correctness check
            # ----------------------------
            is_correct = False

            if parsed_answer is not None:
                is_correct = is_equiv(parsed_answer, ground_truth)

            if not is_correct:
                is_correct = scorer.judge(
                    ground_truth, parsed_answer, precision=1e-6
                )

            if is_correct:
                total_correct += 1

            parsed_answers.append(parsed_answer)
            correctness_flags.append(is_correct)

        # ----------------------------
        # Accuracy@k
        # ----------------------------
        if(any(correctness_flags)):
            is_correct_at_k += 1

        k = len(generations)

        processed_items.append(
            {
                "question": question,
                "raw_generations": generations,
                "extracted_answers": parsed_answers,
                "raw_generations": generations,
                "extracted_answers": parsed_answers,
                "ground_truth": ground_truth,
                "is_correct_at_k": is_correct_at_k,
                "num_correct": sum(correctness_flags),
                "effective_tokens": effective_tokens,
            }
        )

    return (
        total_correct, 
        is_correct_at_k,    
        k,     # correct@k
        total_processed,
        processed_items,
        total_effective_tokens,
    )

def parse_countdown_answers(json_path=None, json_data=None):
    if json_path:
        with open(json_path, "r") as file:
            data = json.load(file)
    else:
        data = json_data

    total_correct = 0            # Accuracy@k numerator
    total_processed = 0
    total_effective_tokens = 0
    processed_items = []
    is_correct_at_k = 0

    def validate_equation(equation_str, available_numbers):
        """Validate that equation only uses available numbers and each number once."""
        try:
            numbers_in_eq = [int(n) for n in re.findall(r"\d+", equation_str)]
            return sorted(numbers_in_eq) == sorted(available_numbers)
        except Exception:
            return sorted(numbers_in_eq) == sorted(available_numbers)
        except Exception:
            return False

    def evaluate_equation(equation_str):
        """Safely evaluate the arithmetic equation."""
        try:
            allowed_pattern = r"^[\d+\-*/().\s]+$"
            if not re.match(allowed_pattern, equation_str):
                raise ValueError("Invalid characters")
            return eval(equation_str.strip(), {"__builtins__": None}, {})
        except Exception:
            return float("Inf")

    for item in data.get("generations", []):
        total_processed += 1

        question = item.get("question", "")
        ground_truth = item.get("ground_truth", [])
        generations = item.get("generations", [])   # LIST of k generations

        # ----------------------------
        # Effective tokens (best-of-k)
        # ----------------------------
        effective_tokens = min(
            count_effective_tokens(gen) for gen in generations
        ) if generations else 0

        total_effective_tokens += effective_tokens

        # ----------------------------
        # Extract numbers + target
        # ----------------------------
        # ----------------------------
        # Extract numbers + target
        # ----------------------------
        numbers = []
        target = None

        if isinstance(ground_truth, list) and len(ground_truth) == 2:
            numbers, target = ground_truth
            numbers, target = ground_truth
        else:
            numbers_match = re.search(
                r"Numbers: \[([\d, ]+)\]", question, re.IGNORECASE
            )
            numbers_match = re.search(
                r"Numbers: \[([\d, ]+)\]", question, re.IGNORECASE
            )
            if numbers_match:
                numbers = [int(x) for x in numbers_match.group(1).split(",")]
                numbers = [int(x) for x in numbers_match.group(1).split(",")]

            target_match = re.search(r"Target: (\d+)", question, re.IGNORECASE)
            if target_match:
                target = int(target_match.group(1))

        extracted_equations = []
        correctness_flags = []

        # ----------------------------
        # Evaluate each generation
        # ----------------------------
        for raw_generation in generations:
            equation = ""

            try:
                equation = remove_boxed(
                    last_boxed_only_string(raw_generation)
                )
            except Exception:
                answer_match = re.search(
                    r"<answer>(.*?)</answer>", raw_generation, re.DOTALL
                )
                if answer_match:
                    equation = answer_match.group(1).strip()
                else:
                    equation = raw_generation

            # Replace LaTeX operators
            equation = (
                equation.replace(r"\div", "/")
                        .replace(r"\times", "*")
                        .replace(r"\cdot", "*")
            )

            # Remove RHS if equation contains '='
            eq_match = re.search(r"([0-9+\-*/() ]+)=", equation)
            if eq_match:
                equation = eq_match.group(1).strip()
        extracted_equations = []
        correctness_flags = []

        # ----------------------------
        # Evaluate each generation
        # ----------------------------
        for raw_generation in generations:
            equation = ""

            try:
                equation = remove_boxed(
                    last_boxed_only_string(raw_generation)
                )
            except Exception:
                answer_match = re.search(
                    r"<answer>(.*?)</answer>", raw_generation, re.DOTALL
                )
                if answer_match:
                    equation = answer_match.group(1).strip()
                else:
                    equation = raw_generation

            # Replace LaTeX operators
            equation = (
                equation.replace(r"\div", "/")
                        .replace(r"\times", "*")
                        .replace(r"\cdot", "*")
            )

            # Remove RHS if equation contains '='
            eq_match = re.search(r"([0-9+\-*/() ]+)=", equation)
            if eq_match:
                equation = eq_match.group(1).strip()

            is_correct = False
            result = None
            is_correct = False
            result = None

            if validate_equation(equation, numbers):
                result = evaluate_equation(equation)
                if target is not None and abs(result - target) < 1e-5:
                    is_correct = True
        
            
            if is_correct:
                total_correct += 1

            extracted_equations.append(equation)
            correctness_flags.append(is_correct)

        # ----------------------------
        # Accuracy@k
        # ----------------------------
        if(any(correctness_flags)):
            is_correct_at_k += 1

        k = len(generations)

        processed_items.append(
            {
                "question": question,
                "raw_generations": generations,
                "extracted_equations": extracted_equations,
                "raw_generations": generations,
                "extracted_equations": extracted_equations,
                "ground_truth": ground_truth,
                "is_correct_at_k": is_correct_at_k,
                "num_correct": sum(correctness_flags),
                "effective_tokens": effective_tokens,
            }
        )

    return (
        total_correct, 
        is_correct_at_k, 
        k,        # correct@k
        total_processed,
        processed_items,
        total_effective_tokens,
    )

def parse_sudoku_answers(json_path=None, json_data=None):
    if json_path:
        with open(json_path, "r") as file:
            data = json.load(file)
    else:
        data = json_data

    total_correct_cells = 0
    total_empty_cells = 0
    total_processed = 0
    total_effective_tokens = 0
    is_correct_at_k = 0

    processed_items = []

    for item in data.get("generations", []):
        total_processed += 1

        question = item.get("question", "")
        ground_truth = item.get("ground_truth", "")
        generations = item.get("generations", [])

        effective_tokens = min(count_effective_tokens(gen) for gen in generations) if generations else 0
        total_effective_tokens += effective_tokens

        # Extract puzzle string
        puzzle_str = ""
        if len(question) >= 16 and all(c.isdigit() or c == "0" for c in question[:16]):
            puzzle_str = question[:16]
        else:
            match = re.search(r"Sudoku puzzle: ([0-9]{16})", question)
            if match:
                puzzle_str = match.group(1)
        assert len(puzzle_str) == 16, f"Invalid puzzle string: {puzzle_str}"

        empty_indices = [i for i, c in enumerate(puzzle_str) if c == "0"]
        empty_cells = len(empty_indices)
        total_empty_cells += empty_cells

        # Initialize tracking variables per question
        best_correct_cells = 0
        best_solution = ""
        solved = False  # per question

        extracted_solutions = []
        per_gen_correct_cells = []

        patterns = [
            r"<answer>.*?```\s*([\d\s]+)```",
            r"<answer>(.*?)(?:<\|eot_id\|>|<\|endoftext\|>|</answer>)",
            r"</answer>\s*(.*?)(?:<\|eot_id\|>|<\|endoftext\|>|$)",
            r".*?(\d{16})\s*</answer>",
            r"\b(\d{16})\b",
        ]

        for raw_generation in generations:
            solution_str = ""
            for pattern in patterns:
                match = re.search(pattern, raw_generation, re.DOTALL)
                if match and match.group(1).strip():
                    solution_str = match.group(1).strip()
                    break

            solution_str = re.sub(r"\s", "", solution_str)

            if not solution_str:
                correct_cells = 0
            else:
                if len(solution_str) < 16:
                    solution_str += "0" * (16 - len(solution_str))
                elif len(solution_str) > 16:
                    solution_str = solution_str[:16]

                correct_cells = sum(
                    1 for i in empty_indices if solution_str[i] == ground_truth[i]
                )

            extracted_solutions.append(solution_str)
            per_gen_correct_cells.append(correct_cells)

    
            if correct_cells > best_correct_cells:
                best_correct_cells = correct_cells
                best_solution = solution_str

        
            if correct_cells == empty_cells:
                solved = True

    
        is_correct_at_k += 1 if solved else 0
        total_correct_cells += best_correct_cells

        k = len(generations)

        processed_items.append({
            "question": question,
            "raw_generations": generations,
            "extracted_solutions": extracted_solutions,
            "ground_truth": ground_truth,
            "empty_cells": empty_cells,
            "best_correct_cells": best_correct_cells,
            "per_generation_correct_cells": per_gen_correct_cells,
            "best_accuracy": best_correct_cells / empty_cells if empty_cells > 0 else 0.0,
            "is_solved_at_k": solved,
            "effective_tokens": effective_tokens,
        })

    return (
        total_correct_cells,
        is_correct_at_k,
        k,  # number of generations in last question (if needed)
        total_empty_cells,
        processed_items,
        total_effective_tokens,
    )



def extract_setup_name(filename):
    """Extract the setup name from the filename."""
    match = re.match(r"^(.+)_generations\.json$", filename)
    print(match)
    if match:
        return match.group(1)
    return None


def aggregate_results(directory="."):
    """Aggregate results from all JSON files and save detailed results."""
    # Find all JSON files matching the pattern
    json_files = glob.glob(os.path.join(directory, "*_generations.json"))
    print(json_files)

    # Dictionary to store aggregated results by setup
    setups = defaultdict(
        lambda: {
            "correct": 0,
            "pass@k": 0,
            "processed": 0,
            "accuracy": 0.0,
            "questions": [],
            "total_effective_tokens": 0,
        }
    )

    for json_file in json_files:
        filename = os.path.basename(json_file)
        setup_name = extract_setup_name(filename)

        if setup_name:
            # print(f"Processing {filename}...")
            if "gsm" in setup_name:
                (
                    correct,
                    is_correct_at_k,
                    k,
                    processed,
                    detailed_results,
                    total_effective_tokens,
                ) = parse_gsm_answers(json_path=json_file)
            elif "aime" or "math" in setup_name:
                (
                    correct,
                    is_correct_at_k,
                    k,
                    processed,
                    detailed_results,
                    total_effective_tokens,
                ) = parse_math_answers(json_path=json_file)
            elif "countdown" in setup_name:
                (
                    correct,
                    is_correct_at_k,
                    k,
                    processed,
                    detailed_results,
                    total_effective_tokens,
                ) = parse_countdown_answers(json_path=json_file)
            elif "sudoku" in setup_name:
                (
                    correct,
                    is_correct_at_k,
                    k,
                    processed,
                    detailed_results,
                    total_effective_tokens,
                ) = parse_sudoku_answers(json_path=json_file)

            setups[setup_name]["correct"] += correct
            setups[setup_name]["processed"] += processed
            setups[setup_name]["total_effective_tokens"] += total_effective_tokens
            setups[setup_name]["questions"].extend(detailed_results)
            setups[setup_name]["pass@k"] += is_correct_at_k
            setups[setup_name]["k"] = k

    # Calculate final accuracy and save results
    for setup, results in sorted(setups.items()):
        results["accuracy"] = (
            (results["correct"] / (results["k"]*results["processed"])) * 100 if results["processed"] > 0 else 0
        )
        results["pass@k"] = results["pass@k"]/ results["processed"] * 100   if results["processed"] > 0 else 0
        results["avg_effective_tokens"] = (
            results["total_effective_tokens"] / results["processed"] if len(results["questions"]) > 0 else 0
        )
    # Header
    header_format = "{:<40} {:>12} {:>25}"
    print(header_format.format("Setup (task_model_seqlen_diffusteps)", "Accuracy", "Pass@k","Avg Effective Tokens"))
    print("-" * 80)

    # Data rows
    row_format = "{:<40} {:>11.2f}% {:>25.2f}"
    for setup, results in sorted(setups.items()):
        print(row_format.format(setup, results["accuracy"], results["pass@k"], results["avg_effective_tokens"]))

    print("=" * 80)

if __name__ == "__main__":
    aggregate_results(directory="results3")

