#!/usr/bin/env python3
"""
Complete Greenhouse Domain Evaluation
Computes RAGAS (AC, F) + ROUGE (L, 1, 2) for all Greenhouse responses
"""

import pandas as pd
import numpy as np
import logging
import os
from datetime import datetime

logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('results/greenhouse_evaluation.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Set OpenAI API key
os.environ['OPENAI_API_KEY'] = 'sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'  # Replace with your actual API key

def compute_rouge_metrics(prediction, reference):
    """Compute ROUGE-L, ROUGE-1, ROUGE-2"""
    try:
        from rouge_score import rouge_scorer
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        scores = scorer.score(reference, prediction)
        return {
            'rouge1_precision': scores['rouge1'].precision,
            'rouge1_recall': scores['rouge1'].recall,
            'rouge1_f1': scores['rouge1'].fmeasure,
            'rouge2_precision': scores['rouge2'].precision,
            'rouge2_recall': scores['rouge2'].recall,
            'rouge2_f1': scores['rouge2'].fmeasure,
            'rougeL_precision': scores['rougeL'].precision,
            'rougeL_recall': scores['rougeL'].recall,
            'rougeL_f1': scores['rougeL'].fmeasure,
        }
    except Exception as e:
        logger.warning(f"ROUGE computation failed: {e}")
        return {k: 0.0 for k in ['rouge1_precision', 'rouge1_recall', 'rouge1_f1',
                                   'rouge2_precision', 'rouge2_recall', 'rouge2_f1',
                                   'rougeL_precision', 'rougeL_recall', 'rougeL_f1']}

def compute_ragas_batch(questions, responses, ground_truths, batch_size=20):
    """Compute RAGAS metrics in batches"""
    try:
        from ragas import evaluate
        from ragas.metrics import answer_correctness, faithfulness
        from datasets import Dataset
        
        all_results = []
        total_batches = (len(questions) + batch_size - 1) // batch_size
        
        logger.info(f"Computing RAGAS for {len(questions)} responses in {total_batches} batches...")
        
        for i in range(0, len(questions), batch_size):
            batch_end = min(i + batch_size, len(questions))
            batch_num = i // batch_size + 1
            
            logger.info(f"  Batch {batch_num}/{total_batches}: Processing {i+1}-{batch_end}")
            
            batch_data = {
                'question': questions[i:batch_end],
                'answer': responses[i:batch_end],
                'ground_truth': ground_truths[i:batch_end],
                'contexts': [['No context available']] * (batch_end - i)
            }
            
            dataset = Dataset.from_dict(batch_data)
            
            try:
                result = evaluate(
                    dataset,
                    metrics=[answer_correctness, faithfulness],
                    raise_exceptions=False
                )
                
                df_result = result.to_pandas()
                
                for idx in range(len(df_result)):
                    all_results.append({
                        'answer_correctness': df_result['answer_correctness'].iloc[idx],
                        'faithfulness': df_result['faithfulness'].iloc[idx]
                    })
                
                avg_ac = df_result['answer_correctness'].mean()
                avg_f = df_result['faithfulness'].mean()
                logger.info(f"    Batch {batch_num} complete: AC={avg_ac:.3f}, F={avg_f:.3f}")
                
            except Exception as e:
                logger.error(f"    Batch {batch_num} failed: {e}")
                for _ in range(batch_end - i):
                    all_results.append({
                        'answer_correctness': 0.0,
                        'faithfulness': 0.0
                    })
        
        return all_results
        
    except ImportError as e:
        logger.error(f"RAGAS not available: {e}")
        return [{'answer_correctness': None, 'faithfulness': None}] * len(questions)

def evaluate_greenhouse_domain():
    """Main evaluation for Greenhouse domain"""
    logger.info("="*80)
    logger.info("GREENHOUSE DOMAIN COMPLETE EVALUATION")
    logger.info("="*80)
    logger.info(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Load data
    input_file = 'results/all_responses_with_ground_truth.csv'
    df = pd.read_csv(input_file)
    
    # Filter Greenhouse domain
    df_greenhouse = df[df['domain'] == 'greenhouse'].copy()
    logger.info(f"Loaded {len(df_greenhouse)} Greenhouse responses")
    
    # Get unique methods
    methods = sorted(df_greenhouse['method'].unique())
    logger.info(f"Methods to evaluate: {len(methods)}")
    for m in methods:
        count = len(df_greenhouse[df_greenhouse['method'] == m])
        logger.info(f"  - {m}: {count} responses")
    
    all_results = []
    
    logger.info("\n" + "="*80)
    logger.info("PROCESSING EACH METHOD")
    logger.info("="*80)
    
    for method in methods:
        method_df = df_greenhouse[df_greenhouse['method'] == method].copy()
        
        logger.info(f"\n{'='*60}")
        logger.info(f"Method: {method} ({len(method_df)} responses)")
        logger.info(f"{'='*60}")
        
        # Filter valid responses
        valid_mask = (
            method_df['response'].notna() & 
            method_df['ground_truth'].notna() &
            (method_df['response'].astype(str).str.len() > 10)
        )
        
        valid_df = method_df[valid_mask].copy()
        logger.info(f"Valid responses: {len(valid_df)}/{len(method_df)}")
        
        if len(valid_df) == 0:
            logger.warning(f"No valid responses for {method}, skipping...")
            continue
        
        # Extract data
        questions = valid_df['question'].tolist()
        responses = valid_df['response'].astype(str).tolist()
        ground_truths = valid_df['ground_truth'].astype(str).tolist()
        question_ids = valid_df['question_id'].tolist()
        
        # Compute ROUGE metrics
        logger.info("Computing ROUGE metrics...")
        rouge_results = []
        for resp, gt in zip(responses, ground_truths):
            rouge = compute_rouge_metrics(resp, gt)
            rouge_results.append(rouge)
        
        avg_rouge1_f1 = np.mean([r['rouge1_f1'] for r in rouge_results])
        avg_rouge2_f1 = np.mean([r['rouge2_f1'] for r in rouge_results])
        avg_rougeL_f1 = np.mean([r['rougeL_f1'] for r in rouge_results])
        logger.info(f"  ROUGE-1: {avg_rouge1_f1:.3f}")
        logger.info(f"  ROUGE-2: {avg_rouge2_f1:.3f}")
        logger.info(f"  ROUGE-L: {avg_rougeL_f1:.3f}")
        
        # Compute RAGAS metrics
        logger.info("Computing RAGAS metrics...")
        ragas_results = compute_ragas_batch(questions, responses, ground_truths, batch_size=20)
        
        avg_ac = np.mean([r['answer_correctness'] for r in ragas_results if r['answer_correctness'] is not None])
        avg_f = np.mean([r['faithfulness'] for r in ragas_results if r['faithfulness'] is not None])
        logger.info(f"  Answer Correctness: {avg_ac:.3f}")
        logger.info(f"  Faithfulness: {avg_f:.3f}")
        
        # Store individual results
        for i, (qid, rouge, ragas) in enumerate(zip(question_ids, rouge_results, ragas_results)):
            all_results.append({
                'domain': 'greenhouse',
                'method': method,
                'question_id': qid,
                'answer_correctness': ragas['answer_correctness'],
                'faithfulness': ragas['faithfulness'],
                'rouge1_f1': rouge['rouge1_f1'],
                'rouge2_f1': rouge['rouge2_f1'],
                'rougeL_f1': rouge['rougeL_f1'],
                'rouge1_precision': rouge['rouge1_precision'],
                'rouge1_recall': rouge['rouge1_recall'],
                'rouge2_precision': rouge['rouge2_precision'],
                'rouge2_recall': rouge['rouge2_recall'],
                'rougeL_precision': rouge['rougeL_precision'],
                'rougeL_recall': rouge['rougeL_recall']
            })
    
    # Save results
    logger.info("\n" + "="*80)
    logger.info("SAVING RESULTS")
    logger.info("="*80)
    
    df_results = pd.DataFrame(all_results)
    
    # Save detailed results
    detailed_file = 'results/greenhouse_detailed_metrics.csv'
    df_results.to_csv(detailed_file, index=False)
    logger.info(f"✅ Saved detailed results: {detailed_file}")
    
    # Create summary by method
    summary = df_results.groupby('method').agg({
        'answer_correctness': 'mean',
        'faithfulness': 'mean',
        'rouge1_f1': 'mean',
        'rouge2_f1': 'mean',
        'rougeL_f1': 'mean'
    }).reset_index()
    
    summary_file = 'results/greenhouse_summary_metrics.csv'
    summary.to_csv(summary_file, index=False)
    logger.info(f"✅ Saved summary: {summary_file}")
    
    # Print summary table
    logger.info("\n" + "="*80)
    logger.info("GREENHOUSE DOMAIN SUMMARY")
    logger.info("="*80)
    print("\n" + summary.to_string(index=False))
    
    logger.info("\n" + "="*80)
    logger.info(f"End time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    logger.info("✅ GREENHOUSE EVALUATION COMPLETE")
    logger.info("="*80)
    
    return df_results, summary

if __name__ == '__main__':
    results, summary = evaluate_greenhouse_domain()
