#!/usr/bin/env python3
"""
Evaluate Improved Greenhouse HCA responses
Compare HCA_Full vs HCA_Full_Improved
"""

import pandas as pd
import numpy as np
import logging
import os
from datetime import datetime

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('results/greenhouse_improved_evaluation.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

os.environ['OPENAI_API_KEY'] = 'sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'  # Replace with your actual API key

def compute_rouge_metrics(prediction, reference):
    """Compute ROUGE-L, ROUGE-1, ROUGE-2"""
    try:
        from rouge_score import rouge_scorer
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        scores = scorer.score(reference, prediction)
        return {
            'rouge1': scores['rouge1'].fmeasure,
            'rouge2': scores['rouge2'].fmeasure,
            'rougeL': scores['rougeL'].fmeasure,
        }
    except Exception as e:
        logger.warning(f"ROUGE computation failed: {e}")
        return {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0}

def compute_ragas_batch(questions, responses, ground_truths, contexts, batch_size=20):
    """Compute RAGAS metrics in batches"""
    try:
        from ragas import evaluate
        from ragas.metrics import answer_correctness, faithfulness
        from datasets import Dataset
        
        all_results = []
        total_batches = (len(questions) + batch_size - 1) // batch_size
        
        logger.info(f"Computing RAGAS for {len(questions)} responses in {total_batches} batches...")
        
        for i in range(0, len(questions), batch_size):
            batch_end = min(i + batch_size, len(questions))
            batch_num = i // batch_size + 1
            
            logger.info(f"  Batch {batch_num}/{total_batches}: Processing {i+1}-{batch_end}")
            
            # Prepare contexts properly
            batch_contexts = []
            for j in range(i, batch_end):
                ctx = contexts[j] if pd.notna(contexts[j]) and str(contexts[j]).strip() else ground_truths[j]
                # RAGAS expects list of context strings
                batch_contexts.append([str(ctx)])
            
            batch_data = {
                'question': questions[i:batch_end],
                'answer': responses[i:batch_end],
                'ground_truth': ground_truths[i:batch_end],
                'contexts': batch_contexts
            }
            
            dataset = Dataset.from_dict(batch_data)
            
            try:
                result = evaluate(
                    dataset,
                    metrics=[answer_correctness, faithfulness],
                    raise_exceptions=False
                )
                
                df_result = result.to_pandas()
                
                for idx in range(len(df_result)):
                    all_results.append({
                        'answer_correctness': df_result['answer_correctness'].iloc[idx],
                        'faithfulness': df_result['faithfulness'].iloc[idx]
                    })
                
                avg_ac = df_result['answer_correctness'].mean()
                avg_f = df_result['faithfulness'].mean()
                non_zero_f = (df_result['faithfulness'] > 0).sum()
                logger.info(f"    Batch {batch_num} complete: AC={avg_ac:.3f}, F={avg_f:.3f} ({non_zero_f}/{len(df_result)} non-zero)")
                
            except Exception as e:
                logger.error(f"    Batch {batch_num} failed: {e}")
                for _ in range(batch_end - i):
                    all_results.append({
                        'answer_correctness': 0.0,
                        'faithfulness': 0.0
                    })
        
        return all_results
        
    except ImportError as e:
        logger.error(f"RAGAS not available: {e}")
        return [{'answer_correctness': None, 'faithfulness': None}] * len(questions)

def evaluate_improved():
    """Evaluate improved responses"""
    
    logger.info("="*80)
    logger.info("IMPROVED GREENHOUSE HCA EVALUATION")
    logger.info("="*80)
    logger.info(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Load data
    df = pd.read_csv('results/all_responses_with_ground_truth.csv')
    
    # Filter to Greenhouse + HCA_Full_Improved
    df_improved = df[(df['domain'] == 'greenhouse') & (df['method'] == 'HCA_Full_Improved')].copy()
    logger.info(f"Loaded {len(df_improved)} improved responses")
    
    # Prepare data
    questions = df_improved['question'].tolist()
    responses = df_improved['response'].tolist()
    ground_truths = df_improved['ground_truth'].tolist()
    contexts = df_improved.get('improved_context', df_improved['ground_truth']).tolist()
    
    # Compute RAGAS (with contexts)
    logger.info("\n" + "="*60)
    logger.info("Computing RAGAS Metrics (AC, F)")
    logger.info("="*60)
    ragas_results = compute_ragas_batch(questions, responses, ground_truths, contexts, batch_size=20)
    
    # Compute ROUGE
    logger.info("\n" + "="*60)
    logger.info("Computing ROUGE Metrics")
    logger.info("="*60)
    rouge_results = []
    for i, (pred, ref) in enumerate(zip(responses, ground_truths)):
        if i % 20 == 0:
            logger.info(f"  Progress: {i}/{len(responses)}")
        rouge_results.append(compute_rouge_metrics(pred, ref))
    
    # Combine results
    df_improved['answer_correctness'] = [r['answer_correctness'] for r in ragas_results]
    df_improved['faithfulness'] = [r['faithfulness'] for r in ragas_results]
    df_improved['rouge1'] = [r['rouge1'] for r in rouge_results]
    df_improved['rouge2'] = [r['rouge2'] for r in rouge_results]
    df_improved['rougeL'] = [r['rougeL'] for r in rouge_results]
    
    # Compute statistics
    logger.info("\n" + "="*80)
    logger.info("RESULTS: HCA_Full_Improved")
    logger.info("="*80)
    
    metrics = {
        'Answer Correctness': df_improved['answer_correctness'],
        'Faithfulness': df_improved['faithfulness'],
        'ROUGE-1': df_improved['rouge1'],
        'ROUGE-2': df_improved['rouge2'],
        'ROUGE-L': df_improved['rougeL']
    }
    
    for metric_name, values in metrics.items():
        mean_val = values.mean()
        std_val = values.std()
        non_zero = (values > 0).sum()
        logger.info(f"{metric_name:20s}: {mean_val:.3f} ± {std_val:.3f} ({non_zero}/{len(values)} non-zero)")
    
    # Load original HCA_Full for comparison
    df_original = df[(df['domain'] == 'greenhouse') & (df['method'] == 'HCA_Full')].copy()
    
    logger.info("\n" + "="*80)
    logger.info("COMPARISON: Original vs Improved")
    logger.info("="*80)
    
    comparison = pd.DataFrame({
        'Metric': ['Answer Correctness', 'Faithfulness', 'ROUGE-L'],
        'Original (HCA_Full)': [
            df_original['answer_correctness'].mean() if 'answer_correctness' in df_original else 0.336,
            df_original['faithfulness'].mean() if 'faithfulness' in df_original else 0.028,
            df_original['rougeL'].mean() if 'rougeL' in df_original else 0.086
        ],
        'Improved (HCA_Full_Improved)': [
            df_improved['answer_correctness'].mean(),
            df_improved['faithfulness'].mean(),
            df_improved['rougeL'].mean()
        ]
    })
    
    comparison['Absolute Change'] = comparison['Improved (HCA_Full_Improved)'] - comparison['Original (HCA_Full)']
    comparison['Relative Change (%)'] = (comparison['Absolute Change'] / comparison['Original (HCA_Full)']) * 100
    
    print("\n" + str(comparison.to_string(index=False)))
    logger.info("\n" + str(comparison.to_string(index=False)))
    
    # Save results
    output_file = 'results/greenhouse_hca_improved_metrics.csv'
    df_improved.to_csv(output_file, index=False)
    logger.info(f"\n✅ Saved detailed results to {output_file}")
    
    comparison.to_csv('results/greenhouse_improvement_comparison.csv', index=False)
    logger.info(f"✅ Saved comparison to results/greenhouse_improvement_comparison.csv")
    
    logger.info("\n" + "="*80)
    logger.info("EVALUATION COMPLETE")
    logger.info("="*80)
    logger.info(f"End time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

if __name__ == '__main__':
    evaluate_improved()
