#!/usr/bin/env python3
"""
Evaluate neural baselines (LSTM+Attention, RETAIN, IOC) for Greenhouse domain
Compute RAGAS (Answer Correctness, Faithfulness) and ROUGE (L, 1, 2) metrics
"""

import os
import sys
import json
import pandas as pd
import numpy as np
from typing import Dict, List
import logging
from datetime import datetime

# RAGAS imports
from datasets import Dataset
from ragas import evaluate
from ragas.metrics import answer_correctness, faithfulness

# ROUGE imports
from rouge_score import rouge_scorer

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('results/neural_baselines_evaluation.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# OpenAI API key
os.environ['OPENAI_API_KEY'] = 'sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'  # Replace with your actual API key

def compute_rouge_metrics(prediction: str, reference: str) -> Dict[str, float]:
    """Compute ROUGE-L, ROUGE-1, and ROUGE-2 metrics"""
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = scorer.score(reference, prediction)
    
    return {
        'rouge_1': scores['rouge1'].fmeasure,
        'rouge_2': scores['rouge2'].fmeasure,
        'rouge_l': scores['rougeL'].fmeasure
    }

def evaluate_method(method_name: str, csv_file: str) -> pd.DataFrame:
    """Evaluate a single neural baseline method"""
    logger.info("="*80)
    logger.info(f"EVALUATING {method_name}")
    logger.info("="*80)
    
    # Load data
    df = pd.read_csv(csv_file)
    logger.info(f"Loaded {len(df)} responses from {csv_file}")
    
    # Prepare results storage
    results = []
    
    # Process in batches for RAGAS
    batch_size = 20
    total_batches = (len(df) + batch_size - 1) // batch_size
    
    for batch_idx in range(total_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(df))
        batch_df = df.iloc[start_idx:end_idx]
        
        logger.info(f"Processing batch {batch_idx + 1}/{total_batches} (rows {start_idx}-{end_idx})")
        
        # Compute RAGAS metrics for batch
        try:
            # Prepare dataset for RAGAS
            # For neural baselines, use ground_truth as context since they don't use retrieval
            ragas_data = {
                'question': batch_df['question'].tolist(),
                'answer': batch_df['response'].tolist(),
                'contexts': [[gt] for gt in batch_df['ground_truth'].tolist()],  # Use ground truth as context
                'ground_truth': batch_df['ground_truth'].tolist()
            }
            
            dataset = Dataset.from_dict(ragas_data)
            
            # Evaluate with RAGAS
            logger.info(f"Computing RAGAS metrics for {len(batch_df)} responses...")
            result = evaluate(
                dataset,
                metrics=[answer_correctness, faithfulness]
            )
            
            # Extract metrics using to_pandas()
            df_result = result.to_pandas()
            ac_scores = df_result['answer_correctness'].tolist()
            f_scores = df_result['faithfulness'].tolist()
            
            logger.info(f"RAGAS complete - AC: {np.mean(ac_scores):.3f}, F: {np.mean(f_scores):.3f}")
            
        except Exception as e:
            logger.error(f"Batch {batch_idx + 1} RAGAS failed: {e}")
            # Use default values if RAGAS fails
            ac_scores = [0.0] * len(batch_df)
            f_scores = [0.0] * len(batch_df)
        
        # Compute ROUGE metrics for each response in batch
        for idx, row in batch_df.iterrows():
            try:
                rouge_metrics = compute_rouge_metrics(row['response'], row['ground_truth'])
            except Exception as e:
                logger.error(f"ROUGE failed for question {row['question_id']}: {e}")
                rouge_metrics = {'rouge_1': 0.0, 'rouge_2': 0.0, 'rouge_l': 0.0}
            
            # Get corresponding RAGAS scores
            batch_position = idx - start_idx
            
            results.append({
                'domain': 'greenhouse',
                'method': method_name,
                'method_type': 'neural_baseline',
                'question_id': row['question_id'],
                'question': row['question'],
                'response': row['response'],
                'ground_truth': row['ground_truth'],
                'answer_correctness': ac_scores[batch_position],
                'faithfulness': f_scores[batch_position],
                'rouge_1': rouge_metrics['rouge_1'],
                'rouge_2': rouge_metrics['rouge_2'],
                'rouge_l': rouge_metrics['rouge_l']
            })
    
    # Create results dataframe
    results_df = pd.DataFrame(results)
    
    # Log summary statistics
    logger.info("\n" + "="*80)
    logger.info(f"{method_name} EVALUATION SUMMARY")
    logger.info("="*80)
    logger.info(f"Total responses: {len(results_df)}")
    logger.info(f"Answer Correctness: {results_df['answer_correctness'].mean():.4f} ± {results_df['answer_correctness'].std():.4f}")
    logger.info(f"Faithfulness: {results_df['faithfulness'].mean():.4f} ± {results_df['faithfulness'].std():.4f}")
    logger.info(f"ROUGE-1: {results_df['rouge_1'].mean():.4f} ± {results_df['rouge_1'].std():.4f}")
    logger.info(f"ROUGE-2: {results_df['rouge_2'].mean():.4f} ± {results_df['rouge_2'].std():.4f}")
    logger.info(f"ROUGE-L: {results_df['rouge_l'].mean():.4f} ± {results_df['rouge_l'].std():.4f}")
    logger.info("="*80)
    
    return results_df

def main():
    """Main evaluation function"""
    logger.info("="*80)
    logger.info("NEURAL BASELINES EVALUATION - GREENHOUSE DOMAIN")
    logger.info("="*80)
    logger.info(f"Started at: {datetime.now().isoformat()}")
    
    # Define methods to evaluate
    methods = [
        ('LSTM+Attention', 'results/lstm_attention_greenhouse.csv'),
        ('RETAIN', 'results/retain_greenhouse.csv'),
        ('IOC', 'results/ioc_greenhouse.csv')
    ]
    
    all_results = []
    
    # Evaluate each method
    for method_name, csv_file in methods:
        try:
            results_df = evaluate_method(method_name, csv_file)
            all_results.append(results_df)
        except Exception as e:
            logger.error(f"Failed to evaluate {method_name}: {e}", exc_info=True)
    
    # Combine all results
    if all_results:
        combined_df = pd.concat(all_results, ignore_index=True)
        
        # Save combined results
        output_file = 'results/neural_baselines_greenhouse_metrics.csv'
        combined_df.to_csv(output_file, index=False)
        logger.info(f"\n✅ Saved combined results to {output_file}")
        
        # Save summary statistics
        summary = {}
        for method_name, _ in methods:
            method_df = combined_df[combined_df['method'] == method_name]
            if len(method_df) > 0:
                summary[method_name] = {
                    'count': len(method_df),
                    'answer_correctness': {
                        'mean': float(method_df['answer_correctness'].mean()),
                        'std': float(method_df['answer_correctness'].std())
                    },
                    'faithfulness': {
                        'mean': float(method_df['faithfulness'].mean()),
                        'std': float(method_df['faithfulness'].std())
                    },
                    'rouge_1': {
                        'mean': float(method_df['rouge_1'].mean()),
                        'std': float(method_df['rouge_1'].std())
                    },
                    'rouge_2': {
                        'mean': float(method_df['rouge_2'].mean()),
                        'std': float(method_df['rouge_2'].std())
                    },
                    'rouge_l': {
                        'mean': float(method_df['rouge_l'].mean()),
                        'std': float(method_df['rouge_l'].std())
                    }
                }
        
        summary_file = 'results/neural_baselines_greenhouse_summary.json'
        with open(summary_file, 'w') as f:
            json.dump(summary, f, indent=2)
        logger.info(f"✅ Saved summary statistics to {summary_file}")
        
        # Print final summary
        logger.info("\n" + "="*80)
        logger.info("FINAL SUMMARY")
        logger.info("="*80)
        logger.info(f"Total responses evaluated: {len(combined_df)}")
        logger.info(f"Methods evaluated: {len(methods)}")
        for method_name in summary:
            logger.info(f"\n{method_name}:")
            logger.info(f"  Answer Correctness: {summary[method_name]['answer_correctness']['mean']:.4f}")
            logger.info(f"  Faithfulness: {summary[method_name]['faithfulness']['mean']:.4f}")
            logger.info(f"  ROUGE-L: {summary[method_name]['rouge_l']['mean']:.4f}")
        logger.info("="*80)
        
    logger.info(f"\nCompleted at: {datetime.now().isoformat()}")
    logger.info("="*80)
    logger.info("EVALUATION COMPLETE")
    logger.info("="*80)

if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        logger.error(f"Fatal error: {e}", exc_info=True)
        sys.exit(1)
