import argparse
import json
import re
import time
import os
from collections import Counter
from dataclasses import dataclass
from datetime import datetime
from typing import List, Callable, Optional, Dict
from pathlib import Path
import sys

# Add project root to sys.path
current_dir = Path(__file__).parent
project_root = current_dir.parent
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(current_dir))

from tqdm import tqdm
from llm_client import create_llm_client

# ===== Basic evaluation metrics =====

def pass_at_k(results: List[bool], k: int) -> float:
    """
    Compute pass@k.

    pass@k is the probability of at least one success in k attempts.
    Commonly used in code generation evaluation.

    Args:
        results: List of booleans (True=success, False=failure)
        k: Number of attempts to consider

    Returns:
        pass@k value in [0, 1]

    Example:
        >>> results = [True, False, True, False, True]
        >>> pass_at_k(results, 3)  # at least one success within first 3 attempts
        1.0
    """
    if k <= 0:
        return 0.0
    if k > len(results):
        k = len(results)
    
    # If any of the first k attempts is successful, return 1.0
    return 1.0 if any(results[:k]) else 0.0

def avg_at_k(scores: List[float], k: int) -> float:
    """
    Compute avg@k.

    avg@k is the average score of the first k attempts.

    Args:
        scores: List of scores (e.g., accuracy, BLEU)
        k: Number of attempts to consider

    Returns:
        Average of the first k scores

    Example:
        >>> scores = [0.8, 0.6, 0.9, 0.7, 0.5]
        >>> avg_at_k(scores, 3)  # average over first 3
        0.7666666666666667
    """
    if k <= 0:
        return 0.0
    if k > len(scores):
        k = len(scores)
    
    return sum(scores[:k]) / k

# ===== Batch evaluation metrics =====

def batch_pass_at_k(all_results: List[List[bool]], k: int) -> float:
    """
    Compute the average pass@k across multiple questions.

    Args:
        all_results: List of result lists per question
        k: Number of attempts to consider

    Returns:
        Average pass@k across all questions
    """
    if not all_results:
        return 0.0
    
    pass_at_k_scores = [pass_at_k(results, k) for results in all_results]
    return sum(pass_at_k_scores) / len(pass_at_k_scores)

def batch_avg_at_k(all_scores: List[List[float]], k: int) -> float:
    """
    Compute the average avg@k across multiple questions.

    Args:
        all_scores: List of score lists per question
        k: Number of attempts to consider

    Returns:
        Average avg@k across all questions
    """
    if not all_scores:
        return 0.0
    
    avg_at_k_scores = [avg_at_k(scores, k) for scores in all_scores]
    return sum(avg_at_k_scores) / len(avg_at_k_scores)


# ===== LLM-as-Judge evaluation =====

