#!/usr/bin/env python3
"""
Complete Electricity domain evaluation
Compute RAGAS (Answer Correctness, Faithfulness) and ROUGE (L, 1, 2) metrics
54 questions × 16 methods = 864 responses

NOTE: LSTM_Attention is used as ground truth/reference baseline for electricity domain
"""

import os
import pandas as pd
import numpy as np
import logging
from datasets import Dataset
from ragas import evaluate
from ragas.metrics import answer_correctness, faithfulness
from rouge_score import rouge_scorer

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('results/electricity_complete_evaluation.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# OpenAI API key
os.environ['OPENAI_API_KEY'] = 'sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'  # Replace with your actual API key

def compute_rouge_metrics(prediction: str, reference: str):
    """Compute ROUGE-L, ROUGE-1, and ROUGE-2 metrics"""
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = scorer.score(reference, prediction)
    return {
        'rouge_1': scores['rouge1'].fmeasure,
        'rouge_2': scores['rouge2'].fmeasure,
        'rouge_l': scores['rougeL'].fmeasure
    }

def evaluate_electricity_complete():
    """Evaluate complete Electricity domain (864 responses)"""
    logger.info("="*80)
    logger.info("ELECTRICITY DOMAIN - COMPLETE EVALUATION")
    logger.info("54 questions × 16 methods = 864 responses")
    logger.info("="*80)
    logger.info("NOTE: LSTM_Attention baseline used as ground truth reference")
    logger.info("="*80)
    
    # Load data
    df = pd.read_csv('results/all_responses_with_ground_truth.csv')
    elec_df = df[df['domain'] == 'electricity'].copy()
    
    logger.info(f"Total Electricity responses: {len(elec_df)}")
    logger.info(f"Methods: {elec_df['method'].nunique()}")
    logger.info(f"Questions: {elec_df['question_id'].nunique()}")
    
    results = []
    
    # Process in batches
    batch_size = 20
    total_batches = (len(elec_df) + batch_size - 1) // batch_size
    
    for batch_idx in range(total_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(elec_df))
        batch_df = elec_df.iloc[start_idx:end_idx]
        
        logger.info(f"\nBatch {batch_idx + 1}/{total_batches} (rows {start_idx}-{end_idx})")
        
        # Compute RAGAS metrics
        try:
            ragas_data = {
                'question': batch_df['question'].tolist(),
                'answer': batch_df['response'].tolist(),
                'contexts': [[gt] for gt in batch_df['ground_truth'].tolist()],
                'ground_truth': batch_df['ground_truth'].tolist()
            }
            
            dataset = Dataset.from_dict(ragas_data)
            result = evaluate(dataset, metrics=[answer_correctness, faithfulness])
            
            df_result = result.to_pandas()
            ac_scores = df_result['answer_correctness'].tolist()
            f_scores = df_result['faithfulness'].tolist()
            
            logger.info(f"RAGAS - AC: {np.mean(ac_scores):.3f}, F: {np.mean(f_scores):.3f}")
            
        except Exception as e:
            logger.error(f"RAGAS failed: {e}")
            ac_scores = [0.0] * len(batch_df)
            f_scores = [0.0] * len(batch_df)
        
        # Compute ROUGE for each response
        for batch_position, (idx, row) in enumerate(batch_df.iterrows()):
            try:
                rouge_metrics = compute_rouge_metrics(row['response'], row['ground_truth'])
            except:
                rouge_metrics = {'rouge_1': 0.0, 'rouge_2': 0.0, 'rouge_l': 0.0}
            
            results.append({
                'domain': 'electricity',
                'method': row['method'],
                'method_type': row['method_type'],
                'question_id': row['question_id'],
                'question': row['question'],
                'response': row['response'],
                'ground_truth': row['ground_truth'],
                'answer_correctness': ac_scores[batch_position],
                'faithfulness': f_scores[batch_position],
                'rouge_1': rouge_metrics['rouge_1'],
                'rouge_2': rouge_metrics['rouge_2'],
                'rouge_l': rouge_metrics['rouge_l']
            })
    
    # Create results dataframe
    results_df = pd.DataFrame(results)
    
    # Save results
    results_df.to_csv('results/electricity_complete_metrics.csv', index=False)
    logger.info(f"\n✅ Saved results to results/electricity_complete_metrics.csv")
    
    # Print summary by method
    logger.info("\n" + "="*80)
    logger.info("SUMMARY BY METHOD")
    logger.info("="*80)
    
    summary = results_df.groupby('method').agg({
        'answer_correctness': ['mean', 'std'],
        'faithfulness': ['mean', 'std'],
        'rouge_l': ['mean', 'std']
    }).round(4)
    
    print("\n" + summary.to_string())
    
    # Overall statistics
    logger.info("\n" + "="*80)
    logger.info("OVERALL STATISTICS")
    logger.info("="*80)
    logger.info(f"Total responses: {len(results_df)}")
    logger.info(f"Answer Correctness: {results_df['answer_correctness'].mean():.4f} ± {results_df['answer_correctness'].std():.4f}")
    logger.info(f"Faithfulness: {results_df['faithfulness'].mean():.4f} ± {results_df['faithfulness'].std():.4f}")
    logger.info(f"ROUGE-1: {results_df['rouge_1'].mean():.4f} ± {results_df['rouge_1'].std():.4f}")
    logger.info(f"ROUGE-2: {results_df['rouge_2'].mean():.4f} ± {results_df['rouge_2'].std():.4f}")
    logger.info(f"ROUGE-L: {results_df['rouge_l'].mean():.4f} ± {results_df['rouge_l'].std():.4f}")
    
    logger.info("\n" + "="*80)
    logger.info("ELECTRICITY EVALUATION COMPLETE")
    logger.info("="*80)
    
    return results_df

if __name__ == '__main__':
    try:
        results = evaluate_electricity_complete()
        print("\n✅ Electricity evaluation completed successfully!")
    except Exception as e:
        logger.error(f"Evaluation failed: {e}", exc_info=True)
