#!/usr/bin/env python3
"""
Accuracy Evaluation Script for CODABENCH
Evaluates agent predictions against ground truth answers
"""

import json
import os
import argparse
from pathlib import Path
from typing import Dict, List, Any, Tuple
from datetime import datetime


def normalize_answer(answer: str) -> str:
    """Normalize answer by removing spaces, newlines and converting to lowercase"""
    if not answer:
        return ""
    return answer.replace(" ", "").replace("\n", "").replace("\r", "").lower().strip()


def parse_community_dir(dir_name: str) -> Tuple[str, int]:
    """Parse community directory name to extract type and ID"""
    parts = dir_name.rsplit('_community_', 1)
    if len(parts) == 2:
        community_type = parts[0]
        try:
            community_id = int(parts[1])
            return community_type, community_id
        except ValueError:
            pass
    return None, None


def find_all_instances(eval_env_dir: str) -> List[Dict[str, Any]]:
    """Find all evaluation instances in the directory"""
    eval_env_path = Path(eval_env_dir)
    instances = []
    settings = ['full_community']
    
    for community_dir in sorted(eval_env_path.iterdir()):
        if not community_dir.is_dir() or community_dir.name.endswith('.json'):
            continue
        
        community_type, community_id = parse_community_dir(community_dir.name)
        if community_type is None:
            continue
        
        for instance_dir in sorted(community_dir.glob('instance_*')):
            if not instance_dir.is_dir():
                continue
            
            try:
                instance_id = int(instance_dir.name.split('_')[1])
            except (IndexError, ValueError):
                continue
            
            metadata_path = instance_dir / 'metadata.json'
            if not metadata_path.exists():
                continue
            
            try:
                with open(metadata_path, 'r', encoding='utf-8') as f:
                    metadata = json.load(f)
            except Exception:
                continue
            
            settings_results = {}
            for setting in settings:
                result_path = instance_dir / setting / 'result.txt'
                if result_path.exists():
                    try:
                        with open(result_path, 'r', encoding='utf-8') as f:
                            settings_results[setting] = f.read().strip()
                    except Exception:
                        pass
            
            instances.append({
                'community_type': community_type,
                'community_id': community_id,
                'community_dir': community_dir.name,
                'instance_id': instance_id,
                'metadata': metadata,
                'settings_results': settings_results
            })
    
    return instances