def llm_as_judge(input_file: str, model: str = "gpt-4o-mini-0718-global"):
    """Use an LLM as judge for evaluation.

    Args:
        input_file: Path to input JSON file
        model: Model name
    """
    client = create_llm_client(model=model)

    # Load data
    with open(input_file, 'r', encoding='utf-8') as f:
        responses = json.load(f)

    # Initialize accumulators
    correct = 0
    unexpected_judge = []
    evaluation_results = []
    all_judgement_results = []
    all_response_contents = []

    print(f"Start evaluating {len(responses)} samples...")

    def _extract_after_think(text: str, tag: str = "</think>") -> str:
        """Extract content after </think>; if tag not found return a placeholder text"""
        try:
            pos = text.find(tag)
            if pos != -1:
                return text[pos + len(tag):].strip()
            else:
                return "too long to get valid prediction"
            # return text
        except Exception:
            return text

    for idx, sample in tqdm(enumerate(responses), desc="Evaluation progress"):
        response_list = sample["response_ours"]
        response_judgement = []
        response_contents = []
        judgement_results = []
        
        gt = sample["gt"]
        question = None
        for msg in sample["chat"]:
            if msg['role'] == 'user':
                question = msg['content']
                break

        for response_idx, response in enumerate(response_list):
            # Use only the content after </think> as Prediction
            truncated_response = _extract_after_think(response)
            content = truncated_response
            response_contents.append(content)
            # Build judge prompt
            prompt = f"""Task: Determine whether the Prediction expresses the same option as the Ground Truth.
**Question:** 
```
{question}
```

**Prediction:**
```
{content}
```

**Ground Truth:**
```
{gt}
```

Instructions:
1. Focus exclusively on semantic equivalence.
2. Disregard differences in wording, phrasing, and detail.
3. If the core options match, output "Yes"; if they differ in any essential aspect, output "No".

Response Format: Output exactly one word: either "Yes" or "No". Do not include any additional text."""
            messages = [
                {"role": "system", "content": "You are a precise evaluator that only outputs Yes or No."},
                {"role": "user", "content": prompt}
            ]
            if idx == 0 and response_idx == 0:
                print("=============================Judge Messages Case=============================")
                for mi, m in enumerate(messages):
                    print(f"[{mi}] {m['role'].upper()}\n{m['content']}\n---")
            judgement_result = _evaluate_with_retry(client, messages, idx, len(response_judgement))
            
            # Record results
            if judgement_result["judgement_text"] not in ["Yes", "No"]:
                unexpected_judge.append(judgement_result["judgement_text"])
                judgement_results.append(False)
                final_judgement_text = "No"
            else:
                print(f"\nSample {idx+1}, response {len(response_judgement)+1}: {judgement_result['judgement_text']}")
                judgement_results.append(judgement_result["judgement_text"] == "Yes")
                final_judgement_text = judgement_result["judgement_text"]
            
            # Save judgement details
            response_judgement.append({
                "content": final_judgement_text,
                "original_content": judgement_result.get("original_content", ""),
                "model": judgement_result.get("model", client.model),
                "prompt_tokens": judgement_result.get("prompt_tokens", 0),
                "completion_tokens": judgement_result.get("completion_tokens", 0),
                "total_tokens": judgement_result.get("total_tokens", 0),
                "response_time": judgement_result.get("response_time", 0),
                "retry_count": judgement_result.get("retry_count", 0),
                "is_correct": judgement_results[-1]
            })

    # Determine whether the sample passes
        sample_passed = any(judgement_results)
        if sample_passed:
            correct += 1
        
    # Store per-sample evaluation results
        sample_result = {
            "sample_id": idx,
            "question": question,
            "ground_truth": gt,
            "responses": response_contents,
            "judgements": response_judgement,
            "judgement_results": judgement_results,
            "sample_passed": sample_passed
        }
        evaluation_results.append(sample_result)
        all_judgement_results.append(judgement_results)
        all_response_contents.append(response_contents)

    # Compute and output summary results
    _calculate_and_save_results(
        responses, correct, unexpected_judge, evaluation_results, 
        all_judgement_results, all_response_contents, input_file, client.model
    )


