#!/usr/bin/env python3
import argparse
import json
import re
import os
from pathlib import Path
from typing import Dict, Any, List, Optional
import glob


def extract_answer_from_output(text: str, domain: str = 'math') -> Optional[str]:
    """Extract answer after 'Answer:' from model output."""
    if not text:
        return None
    
    # Find the last occurrence of "Answer:" (case insensitive)
    pattern = re.compile(r'Answer:\s*(.+)', re.IGNORECASE | re.MULTILINE)
    matches = list(pattern.finditer(text))
    
    if matches:
        answer = matches[-1].group(1).strip()
        
        if domain in ['legal', 'medical', 'psychology', 'commonsense']:
            # For multiple choice problems, extract number (1-4) and convert to letter
            number_match = re.search(r'[1-4]', answer)
            if number_match:
                num = int(number_match.group(0))
                # Convert 1->A, 2->B, 3->C, 4->D
                return chr(64 + num)  # 64 + 1 = 65 = 'A'
            # Also try to extract letter directly in case model outputs letters
            letter_match = re.search(r'[A-Da-d]', answer)
            if letter_match:
                return letter_match.group(0).upper()
        else:
            # For math problems, extract just the number
            # Look for integers or decimals
            number_match = re.search(r'-?\d+(?:\.\d+)?', answer)
            if number_match:
                return number_match.group(0)
    
    return None


def check_correctness(predicted: Optional[str], gold: str, domain: str = 'math') -> bool:
    """Check if predicted answer matches gold answer."""
    if predicted is None:
        return False
    
    if domain in ['legal', 'medical', 'psychology', 'commonsense']:
        # For multiple choice problems, do direct ABCD comparison
        return predicted.strip().upper() == gold.strip().upper()
    else:
        try:
            # Convert both to float for numerical comparison
            pred_num = float(predicted)
            gold_num = float(gold)
            return abs(pred_num - gold_num) < 1e-9  # Small tolerance for floating point
        except ValueError:
            # If conversion fails, do string comparison
            return predicted.strip().lower() == gold.strip().lower()


def load_jsonl(filepath: str) -> List[Dict[str, Any]]:
    """Load data from JSONL file."""
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))
    return data


def create_gold_answer_to_label_mapping(domain: str) -> Dict[str, str]:
    """Create mapping from gold_answer text to ABCD label for multiple choice domains."""
    if domain not in ['legal', 'medical', 'psychology', 'commonsense']:
        return {}
    
    mapping = {}
    try:
        # Load original dataset to get the mapping
        original_dataset_path = f"data/{domain}/dataset/{domain}_dataset.jsonl"
        if os.path.exists(original_dataset_path):
            original_data = load_jsonl(original_dataset_path)
            for item in original_data:
                if 'gold_answer' in item and 'gold_label' in item:
                    mapping[item['gold_answer']] = item['gold_label'].upper()
    except Exception as e:
        print(f"Warning: Could not create gold answer mapping: {e}")
    
    return mapping


def evaluate_accuracy(data: List[Dict[str, Any]], domain: str = 'math') -> Dict[str, Dict]:
    """Evaluate accuracy for all variants."""
    total_items = len(data)
    
    if total_items == 0:
        return {}
    
    # Create mapping from gold answer text to ABCD labels for legal domain
    gold_mapping = create_gold_answer_to_label_mapping(domain)
    
    # Count correct answers for each variant
    variant_correct = {
        'baseline': 0,
        'primed': 0
    }
    
    # Count persona variants
    persona_correct = {}
    
    for item in data:
        # For multiple choice domains, use gold_label directly
        if domain in ['legal', 'medical', 'psychology', 'commonsense']:
            if 'gold_label' in item:
                gold_answer = str(item['gold_label']).upper()
            else:
                # Fallback to mapping if no gold_label
                gold_text = str(item['gold_answer'])
                gold_answer = gold_mapping.get(gold_text, gold_text).upper()
        else:
            gold_answer = str(item['gold_answer'])
        
        # Check baseline
        if 'baseline_output' in item:
            baseline_pred = extract_answer_from_output(item['baseline_output'], domain)
            if check_correctness(baseline_pred, gold_answer, domain):
                variant_correct['baseline'] += 1
        
        # Check primed
        if 'primed_output' in item:
            primed_pred = extract_answer_from_output(item['primed_output'], domain)
            if check_correctness(primed_pred, gold_answer, domain):
                variant_correct['primed'] += 1
        
        # Check persona variants
        if 'persona_outputs' in item:
            for persona_name, persona_output in item['persona_outputs'].items():
                if persona_name not in persona_correct:
                    persona_correct[persona_name] = 0
                
                persona_pred = extract_answer_from_output(persona_output, domain)
                if check_correctness(persona_pred, gold_answer, domain):
                    persona_correct[persona_name] += 1
    
    # Calculate accuracies and create results
    results = {
        'total_items': total_items,
        'accuracy': {}
    }
    
    for variant, correct in variant_correct.items():
        results['accuracy'][variant] = {
            'correct': correct,
            'total': total_items,
            'accuracy': (correct / total_items) * 100
        }
    
    for persona, correct in persona_correct.items():
        results['accuracy'][persona] = {
            'correct': correct,
            'total': total_items,
            'accuracy': (correct / total_items) * 100
        }
    
    return results


