#!/usr/bin/env python3
"""
Multi-domain evaluation script for reference-based metrics.
Evaluates across 4 domains and computes per-domain + overall averages.
"""

import argparse
import json
import os
import sys
from typing import Dict, List

try:
    import evaluate
    import nltk
    nltk.download('punkt', quiet=True)
    nltk.download('wordnet', quiet=True)
    nltk.download('omw-1.4', quiet=True)
except ImportError as e:
    print(f"Missing package: {e}. Install with: pip install evaluate nltk")
    sys.exit(1)

# Model-specific prediction directory patterns
MODEL_PATTERNS = {
    'GPT-4.1': 'gpt_41_benchmark_results_{}',
    'GPT-4o': 'gpt_4o_benchmark_results_{}',
    'GPT-5': 'gpt_5_benchmark_results_{}',
    'GPT-5-chat': 'gpt_5_chat_benchmark_results_{}',
    'O4-mini': 'o4_mini_benchmark_results_{}'
}

def get_domain_config(model_name: str, domain: str) -> Dict:
    """Get domain configuration for a specific model."""
    pattern = MODEL_PATTERNS.get(model_name, 'gpt_41_chat_benchmark_results_{}')
    return {
        'ref_dir': 'golden_documents',
        'pred_suffix': pattern.format(domain)
    }

def load_metrics():
    """Load all evaluation metrics."""
    return {
        'rouge': evaluate.load('rouge'),
        'bleu': evaluate.load('bleu'),
        'meteor': evaluate.load('meteor')
    }

def read_reference(path: str) -> str:
    """Read reference from .txt or .json (MarkdownDocContent field)."""
    if path.endswith('.txt'):
        with open(path, "r", encoding="utf-8") as f:
            return f.read()
    else:  # .json
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        return data.get("MarkdownDocContent", "")

def read_prediction(path: str) -> str:
    """Extract generated content from prediction JSON."""
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data.get("detailed_evaluation", {}).get("document", {}).get("content", "")

def evaluate_texts(reference: str, prediction: str, metrics: Dict) -> Dict:
    """Compute all metrics for a reference-prediction pair."""
    # ROUGE
    rouge_result = metrics['rouge'].compute(
        predictions=[prediction],
        references=[reference],
        use_stemmer=True
    )
    
    # BLEU
    bleu_result = metrics['bleu'].compute(
        predictions=[prediction],
        references=[[reference]]
    )
    
    # METEOR
    meteor_result = metrics['meteor'].compute(
        predictions=[prediction],
        references=[reference]
    )
    
    # Handle ROUGE score format (may be float or dict)
    def extract_score(score):
        if isinstance(score, dict):
            return score
        return {'precision': score, 'recall': score, 'fmeasure': score}
    
    return {
        'rouge1': extract_score(rouge_result['rouge1']),
        'rouge2': extract_score(rouge_result['rouge2']),
        'rougeL': extract_score(rouge_result['rougeL']),
        'bleu': bleu_result['bleu'],
        'meteor': meteor_result['meteor']
    }

def find_reference(pred_path: str, ref_dir: str, domain: str = None) -> str:
    """Find matching reference file for prediction."""
    pred_name = os.path.basename(pred_path)
    
    # Extract query number from "query_N_result.json"
    if pred_name.startswith('query_') and pred_name.endswith('_result.json'):
        query_num = int(pred_name.replace('query_', '').replace('_result.json', ''))
        
        # First try the new extracted golden documents format: domain_qid.json
        if domain:
            golden_ref_path = os.path.join(ref_dir, f"{domain}_{query_num:03d}.json")
            if os.path.exists(golden_ref_path):
                return golden_ref_path
        
        # Fallback to original nested structure: ref_dir/chat_0_00N/chat_0_00N_f4_0.json
        chat_id = f"chat_0_{query_num:03d}"
        ref_path = os.path.join(ref_dir, chat_id, f"{chat_id}_f4_0.json")
        if os.path.exists(ref_path):
            return ref_path
    
    return None