def _evaluate_with_retry(client, messages, sample_idx, response_idx):
    """Evaluation helper with retry logic (messages version), auto-switching to fallback models after repeated failures.

    Logic:
    1. Retry the current model multiple times (default per_model_retry_limit attempts per model).
    2. If still no valid (Yes/No) response, switch to the next fallback model.
    3. Fallback model list can be specified via env var EVAL_FALLBACK_MODELS (comma-separated);
       otherwise a built-in default order is used.
    4. Returned result includes model_sequence of models used.
    """
    max_total_retries = 12  # total retry cap (across all models)
    per_model_retry_limit = 3  # per-model max attempts
    original_model = client.model  # record original model; restore after success/final

    # Read fallback model list (exclude the current initial model, keep order)
    env_fallback = os.getenv("EVAL_FALLBACK_MODELS", "")
    if env_fallback.strip():
        fallback_models = [m.strip() for m in env_fallback.split(',') if m.strip()]
    else:
        # Default list; adjust as needed. Any entry equal to the current model will be removed.
        fallback_models = [
            "gpt-41-mini-0414-global",
            "o4-mini-0416-global",
            "o3-mini-2025-01-31",
            "gpt-4o-1120-global",
            "gemini-2.5-pro-06-17",
            "gemini-2.5-flash-06-17",
            "claude_sonnet4",
            "claude37_sonnet",
        ]

    # Deduplicate and remove current model
    seen = set()
    cleaned_fallback = []
    for m in fallback_models:
        if m == client.model:
            continue
        if m not in seen:
            cleaned_fallback.append(m)
            seen.add(m)
    fallback_models = cleaned_fallback

    retry_count_total = 0
    retry_count_current_model = 0
    model_sequence = [client.model]
    fallback_index = 0

    while retry_count_total < max_total_retries:
        try:
            judgement = client.generate_response(messages=messages, temperature=0.1)
            judgement_text = judgement.content.strip()
            if judgement_text in ["Yes", "No"]:
                used_model = client.model  # model actually used
                # Restore to original model upon success
                if client.model != original_model:
                    client.model = original_model
                return {
                    "judgement_text": judgement_text,
                    "original_content": judgement.content,
                    "model": used_model,
                    "prompt_tokens": getattr(judgement, 'prompt_tokens', 0),
                    "completion_tokens": getattr(judgement, 'completion_tokens', 0),
                    "total_tokens": getattr(judgement, 'total_tokens', 0),
                    "response_time": getattr(judgement, 'response_time', 0),
                    "retry_count": retry_count_total,
                    "model_sequence": model_sequence,
                }
            else:
                retry_count_total += 1
                retry_count_current_model += 1
                print(
                    f"Sample {sample_idx+1}, response {response_idx+1}, model {client.model}, "
                    f"attempt {retry_count_current_model}: invalid response '{judgement_text}'"
                )
        except Exception as e:
            retry_count_total += 1
            retry_count_current_model += 1
            print(
                f"Sample {sample_idx+1}, response {response_idx+1}, model {client.model}, "
                f"attempt {retry_count_current_model}: request failed {str(e)}"
            )

        # Decide whether to switch model
        should_switch = (
            retry_count_current_model >= per_model_retry_limit
            and fallback_index < len(fallback_models)
            and retry_count_total < max_total_retries
        )
        if should_switch:
            new_model = fallback_models[fallback_index]
            fallback_index += 1
            print(f"--> Switching to fallback model: {new_model} (previous: {client.model})")
            client.model = new_model  # directly update client model
            model_sequence.append(new_model)
            retry_count_current_model = 0

        if retry_count_total < max_total_retries:
            # Simple backoff: slightly longer wait after repeated attempts on the same model
            sleep_time = 1 if retry_count_current_model == 0 else 1.5
            time.sleep(sleep_time)

    print(
        f"Warning: Sample {sample_idx+1}, response {response_idx+1}: retried across models "
        f"{max_total_retries} times with no valid response; models used: {model_sequence}"
    )
    # On failure also restore to original model if it was switched
    if client.model != original_model:
        client.model = original_model
    return {
        "judgement_text": "Invalid",
        "original_content": "",
        "model": model_sequence[-1] if model_sequence else original_model,
        "retry_count": retry_count_total,
        "model_sequence": model_sequence,
        "final_client_model": client.model,
    }

def _calculate_and_save_results(responses, correct, unexpected_judge, evaluation_results,
                               all_judgement_results, all_response_contents, input_file, model_name):
    """Compute metrics and save results."""
    accuracy = correct / len(responses)
    print(f"\n=== Basic Evaluation Results ===")
    print(f"Total samples: {len(responses)}")
    print(f"Correct samples: {correct}")
    print(f"Incorrect samples: {len(responses) - correct}")
    print(f"Base accuracy: {accuracy:.2%}")
    print(f"Unexpected judgement count: {len(unexpected_judge)}")
    if unexpected_judge:
        print(f"Unexpected judgements: {unexpected_judge}")

    # Compute metrics for various k values
    print(f"\n=== Metric Computation ===")
    k_values = [1, 2, 4, 8, 16]

    for k in k_values:
        if k <= max(len(results) for results in all_judgement_results):
            # Compute metrics
            pass_k = batch_pass_at_k(all_judgement_results, k)
            all_scores = [[float(result) for result in results] for results in all_judgement_results]
            avg_k = batch_avg_at_k(all_scores, k)
            
            print(f"pass@{k}: {pass_k:.4f}")
            print(f"avg@{k}: {avg_k:.4f}")

    # Save results
    _save_evaluation_results(
        input_file, model_name, correct, len(responses), accuracy, unexpected_judge,
        evaluation_results, all_judgement_results, all_response_contents, k_values
    )