def parse_filename(filename: str) -> Dict[str, str]:
    """Parse generation filename to extract model and cot info."""
    # Example: generations_Meta_Llama_3.1_8B_Instruct_cot.jsonl
    # or generations_gpt_4.1_nano_2025_04_14_no_cot.jsonl
    
    basename = Path(filename).stem
    
    # Extract CoT mode - check no_cot first since it contains _cot
    if basename.endswith('_no_cot'):
        cot_mode = 'no_cot'
        model_part = basename[:-7]  # Remove '_no_cot'
    elif basename.endswith('_cot'):
        cot_mode = 'cot'
        model_part = basename[:-4]  # Remove '_cot'
    else:
        cot_mode = 'unknown'
        model_part = basename
    
    # Remove 'generations_' prefix
    if model_part.startswith('generations_'):
        model_name = model_part[12:]  # Remove 'generations_'
    else:
        model_name = model_part
    
    # Clean up model name
    if model_name.startswith('Meta_Llama'):
        model_type = 'llama'
    elif model_name.startswith('gpt'):
        model_type = 'gpt'
    elif model_name.startswith('gemini'):
        model_type = 'gemini'
    else:
        model_type = model_name.split('_')[0]
    
    return {
        'model_type': model_type,
        'model_name': model_name,
        'cot_mode': cot_mode
    }


def save_extracted_answers(data: List[Dict[str, Any]], domain: str, model_type: str, cot_mode: str, source_file: str, gold_mapping: Dict[str, str]):
    """Save extracted answers for debugging."""
    # Create output directory based on source file
    source_path = Path(source_file)
    output_dir = source_path.parent
    output_file = output_dir / f"extracted_answers_{source_path.stem}.json"
    
    extracted_data = []
    
    for item in data:
        # Get gold answer
        if domain in ['legal', 'medical', 'psychology', 'commonsense']:
            if 'gold_label' in item:
                gold_label = str(item['gold_label']).upper()
            else:
                # Fallback to mapping if no gold_label
                gold_text = str(item['gold_answer'])
                gold_label = gold_mapping.get(gold_text, gold_text).upper()
        else:
            gold_label = str(item['gold_answer'])
        
        extracted_item = {
            'id': item.get('id', 'unknown'),
            'gold_answer': gold_label
        }
        
        # Extract baseline
        if 'baseline_output' in item:
            baseline_pred = extract_answer_from_output(item['baseline_output'], domain)
            extracted_item['baseline_output'] = baseline_pred
        
        # Extract primed
        if 'primed_output' in item:
            primed_pred = extract_answer_from_output(item['primed_output'], domain)
            extracted_item['primed_output'] = primed_pred
        
        # Extract persona outputs
        if 'persona_outputs' in item:
            for persona_name, persona_output in item['persona_outputs'].items():
                persona_pred = extract_answer_from_output(persona_output, domain)
                extracted_item[f'persona_{persona_name}'] = persona_pred
        
        extracted_data.append(extracted_item)
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(extracted_data, f, indent=2)
    
    print(f"Saved extracted answers to {output_file}")