def evaluate_domain(model_dir: str, domain: str, domain_config: Dict, metrics: Dict) -> Dict:
    """Evaluate a single domain."""
    ref_dir = domain_config['ref_dir']
    pred_dir = os.path.join(model_dir, domain_config['pred_suffix'])
    
    print(f"\nEvaluating {domain.upper()} domain...")
    print(f"   Reference: {ref_dir}")
    print(f"   Prediction: {pred_dir}")
    
    if not os.path.exists(ref_dir):
        print(f"   Reference directory not found: {ref_dir}")
        return {'results': {}, 'count': 0, 'average': None}
    
    if not os.path.exists(pred_dir):
        print(f"   Prediction directory not found: {pred_dir}")
        return {'results': {}, 'count': 0, 'average': None}
    
    results = {}
    pred_files = [f for f in os.listdir(pred_dir) if f.endswith('.json')]
    processed_count = 0
    
    for pred_file in pred_files:
        pred_path = os.path.join(pred_dir, pred_file)
        ref_path = find_reference(pred_path, ref_dir, domain)
        
        if not ref_path:
            continue
        
        try:
            ref_text = read_reference(ref_path)
            pred_text = read_prediction(pred_path)
            
            if ref_text and pred_text:
                file_results = evaluate_texts(ref_text, pred_text, metrics)
                tag = pred_file.replace('.json', '')
                results[tag] = file_results
                processed_count += 1
                
        except Exception as e:
            print(f"   Error processing {pred_file}: {e}")
    
    # Compute domain average
    domain_avg = None
    if results:
        domain_avg = compute_average(results)
        print(f"   Processed {processed_count} files")
        print_results(domain_avg, f"{domain.upper()} Average")
    else:
        print(f"   No files processed for {domain}")
    
    return {
        'results': results,
        'count': processed_count,
        'average': domain_avg
    }

def compute_average(results: Dict) -> Dict:
    """Compute average scores across multiple results."""
    if not results:
        return None
    
    avg_results = {}
    n = len(results)
    
    for metric in ['rouge1', 'rouge2', 'rougeL', 'bleu', 'meteor']:
        if metric in ['bleu', 'meteor']:
            avg_results[metric] = sum(r[metric] for r in results.values()) / n
        else:
            # Use fmeasure (F1) for ROUGE metrics
            avg_score = sum(r[metric].get('fmeasure', r[metric].get('precision', 0)) 
                          for r in results.values()) / n
            avg_results[metric] = {'precision': avg_score, 'recall': avg_score, 'fmeasure': avg_score}
    
    return avg_results

def print_results(results: Dict, title: str = "Results"):
    """Print formatted evaluation results."""
    if not results:
        return
    
    def pct(x): return f"{x*100:.2f}%"
    
    print(f"\n=== {title} ===")
    r1, r2, rl = results['rouge1'], results['rouge2'], results['rougeL']
    print(f"ROUGE-1: P:{pct(r1.get('precision', r1.get('fmeasure', 0)))} "
          f"R:{pct(r1.get('recall', r1.get('fmeasure', 0)))} "
          f"F1:{pct(r1.get('fmeasure', r1.get('precision', 0)))}")
    print(f"ROUGE-2: P:{pct(r2.get('precision', r2.get('fmeasure', 0)))} "
          f"R:{pct(r2.get('recall', r2.get('fmeasure', 0)))} "
          f"F1:{pct(r2.get('fmeasure', r2.get('precision', 0)))}")
    print(f"ROUGE-L: P:{pct(rl.get('precision', rl.get('fmeasure', 0)))} "
          f"R:{pct(rl.get('recall', rl.get('fmeasure', 0)))} "
          f"F1:{pct(rl.get('fmeasure', rl.get('precision', 0)))}")
    print(f"BLEU: {pct(results['bleu'])}")
    print(f"METEOR: {pct(results['meteor'])}")