def _save_evaluation_results(input_file, model_name, correct, total_samples, accuracy, unexpected_judge,
                           evaluation_results, all_judgement_results, all_response_contents, k_values):
    """Save evaluation results to a JSON file."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = f"evaluation_results_{timestamp}.json"

    final_results = {
        "metadata": {
            "input_file": input_file,
            "evaluation_time": timestamp,
            "total_samples": total_samples,
            "model_used": model_name
        },
        "summary": {
            "correct_samples": correct,
            "incorrect_samples": total_samples - correct,
            "accuracy": accuracy,
            "unexpected_judgements": len(unexpected_judge),
            "unexpected_judgement_list": unexpected_judge
        },
        "metrics": {},
        "detailed_results": evaluation_results
    }

    # Add metrics for different k values
    for k in k_values:
        if k <= max(len(results) for results in all_judgement_results):
            pass_k = batch_pass_at_k(all_judgement_results, k)
            all_scores = [[float(result) for result in results] for results in all_judgement_results]
            avg_k = batch_avg_at_k(all_scores, k)
            
            final_results["metrics"][f"pass@{k}"] = pass_k
            final_results["metrics"][f"avg@{k}"] = avg_k

    # Save to file
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(final_results, f, ensure_ascii=False, indent=2)

    print(f"\n=== Results Saved ===")
    print(f"Detailed evaluation results saved to: {output_file}")


def analyze_judgement_results(evaluation_file: str) -> Dict:
    """Analyze an evaluation results file and categorize samples by the ratio of "Yes" judgements.

    Args:
        evaluation_file: Path to the evaluation results JSON file

    Returns:
        A dictionary containing analysis results
    """
    with open(evaluation_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    detailed_results = data.get('detailed_results', [])
    
    # Initialize categories
    categories = {
        '0%': [], '1-25%': [], '26-50%': [],
        '51-75%': [], '76-99%': [], '100%': []
    }
    
    all_judgement_results = []
    sample_analysis = []
    
    for sample in detailed_results:
        sample_id = sample.get('sample_id', 0)
        judgements = sample.get('judgements', [])
        
        # Count the number and ratio of "Yes" judgements
        yes_count = sum(1 for j in judgements if j.get('content', '').strip() == 'Yes')
        total_count = len(judgements)
        yes_ratio = yes_count / total_count if total_count > 0 else 0
        
        # Categorize
        if yes_ratio == 0:
            category = '0%'
        elif yes_ratio <= 0.25:
            category = '1-25%'
        elif yes_ratio <= 0.50:
            category = '26-50%'
        elif yes_ratio <= 0.75:
            category = '51-75%'
        elif yes_ratio < 1.0:
            category = '76-99%'
        else:
            category = '100%'
        
        categories[category].append(sample_id)
        
        judgement_results = [1 if j.get('content', '').strip() == 'Yes' else 0 for j in judgements]
        all_judgement_results.append(judgement_results)
        
        sample_analysis.append({
            'sample_id': sample_id,
            'total_judgements': total_count,
            'yes_count': yes_count,
            'yes_ratio': yes_ratio,
            'category': category,
            'judgement_results': judgement_results
        })
    
    # Compute metrics
    k_values = [1, 2, 4, 8, 16]
    metrics = {}
    
    for k in k_values:
        if k <= max(len(results) for results in all_judgement_results if results):
            metrics[f'pass@{k}'] = batch_pass_at_k(all_judgement_results, k)
            all_scores = [[float(result) for result in results] for results in all_judgement_results]
            metrics[f'avg@{k}'] = batch_avg_at_k(all_scores, k)
           
    
    # Stats per category
    category_stats = {
        category: {'count': len(sample_ids), 'sample_ids': sample_ids}
        for category, sample_ids in categories.items()
    }
    
    return {
        'file_info': {
            'source_file': evaluation_file,
            'total_samples': len(detailed_results)
        },
        'category_statistics': category_stats,
        'metrics': metrics,
        'sample_details': sample_analysis
    }


# ===== Reasoning evaluation related classes and functions =====

@dataclass
class ReasoningEvaluation:
    """Reasoning evaluation result."""
    sample_id: int
    question: str
    reasoning_process: str
    is_correct_reasoning: bool
    is_independent: bool
    correct_reasoning_explanation: str
    independence_explanation: str
    evaluation_time: float


class ReasoningEvaluator:
    """Reasoning process evaluator."""
    
    def __init__(self, model: str = "gemini-2.5-pro-06-17", api_key: Optional[str] = None):
        self.model = model
        self.client = create_llm_client(model=model, api_key=api_key)
        self._printed_example = False  # print evaluation prompt example once
    
    def extract_question_from_chat(self, chat_content: str) -> str:
        """Extract the question part from chat content."""
        try:
            question_pattern = r'\*\*Question:\*\*(.*?)\*\*Reference Answer:\*\*'
            match = re.search(question_pattern, chat_content, re.DOTALL)
            return match.group(1).strip() if match else "Question extraction failed"
        except Exception as e:
            return f"Error extracting question: {str(e)}"
    
    def extract_reasoning_process(self, response: str, cot_solution_start_tag="</think>") -> str:
        """Extract the reasoning process portion from a response."""
        try:
            think_end = response.find(cot_solution_start_tag)
            if think_end != -1:
                return response[think_end + len(cot_solution_start_tag):].strip()
            return ""
        except Exception as e:
            return ""
    
    def evaluate_reasoning(self, question: str, reasoning_process: str, correct_answer: str) -> Dict:
        """Evaluate the correctness and independence of a reasoning process."""
        evaluation_prompt = f"""Please evaluate the following reasoning process on two aspects:

**Question:**
```
{question}
```

**Reasoning Process:**
```
{reasoning_process}
```

**Reference Answer:**
```
{correct_answer}
```

Please evaluate separately:

1. **Reasoning Correctness**: Does this reasoning process correctly lead to the answer?
   - Judge whether the reasoning logic is correct
   - Judge whether each step is reasonable and coherent
   - Judge whether the final conclusion is correct

2. **Reasoning Independence**: Does this reasoning process not rely on the reference answer?
   - Judge whether there are obvious traces of using the reference answer
   - Judge whether there are statements like "I know the answer is X"

Please answer in the following format:

Correctness Assessment: [Yes/No]
Correctness Explanation: [Detailed explanation of why you think the reasoning is correct or incorrect]

Independence Assessment: [Yes/No]
Independence Explanation: [Detailed explanation of why you think the reasoning is independent or not independent]

Please strictly follow the above format and only answer "Yes" or "No", do not include any other expressions."""

        if not self._printed_example:
            print("=============================Reasoning Evaluation Prompt Case=============================")
            print(evaluation_prompt)
            self._printed_example = True

        start_time = time.time()
        
        try:
            messages = [
                {"role": "system", "content": "You are a professional reasoning process evaluation expert who can objectively evaluate the correctness and independence of reasoning."},
                {"role": "user", "content": evaluation_prompt}
            ]
            if not self._printed_example:
                print("=============================Reasoning Messages Case=============================")
                for mi, m in enumerate(messages):
                    print(f"[{mi}] {m['role'].upper()}\n{m['content'][:1000]}\n---")
            response = self.client.generate_response(
                messages=messages,
                temperature=0.1,
                max_tokens=30000
            )
            
            evaluation_time = time.time() - start_time
            evaluation_text = response.content.strip()
            
            # Parse evaluation result
            is_correct = self._parse_evaluation_result(evaluation_text, "Correctness Assessment")
            correct_explanation = self._parse_explanation(evaluation_text, "Correctness Explanation")
            is_independent = self._parse_evaluation_result(evaluation_text, "Independence Assessment")
            independence_explanation = self._parse_explanation(evaluation_text, "Independence Explanation")
            
            return {
                "is_correct_reasoning": is_correct,
                "is_independent": is_independent,
                "correct_reasoning_explanation": correct_explanation,
                "independence_explanation": independence_explanation,
                "evaluation_time": evaluation_time,
                "raw_evaluation": evaluation_text,
                "evaluation_prompt": evaluation_prompt
            }
            
        except Exception as e:
            return {
                "is_correct_reasoning": False,
                "is_independent": False,
                "correct_reasoning_explanation": f"Evaluation failed: {str(e)}",
                "independence_explanation": f"Evaluation failed: {str(e)}",
                "evaluation_time": time.time() - start_time,
                "raw_evaluation": str(e),
                "evaluation_prompt": evaluation_prompt
            }
    
    def _parse_evaluation_result(self, text: str, label: str) -> bool:
        """Parse Yes/No judgement from the evaluation text."""
        try:
            # Normalize full-width colon to ASCII to avoid locale-specific punctuation
            text_norm = text.replace('\uFF1A', ':')
            label_norm = label.replace('\uFF1A', ':')
            pattern = f"{label_norm}\\s*[:]\\s*(Yes|No)"
            match = re.search(pattern, text_norm, re.IGNORECASE)
            if match:
                return match.group(1).lower() == "yes"
            
            # Fallback parse method
            lines = text_norm.split('\n')
            for line in lines:
                if label_norm in line:
                    if "yes" in line.lower() and "no" not in line.lower():
                        return True
                    elif "no" in line.lower() and "yes" not in line.lower():
                        return False
            
            return False
        except:
            return False
    
    def _parse_explanation(self, text: str, label: str) -> str:
        """Parse explanation text for a given label."""
        try:
            # Normalize punctuation
            text_norm = text.replace('\uFF1A', ':')
            label_norm = label.replace('\uFF1A', ':')
            pattern = f"{label_norm}\\s*[:]\\s*(.*?)(?=\n.*?[:]|\n*$)"
            match = re.search(pattern, text_norm, re.DOTALL)
            if match:
                return match.group(1).strip()
            
            # Fallback parse method
            lines = text_norm.split('\n')
            for i, line in enumerate(lines):
                if label_norm in line:
                    if ':' in line:
                        explanation = line.split(':', 1)[1].strip()
                        if len(explanation) < 10 and i + 1 < len(lines):
                            explanation += " " + lines[i + 1].strip()
                        return explanation
            
            return "Failed to parse explanation"
        except:
            return "Parsing failed"


def process_dataset(input_file: str, output_file: str, max_samples: Optional[int] = None, model: Optional[str] = None):
    """Process a dataset: extract questions and reasoning processes, then evaluate them.
    Args:
        input_file: Input reasoning results file
        output_file: Output evaluation results file
        max_samples: Optional limit on number of samples to process
        model: Model to use for reasoning evaluation (overrides default gemini-2.5-pro-06-17)
    """
    print(f"Loading data from {input_file}...")
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    if max_samples:
        data = data[:max_samples]
    
    print(f"Processing {len(data)} samples...")
    
    evaluator = ReasoningEvaluator(model=model or "gemini-2.5-pro-06-17")
    evaluation_results = []
    
    for idx, sample in enumerate(data):
        print(f"\nProcessing sample {idx + 1}/{len(data)}...")
        
        try:
            # Extract question and correct answer
            chat_content = sample["chat"][0]["content"] if sample["chat"] else ""
            question = evaluator.extract_question_from_chat(chat_content)
            correct_answer = sample.get("gt", "")
            print(f"Question extracted: {question[:100]}...")
            
            # Process each response
            responses = sample.get("response_ours", [])
            sample_results = []
            
            for resp_idx, response in enumerate(responses):
                print(f"  Processing response {resp_idx + 1}/{len(responses)}...")
                
                # Extract reasoning process
                reasoning_process = evaluator.extract_reasoning_process(response)
                
                if len(reasoning_process) < 50:
                    print(f"    Skipping short reasoning process: {len(reasoning_process)} chars")
                    continue
                
                print(f"    Reasoning process extracted: {len(reasoning_process)} chars")
                
                # Evaluate reasoning process
                evaluation = evaluator.evaluate_reasoning(question, reasoning_process, correct_answer)
                
                # Create result object
                result = ReasoningEvaluation(
                    sample_id=idx,
                    question=question,
                    reasoning_process=reasoning_process,
                    is_correct_reasoning=evaluation["is_correct_reasoning"],
                    is_independent=evaluation["is_independent"],
                    correct_reasoning_explanation=evaluation["correct_reasoning_explanation"],
                    independence_explanation=evaluation["independence_explanation"],
                    evaluation_time=evaluation["evaluation_time"]
                )
                
                sample_results.append({
                    "sample_id": idx,
                    "response_id": resp_idx,
                    "question": question,
                    "reasoning_process": reasoning_process[:1000] + "..." if len(reasoning_process) > 1000 else reasoning_process,
                    "is_correct_reasoning": result.is_correct_reasoning,
                    "is_independent": result.is_independent,
                    "correct_reasoning_explanation": result.correct_reasoning_explanation,
                    "independence_explanation": result.independence_explanation,
                    "evaluation_time": result.evaluation_time,
                    "raw_evaluation": evaluation.get("raw_evaluation", ""),
                    "evaluation_prompt": evaluation.get("evaluation_prompt", "")
                })
                
                print(f"    Correct: {result.is_correct_reasoning}, Independent: {result.is_independent}")
                time.sleep(1)  # API rate limit delay
            
            evaluation_results.extend(sample_results)
            
        except Exception as e:
            print(f"Error processing sample {idx}: {str(e)}")
            continue
    
    # Compute stats and save results
    _save_reasoning_evaluation_results(input_file, data, evaluation_results, evaluator, output_file)


def _save_reasoning_evaluation_results(input_file, data, evaluation_results, evaluator, output_file):
    """Save reasoning evaluation results."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Compute statistics
    correct_reasoning_samples = set(r["sample_id"] for r in evaluation_results if r["is_correct_reasoning"])
    independent_reasoning_samples = set(r["sample_id"] for r in evaluation_results if r["is_independent"])
    both_correct_and_independent_samples = set(r["sample_id"] for r in evaluation_results 
                                             if r["is_correct_reasoning"] and r["is_independent"])
    
    final_results = {
        "metadata": {
            "input_file": input_file,
            "total_samples": len(data),
            "processed_samples": len(evaluation_results),
            "evaluation_time": timestamp,
            "model_used": evaluator.client.model
        },
        "summary": {
            "total_evaluations": len(evaluation_results),
            "correct_reasoning_count": sum(1 for r in evaluation_results if r["is_correct_reasoning"]),
            "independent_reasoning_count": sum(1 for r in evaluation_results if r["is_independent"]),
            "both_correct_and_independent": sum(1 for r in evaluation_results 
                                              if r["is_correct_reasoning"] and r["is_independent"]),
            "correct_reasoning_samples_count": len(correct_reasoning_samples),
            "independent_reasoning_samples_count": len(independent_reasoning_samples),
            "both_correct_and_independent_samples_count": len(both_correct_and_independent_samples)
        },
        "detailed_results": evaluation_results
    }
    
    # Compute rates
    if evaluation_results:
        total = len(evaluation_results)
        summary = final_results["summary"]
        summary["correct_reasoning_rate"] = summary["correct_reasoning_count"] / total
        summary["independence_rate"] = summary["independent_reasoning_count"] / total
        summary["both_correct_and_independent_rate"] = summary["both_correct_and_independent"] / total
    
    # Save to file
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(final_results, f, ensure_ascii=False, indent=2)
    
    print(f"\n=== Reasoning Evaluation Completed ===")
    print(f"Results saved to: {output_file}")
    print(f"Total evaluations: {final_results['summary']['total_evaluations']}")
    print(
        f"Correct reasoning count: {final_results['summary']['correct_reasoning_count']} "
        f"across {final_results['summary']['correct_reasoning_samples_count']} samples"
    )
    print(
        f"Independent reasoning count: {final_results['summary']['independent_reasoning_count']} "
        f"across {final_results['summary']['independent_reasoning_samples_count']} samples"
    )
    print(
        f"Both correct and independent: {final_results['summary']['both_correct_and_independent']} "
        f"across {final_results['summary']['both_correct_and_independent_samples_count']} samples"
    )
    
    if evaluation_results:
        print(f"Correct reasoning rate: {final_results['summary']['correct_reasoning_rate']:.2%}")
        print(f"Independence rate: {final_results['summary']['independence_rate']:.2%}")
        print(f"Both correct and independent rate: {final_results['summary']['both_correct_and_independent_rate']:.2%}")