def save_results(results: Dict[str, Any], output_dir: Path, model_type: str, cot_mode: str, source_file: str):
    """Save results to JSON file."""
    # Determine if this is in-domain or cross-domain based on source file path
    if "/math/generation/legal/" in source_file:
        # Cross-domain: legal personas on math data
        final_output_dir = output_dir
    elif "/math/generation/llama/" in source_file or "/math/generation/gpt/" in source_file or "/math/generation/gemini/" in source_file:
        # In-domain: math personas on math data
        final_output_dir = output_dir / "in-domain"
    else:
        # Default to original behavior
        final_output_dir = output_dir
    
    final_output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create filename
    output_file = final_output_dir / f"{model_type}_{cot_mode}_accuracy.json"
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2)
    
    print(f"Saved results to {output_file}")


def print_results(results: Dict[str, Any], model_type: str, cot_mode: str):
    """Print accuracy results."""
    total_items = results['total_items']
    accuracies = results['accuracy']
    
    print(f"\n{model_type.upper()} - {cot_mode.upper()} Results ({total_items} items):")
    print("=" * 60)
    
    # Print baseline and primed first
    for variant in ['baseline', 'primed']:
        if variant in accuracies:
            acc_data = accuracies[variant]
            print(f"{variant:25}: {acc_data['correct']:4d}/{acc_data['total']:4d} ({acc_data['accuracy']:6.2f}%)")
    
    print()
    
    # Print persona results
    persona_results = {k: v for k, v in accuracies.items() if k not in ['baseline', 'primed']}
    for persona, acc_data in sorted(persona_results.items()):
        print(f"{persona:25}: {acc_data['correct']:4d}/{acc_data['total']:4d} ({acc_data['accuracy']:6.2f}%)")


def main():
    parser = argparse.ArgumentParser(description="Evaluate accuracy of math generation results")
    parser.add_argument("--in", dest="input_dir", default="data/medical/generation", help="Input directory with generation results")
    parser.add_argument("--out", default="eval/medical", help="Output directory for accuracy results")
    args = parser.parse_args()
    
    input_dir = Path(args.input_dir)
    output_dir = Path(args.out)
    
    if not input_dir.exists():
        print(f"Error: Input directory {args.input_dir} does not exist")
        return 1
    
    # Find all generation files
    generation_files = list(input_dir.glob("**/generations_*.jsonl"))
    
    if not generation_files:
        print(f"Error: No generation files found in {args.input_dir}")
        return 1
    
    print(f"Found {len(generation_files)} generation files")
    
    # Process each generation file
    for gen_file in generation_files:
        print(f"\nProcessing {gen_file.name}...")
        
        # Parse filename to get model and CoT info
        file_info = parse_filename(gen_file.name)
        model_type = file_info['model_type']
        cot_mode = file_info['cot_mode']
        
        # Detect domain from file path
        domain = 'math'  # default
        if '/legal/' in str(gen_file):
            domain = 'legal'
        elif '/math/' in str(gen_file):
            domain = 'math'
        elif '/medical/' in str(gen_file):
            domain = 'medical'
        elif '/psychology/' in str(gen_file):
            domain = 'psychology'
        elif '/commonsense/' in str(gen_file):
            domain = 'commonsense'
        
        # Load and evaluate data
        data = load_jsonl(str(gen_file))
        if not data:
            print(f"  Warning: No data in {gen_file.name}")
            continue
        
        # Create gold mapping 
        gold_mapping = create_gold_answer_to_label_mapping(domain)
        
        # Save extracted answers for debugging
        save_extracted_answers(data, domain, model_type, cot_mode, str(gen_file), gold_mapping)
        
        # Evaluate accuracy
        results = evaluate_accuracy(data, domain)
        results['file_info'] = file_info
        results['source_file'] = str(gen_file)
        results['domain'] = domain
        
        # Print results
        print_results(results, model_type, cot_mode)
        
        # Save results
        save_results(results, output_dir, model_type, cot_mode, str(gen_file))
    
    print(f"\nAll results saved to {output_dir}")
    return 0


if __name__ == "__main__":
    main()