#!/usr/bin/env python3
import argparse
import json
import re
import os
from pathlib import Path
from typing import Dict, Any, List, Optional
import glob


def extract_math_answer_from_output(text: str) -> Optional[str]:
    """Extract numerical answer from model output, handling both 'Answer:' format and standalone numbers."""
    if not text:
        return None
    
    # First, try to find the last occurrence of "Answer:" (case insensitive)
    answer_pattern = re.compile(r'Answer:\s*(.+)', re.IGNORECASE | re.MULTILINE)
    matches = list(answer_pattern.finditer(text))
    
    if matches:
        answer = matches[-1].group(1).strip()
        
        # Extract numbers from the answer, removing any currency symbols, commas, etc.
        number_match = re.search(r'-?\d+(?:,\d{3})*(?:\.\d+)?', answer.replace('$', '').replace(',', ''))
        if number_match:
            # Remove commas and return clean number
            return number_match.group(0).replace(',', '')
    
    # If no "Answer:" found, try to extract the last number from the entire text
    # This handles cases where the model just outputs a number or calculation result
    lines = text.strip().split('\n')
    
    # Look through lines from end to beginning to find the final numerical result
    for line in reversed(lines):
        line = line.strip()
        if not line:
            continue
            
        # Look for standalone numbers or numbers at the end of expressions
        # Handle cases like "3 * 3 * 60 = 180" or just "180"
        number_patterns = [
            r'=\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)$',  # Numbers after equals sign at end of line
            r'^(-?\d+(?:,\d{3})*(?:\.\d+)?)$',     # Standalone numbers (whole line)
            r'(-?\d+(?:,\d{3})*(?:\.\d+)?)$'       # Numbers at the end of any line
        ]
        
        for pattern in number_patterns:
            match = re.search(pattern, line.replace('$', '').replace(',', ''))
            if match:
                return match.group(1).replace(',', '')
    
    return None


def check_math_correctness(predicted: Optional[str], gold: str) -> bool:
    """Check if predicted numerical answer matches gold answer."""
    if predicted is None:
        return False
    
    try:
        # Convert both to float for numerical comparison
        pred_num = float(predicted)
        gold_num = float(gold)
        # Use small tolerance for floating point precision
        return abs(pred_num - gold_num) < 1e-9
    except ValueError:
        # If conversion fails, do string comparison as fallback
        return predicted.strip().lower() == gold.strip().lower()


def load_jsonl(filepath: str) -> List[Dict[str, Any]]:
    """Load data from JSONL file."""
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))
    return data


def save_extracted_answers(data: List[Dict[str, Any]], model_type: str, cot_mode: str, source_file: str):
    """Save extracted answers for debugging."""
    # Create output directory based on source file
    source_path = Path(source_file)
    output_dir = source_path.parent
    output_file = output_dir / f"extracted_answers_{source_path.stem}.json"
    
    extracted_data = []
    
    for item in data:
        # Get gold answer
        gold_answer = str(item['gold_answer'])
        
        extracted_item = {
            'id': item.get('id', 'unknown'),
            'gold_answer': gold_answer
        }
        
        # Extract baseline
        if 'baseline_output' in item:
            baseline_pred = extract_math_answer_from_output(item['baseline_output'])
            extracted_item['baseline_output'] = baseline_pred
        
        # Extract primed
        if 'primed_output' in item:
            primed_pred = extract_math_answer_from_output(item['primed_output'])
            extracted_item['primed_output'] = primed_pred
        
        # Extract persona outputs
        if 'persona_outputs' in item:
            for persona_name, persona_output in item['persona_outputs'].items():
                persona_pred = extract_math_answer_from_output(persona_output)
                extracted_item[f'persona_{persona_name}'] = persona_pred
        
        extracted_data.append(extracted_item)
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(extracted_data, f, indent=2)
    
    print(f"Saved extracted answers to {output_file}")


def evaluate_math_accuracy(data: List[Dict[str, Any]]) -> Dict[str, Dict]:
    """Evaluate accuracy for all variants."""
    total_items = len(data)
    
    if total_items == 0:
        return {}
    
    # Count correct answers for each variant
    variant_correct = {
        'baseline': 0,
        'primed': 0
    }
    
    # Count persona variants
    persona_correct = {}
    
    for item in data:
        gold_answer = str(item['gold_answer'])
        
        # Check baseline
        if 'baseline_output' in item:
            baseline_pred = extract_math_answer_from_output(item['baseline_output'])
            if check_math_correctness(baseline_pred, gold_answer):
                variant_correct['baseline'] += 1
        
        # Check primed
        if 'primed_output' in item:
            primed_pred = extract_math_answer_from_output(item['primed_output'])
            if check_math_correctness(primed_pred, gold_answer):
                variant_correct['primed'] += 1
        
        # Check persona variants
        if 'persona_outputs' in item:
            for persona_name, persona_output in item['persona_outputs'].items():
                if persona_name not in persona_correct:
                    persona_correct[persona_name] = 0
                
                persona_pred = extract_math_answer_from_output(persona_output)
                if check_math_correctness(persona_pred, gold_answer):
                    persona_correct[persona_name] += 1
    
    # Calculate accuracies and create results
    results = {
        'total_items': total_items,
        'accuracy': {}
    }
    
    for variant, correct in variant_correct.items():
        results['accuracy'][variant] = {
            'correct': correct,
            'total': total_items,
            'accuracy': (correct / total_items) * 100
        }
    
    for persona, correct in persona_correct.items():
        results['accuracy'][persona] = {
            'correct': correct,
            'total': total_items,
            'accuracy': (correct / total_items) * 100
        }
    
    return results