def main():
    parser = argparse.ArgumentParser(description="Multi-domain evaluation for reference-based metrics")
    parser.add_argument("--model_dir", required=True, 
                       help="Model directory (e.g., 'GPT-4.1')")
    parser.add_argument("--output_json", 
                       help="Save detailed results to JSON file")
    parser.add_argument("--domains", nargs='+', 
                       choices=['finance', 'healthcare', 'manufacturing', 'technology'],
                       default=['finance', 'healthcare', 'manufacturing', 'technology'],
                       help="Domains to evaluate (default: all)")
    
    args = parser.parse_args()
    
    print("Multi-Domain Evaluation Starting...")
    print(f"Model Directory: {args.model_dir}")
    print(f"Domains: {', '.join(args.domains)}")
    
    # Load metrics
    print("\nLoading evaluation metrics...")
    metrics = load_metrics()
    print("Metrics loaded successfully")
    
    # Evaluate each domain
    all_domain_results = {}
    domain_averages = {}
    total_files = 0
    overall_avg = None
    
    for domain in args.domains:
        domain_config = get_domain_config(args.model_dir, domain)
        domain_data = evaluate_domain(args.model_dir, domain, domain_config, metrics)
        all_domain_results[domain] = domain_data
        
        if domain_data['average']:
            domain_averages[domain] = domain_data['average']
            total_files += domain_data['count']
    
    # Compute overall average across domains
    if domain_averages:
        print(f"Computing overall average across {len(domain_averages)} domains...")
        overall_avg = compute_average(domain_averages)
        print_results(overall_avg, "OVERALL AVERAGE (Cross-Domain)")
        
        # Summary table
        print(f"\nSUMMARY")
        print(f"{'Domain':<15} {'Files':<8} {'ROUGE-1':<10} {'ROUGE-2':<10} {'ROUGE-L':<10} {'BLEU':<8} {'METEOR':<8}")
        print("-" * 75)
        
        for domain in args.domains:
            if domain in domain_averages:
                avg = domain_averages[domain]
                r1_f1 = avg['rouge1'].get('fmeasure', avg['rouge1'].get('precision', 0)) * 100
                r2_f1 = avg['rouge2'].get('fmeasure', avg['rouge2'].get('precision', 0)) * 100
                rl_f1 = avg['rougeL'].get('fmeasure', avg['rougeL'].get('precision', 0)) * 100
                bleu = avg['bleu'] * 100
                meteor = avg['meteor'] * 100
                count = all_domain_results[domain]['count']
                
                print(f"{domain.capitalize():<15} {count:<8} {r1_f1:<10.2f} {r2_f1:<10.2f} {rl_f1:<10.2f} {bleu:<8.2f} {meteor:<8.2f}")
        
        # Overall row
        if overall_avg:
            r1_f1 = overall_avg['rouge1'].get('fmeasure', overall_avg['rouge1'].get('precision', 0)) * 100
            r2_f1 = overall_avg['rouge2'].get('fmeasure', overall_avg['rouge2'].get('precision', 0)) * 100
            rl_f1 = overall_avg['rougeL'].get('fmeasure', overall_avg['rougeL'].get('precision', 0)) * 100
            bleu = overall_avg['bleu'] * 100
            meteor = overall_avg['meteor'] * 100
            
            print("-" * 75)
            print(f"{'OVERALL':<15} {total_files:<8} {r1_f1:<10.2f} {r2_f1:<10.2f} {rl_f1:<10.2f} {bleu:<8.2f} {meteor:<8.2f}")
    
    # Save results
    if args.output_json:
        output_data = {
            'model': args.model_dir,
            'domains_evaluated': args.domains,
            'domain_results': all_domain_results,
            'domain_averages': domain_averages,
            'overall_average': overall_avg,
            'summary': {
                'total_files_processed': total_files,
                'domains_processed': len(domain_averages)
            }
        }
        
        with open(args.output_json, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, indent=2)
        print(f"\nResults saved to: {args.output_json}")
    
    print(f"\nMulti-domain evaluation complete!")
    print(f"Total files processed: {total_files}")
    print(f"Domains evaluated: {len(domain_averages)}/{len(args.domains)}")

if __name__ == "__main__":
    main()