def evaluate_all(instances: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Evaluate accuracy for all instances"""
    settings = ['full_community']
    
    evaluation_results = {
        'metadata': {
            'evaluation_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'total_instances': len(instances),
            'settings': settings
        },
        'overall': {setting: {'total': 0, 'correct': 0, 'accuracy': 0.0} for setting in settings},
        'by_community_type': {},
        'detailed_results': []
    }
    
    for inst in instances:
        community_type = inst['community_type']
        metadata = inst['metadata']
        correct_answer = metadata.get('answer', '')
        correct_normalized = normalize_answer(correct_answer)
        
        if community_type not in evaluation_results['by_community_type']:
            evaluation_results['by_community_type'][community_type] = {
                setting: {'total': 0, 'correct': 0, 'accuracy': 0.0} for setting in settings
            }
        
        detail = {
            'community_dir': inst['community_dir'],
            'instance_id': inst['instance_id'],
            'question': metadata.get('question', ''),
            'correct_answer': correct_answer,
            'results': {}
        }
        
        for setting in settings:
            if setting in inst['settings_results']:
                predicted = inst['settings_results'][setting]
                predicted_normalized = normalize_answer(predicted)
                is_correct = (correct_normalized == predicted_normalized)
                
                detail['results'][setting] = {
                    'predicted': predicted,
                    'is_correct': is_correct
                }
                
                evaluation_results['overall'][setting]['total'] += 1
                if is_correct:
                    evaluation_results['overall'][setting]['correct'] += 1
                
                evaluation_results['by_community_type'][community_type][setting]['total'] += 1
                if is_correct:
                    evaluation_results['by_community_type'][community_type][setting]['correct'] += 1
        
        evaluation_results['detailed_results'].append(detail)
    
    # Calculate accuracy
    for setting in settings:
        total = evaluation_results['overall'][setting]['total']
        if total > 0:
            evaluation_results['overall'][setting]['accuracy'] = \
                evaluation_results['overall'][setting]['correct'] / total
    
    for comm_type, stats in evaluation_results['by_community_type'].items():
        for setting in settings:
            total = stats[setting]['total']
            if total > 0:
                stats[setting]['accuracy'] = stats[setting]['correct'] / total
    
    return evaluation_results


def print_summary(results: Dict[str, Any]):
    """Print evaluation summary"""
    settings = results['metadata']['settings']
    
    print("\n" + "=" * 80)
    print("  EVALUATION RESULTS SUMMARY")
    print("=" * 80)
    
    print(f"\n  Evaluation Time: {results['metadata']['evaluation_time']}")
    print(f"  Total Instances: {results['metadata']['total_instances']}")
    
    print("\n" + "-" * 80)
    print("  Overall Accuracy")
    print("-" * 80)
    
    print(f"\n  {'Setting':<20} {'Correct':>10} {'Total':>10} {'Accuracy':>12}")
    print("  " + "-" * 54)
    
    for setting in settings:
        stats = results['overall'][setting]
        if stats['total'] > 0:
            print(f"  {setting:<20} {stats['correct']:>10} {stats['total']:>10} "
                  f"{stats['accuracy']*100:>10.2f}%")
    
    print("\n" + "-" * 80)
    print("  Accuracy by Community Type")
    print("-" * 80)
    
    for comm_type, stats in sorted(results['by_community_type'].items()):
        print(f"\n  {comm_type}:")
        print(f"     {'Setting':<18} {'Correct':>8} {'Total':>8} {'Accuracy':>10}")
        print("     " + "-" * 46)
        
        for setting in settings:
            s = stats[setting]
            if s['total'] > 0:
                print(f"     {setting:<18} {s['correct']:>8} {s['total']:>8} {s['accuracy']*100:>9.2f}%")
    
    print("\n" + "=" * 80 + "\n")


def save_results(results: Dict[str, Any], output_path: str):
    """Save evaluation results to JSON file"""
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"Results saved to: {output_path}")


def main():
    parser = argparse.ArgumentParser(description='Evaluate CODABENCH results')
    parser.add_argument('eval_dir', help='Path to evaluation environment directory')
    parser.add_argument('--output', '-o', help='Output directory for results', default='.')
    args = parser.parse_args()
    
    eval_env_dir = Path(args.eval_dir)
    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print("=" * 80)
    print("  Starting Evaluation...")
    print("=" * 80)
    print(f"\n  Evaluation Directory: {eval_env_dir}\n")
    
    # Find all instances
    instances = find_all_instances(str(eval_env_dir))
    
    if not instances:
        print("No instances found. Please check the directory path.")
        return
    
    print(f"Found {len(instances)} instances\n")
    
    # Evaluate
    results = evaluate_all(instances)
    
    # Save results
    eval_name = eval_env_dir.name
    output_json_path = output_dir / f'evaluation_{eval_name}_results.json'
    save_results(results, str(output_json_path))
    
    # Save simplified summary
    simplified_results = {
        'metadata': results['metadata'],
        'overall': results['overall'],
        'by_community_type': results['by_community_type']
    }
    simplified_path = output_dir / f'evaluation_{eval_name}_summary.json'
    with open(simplified_path, 'w', encoding='utf-8') as f:
        json.dump(simplified_results, f, indent=2, ensure_ascii=False)
    print(f"Summary saved to: {simplified_path}")
    
    # Print summary
    print_summary(results)


if __name__ == '__main__':
    main()