# ===== Main function and CLI =====

def main():
    """Main entrypoint using argparse to handle CLI arguments."""
    parser = argparse.ArgumentParser(
        description="LLM Evaluation Toolkit: multiple evaluation modes supported",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Evaluate using LLM-as-Judge with defaults
  python evaluation.py llm-judge --input data.json

  # Specify model/provider
  python evaluation.py llm-judge --input data.json --model gpt-4

  # Reasoning process evaluation
  python evaluation.py eval_reasoning_trace --input data.json --output results.json

  # Analyze an existing evaluation results file
  python evaluation.py analyze_eval_result --input evaluation_results.json
        """
    )

    # Subcommands
    subparsers = parser.add_subparsers(dest='command', help='Choose evaluation mode')

    # LLM-as-Judge subcommand
    judge_parser = subparsers.add_parser('llm-judge', help='Evaluate using an LLM as the judge')
    judge_parser.add_argument('--input', '-i', type=str, required=True,
                              help='Path to the input JSON data file')
    judge_parser.add_argument('--model', '-m', type=str,
                              default='gpt-4o-mini-0718-global',
                              help='Model name to use (default: gpt-4o-mini-0718-global)')

    # Reasoning process evaluation subcommand
    reasoning_parser = subparsers.add_parser('eval_reasoning_trace', help='Evaluate reasoning traces')
    reasoning_parser.add_argument('--input', '-i', type=str, required=True,
                                  help='Path to the input JSON data file')
    reasoning_parser.add_argument('--output', '-o', type=str, required=True,
                                  help='Path to the output results file')
    reasoning_parser.add_argument('--max-samples', type=int, default=None,
                                  help='Maximum number of samples to process (default: all)')
    reasoning_parser.add_argument('--model', '-m', type=str,
                                  default='gemini-2.5-pro-06-17',
                                  help='Model name to use (default: gemini-2.5-pro-06-17)')

    # Results analysis subcommand
    analyze_parser = subparsers.add_parser('analyze_eval_result', help='Analyze evaluation results')
    analyze_parser.add_argument('--input', '-i', type=str, required=True,
                                help='Path to the evaluation results JSON file')
    analyze_parser.add_argument('--output', '-o', type=str, default=None,
                                help='Optional output path for analysis results')

    # Example subcommand
    example_parser = subparsers.add_parser('example', help='Run example code')

    # Parse args
    args = parser.parse_args()

    if args.command is None:
        parser.print_help()
        return

    try:
        if args.command == 'llm-judge':
            print(f"Starting LLM-as-Judge evaluation...")
            print(f"Input file: {args.input}")
            print(f"Model: {args.model}")
            llm_as_judge(args.input, args.model)

        elif args.command == 'eval_reasoning_trace':
            print(f"Starting reasoning process evaluation...")
            print(f"Input file: {args.input}")
            print(f"Output file: {args.output}")
            print(f"Model: {args.model}")
            if args.max_samples:
                print(f"Max samples: {args.max_samples}")
            process_dataset(args.input, args.output, args.max_samples, model=args.model)

        elif args.command == 'analyze_eval_result':
            print(f"Analyzing evaluation results...")
            print(f"Input file: {args.input}")
            results = analyze_judgement_results(args.input)

            # Print analysis results
            print(f"\n=== Analysis Results ===")
            print(f"Total samples: {results['file_info']['total_samples']}")
            print(f"\nCategory stats:")
            for category, stats in results['category_statistics'].items():
                print(f"  {category}: {stats['count']} samples")

            print(f"\nEvaluation metrics:")
            for metric, value in results['metrics'].items():
                print(f"  {metric}: {value:.4f}")

            # Save analysis results to file if an output path is provided
            if args.output:
                with open(args.output, 'w', encoding='utf-8') as f:
                    json.dump(results, f, ensure_ascii=False, indent=2)
                print(f"\nAnalysis results saved to: {args.output}")

        elif args.command == 'example':
            print("=== Example Usage of Metric Functions ===")

            # Example data
            results = [True, False, True, False, True]
            scores = [0.8, 0.6, 0.9, 0.7, 0.5]
            responses = ["Paris", "Paris", "London", "Paris", "Berlin"]

            # Compute metrics
            print(f"pass@3: {pass_at_k(results, 3)}")
            print(f"avg@3: {avg_at_k(scores, 3):.3f}")

            # Batch computation example
            all_results = [[True, False, True], [False, True, True], [True, True, False]]
            print(f"batch pass@2: {batch_pass_at_k(all_results, 2):.3f}")

            print("=== Example complete ===")

    except Exception as e:
        print(f"Execution error: {str(e)}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()