def parse_filename(filename: str) -> Dict[str, str]:
    """Parse generation filename to extract model and cot info."""
    basename = Path(filename).stem
    
    # Extract CoT mode - check no_cot first since it contains _cot
    if basename.endswith('_no_cot'):
        cot_mode = 'no_cot'
        model_part = basename[:-7]  # Remove '_no_cot'
    elif basename.endswith('_cot'):
        cot_mode = 'cot'
        model_part = basename[:-4]  # Remove '_cot'
    else:
        cot_mode = 'unknown'
        model_part = basename
    
    # Remove 'generations_' prefix
    if model_part.startswith('generations_'):
        model_name = model_part[12:]  # Remove 'generations_'
    else:
        model_name = model_part
    
    # Clean up model name
    if model_name.startswith('Meta_Llama'):
        model_type = 'llama'
    elif model_name.startswith('gpt'):
        model_type = 'gpt'
    elif model_name.startswith('gemini'):
        model_type = 'gemini'
    else:
        model_type = model_name.split('_')[0]
    
    return {
        'model_type': model_type,
        'model_name': model_name,
        'cot_mode': cot_mode
    }


def save_results(results: Dict[str, Any], output_dir: Path, model_type: str, cot_mode: str, source_file: str):
    """Save results to JSON file."""
    # For math domain, save to in-domain directory
    final_output_dir = output_dir / "in-domain"
    final_output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create filename
    output_file = final_output_dir / f"{model_type}_{cot_mode}_accuracy.json"
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2)
    
    print(f"Saved results to {output_file}")


def print_results(results: Dict[str, Any], model_type: str, cot_mode: str):
    """Print accuracy results."""
    total_items = results['total_items']
    accuracies = results['accuracy']
    
    print(f"\n{model_type.upper()} - {cot_mode.upper()} Results ({total_items} items):")
    print("=" * 60)
    
    # Print baseline and primed first
    for variant in ['baseline', 'primed']:
        if variant in accuracies:
            acc_data = accuracies[variant]
            print(f"{variant:25}: {acc_data['correct']:4d}/{acc_data['total']:4d} ({acc_data['accuracy']:6.2f}%)")
    
    print()
    
    # Print persona results
    persona_results = {k: v for k, v in accuracies.items() if k not in ['baseline', 'primed']}
    for persona, acc_data in sorted(persona_results.items()):
        print(f"{persona:25}: {acc_data['correct']:4d}/{acc_data['total']:4d} ({acc_data['accuracy']:6.2f}%)")


def main():
    parser = argparse.ArgumentParser(description="Evaluate accuracy of math generation results")
    parser.add_argument("--in", dest="input_dir", default="data/math/generation", help="Input directory with generation results")
    parser.add_argument("--out", default="eval/math", help="Output directory for accuracy results")
    args = parser.parse_args()
    
    input_dir = Path(args.input_dir)
    output_dir = Path(args.out)
    
    if not input_dir.exists():
        print(f"Error: Input directory {args.input_dir} does not exist")
        return 1
    
    # Find all generation files in the main math generation directories
    generation_files = []
    for model_dir in ['llama', 'gpt', 'gemini']:
        model_path = input_dir / model_dir
        if model_path.exists():
            generation_files.extend(list(model_path.glob("generations_*.jsonl")))
    
    if not generation_files:
        print(f"Error: No generation files found in {args.input_dir}")
        return 1
    
    print(f"Found {len(generation_files)} generation files")
    
    # Process each generation file
    for gen_file in generation_files:
        print(f"\nProcessing {gen_file.name}...")
        
        # Parse filename to get model and CoT info
        file_info = parse_filename(gen_file.name)
        model_type = file_info['model_type']
        cot_mode = file_info['cot_mode']
        
        # Load and evaluate data
        data = load_jsonl(str(gen_file))
        if not data:
            print(f"  Warning: No data in {gen_file.name}")
            continue
        
        # Save extracted answers for debugging
        save_extracted_answers(data, model_type, cot_mode, str(gen_file))
        
        # Evaluate accuracy
        results = evaluate_math_accuracy(data)
        results['file_info'] = file_info
        results['source_file'] = str(gen_file)
        results['domain'] = 'math'
        
        # Print results
        print_results(results, model_type, cot_mode)
        
        # Save results
        save_results(results, output_dir, model_type, cot_mode, str(gen_file))
    
    print(f"\nAll results saved to {output_dir}")
    return 0


if __name__ == "__main__":
    main()