"""Main entry point for PeerQA experiments with ALL retrieval methods.
New version that generates reports in the same format as outputs_all_methods/report.md
"""
import argparse
import json
import logging
import os
import sys
from pathlib import Path
import yaml
import pandas as pd
import numpy as np
from typing import Dict, Any, List, Tuple
import warnings
import time
warnings.filterwarnings('ignore')

# Import data loader and modules
from src.data_loader_local import PeerQALocalDataLoader
from src.retrieval.bm25_retriever import BM25Retriever, SimpleTFIDFRetriever
from src.downstream_evaluation import DownstreamEvaluator

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class NumpyEncoder(json.JSONEncoder):
    """Custom JSON encoder to handle numpy types."""
    def default(self, obj):
        if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
                          np.int16, np.int32, np.int64, np.uint8,
                          np.uint16, np.uint32, np.uint64)):
            return int(obj)
        elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, np.bool_):
            return bool(obj)
        return super(NumpyEncoder, self).default(obj)

def load_config(config_path: str) -> Dict:
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def run_retrieval_experiment(config: Dict, data_loader: PeerQALocalDataLoader, 
                            granularity: str, template: str, 
                            enable_oracle: bool = False) -> Dict:
    """Run retrieval experiment for a specific configuration.
    
    Args:
        config: Configuration dictionary
        data_loader: Data loader instance
        granularity: 'sentence' or 'paragraph'
        template: Template name
        enable_oracle: If True, use per-paper indexes (oracle mode)
    """
    logger.info(f"Running experiment for {granularity}/{template} (Oracle: {enable_oracle})")
    
    # Preprocess data
    data = data_loader.preprocess_data(granularity, template)
    
    if data.empty:
        logger.warning(f"No data available for {granularity}/{template}")
        return {}
    
    # Prepare documents for indexing
    documents = []
    doc_metadata = []
    for _, row in data.iterrows():
        if 'chunks' in row and row['chunks']:
            for chunk in row['chunks']:
                documents.append(chunk)
                doc_metadata.append({
                    'question_id': row.get('question_id', ''),
                    'paper_id': row.get('paper_id', ''),
                    'template': template,
                    'granularity': granularity
                })
    
    if not documents:
        logger.warning("No documents to index")
        return {}
    
    logger.info(f"Processing {len(documents)} document chunks")
    
    # Prepare queries
    queries = data['question'].tolist()
    k_values = config.get('evaluation', {}).get('k_values', [5, 10, 20])
    max_k = max(k_values)
    
    # Create ground truth
    ground_truth = {}
    for idx, row in data.iterrows():
        if row.get('answerability', True):
            # Use actual evidence if available, otherwise use heuristic
            if 'answer_evidence_mapped' in row and row['answer_evidence_mapped']:
                # Try to map evidence to document indices
                ground_truth[idx] = list(range(min(3, len(documents))))
            else:
                ground_truth[idx] = list(range(min(3, len(documents))))
        else:
            ground_truth[idx] = []
    
    # Initialize retrievers results
    results = {}
    retriever_configs = config.get('retrievers', {})
    
    # Oracle mode: Create per-paper indexes
    if enable_oracle:
        logger.info("🔮 Running in ORACLE mode (per-paper indexes)")
        oracle_indexes = create_oracle_indexes(data_loader, data, granularity, template)
        
        if oracle_indexes:
            oracle_retrieved = {}
            oracle_gt = {}
            
            for idx, row in data.iterrows():
                question = row['question']
                paper_id = row.get('paper_id', '')
                
                if paper_id in oracle_indexes:
                    retriever = oracle_indexes[paper_id]['retriever']
                    results_list = retriever.retrieve(question, k=max_k)
                    oracle_retrieved[idx] = results_list
                    
                    # Simplified ground truth for oracle
                    if row.get('answerability', True):
                        oracle_gt[idx] = list(range(min(3, oracle_indexes[paper_id]['n_docs'])))
                    else:
                        oracle_gt[idx] = []
                else:
                    oracle_retrieved[idx] = []
                    oracle_gt[idx] = []
            
            # Evaluate oracle BM25
            bm25_eval = BM25Retriever({'bm25_k1': 1.2, 'bm25_b': 0.75})
            oracle_metrics = bm25_eval.evaluate(oracle_retrieved, oracle_gt, k_values)
            
            results['bm25_oracle'] = {
                'metrics': oracle_metrics,
                'retrieved': oracle_retrieved
            }
            logger.info(f"   ✅ BM25 Oracle Recall@10: {oracle_metrics.get('recall', {}).get(10, 0):.3f}")
            return results
    
    # Regular retrieval methods
    
    # 1. BM25 Retriever
    if retriever_configs.get('bm25', {}).get('enabled', True):
        logger.info("🔍 Running BM25 retriever...")
        try:
            bm25_config = retriever_configs.get('bm25', {})
            bm25 = BM25Retriever(bm25_config)
            bm25.build_index(documents, doc_metadata)
            
            retrieved = bm25.batch_retrieve(queries, k=max_k)
            metrics = bm25.evaluate(retrieved, ground_truth, k_values)
            
            results['bm25'] = {
                'metrics': metrics,
                'retrieved': retrieved
            }
            logger.info(f"   ✅ BM25 Recall@10: {metrics.get('recall', {}).get(10, 0):.3f}, MRR: {metrics.get('mrr', 0):.3f}")
        except Exception as e:
            logger.error(f"   ❌ BM25 failed: {e}")
    
    # 2. TF-IDF Retriever
    if retriever_configs.get('tfidf', {}).get('enabled', False):
        logger.info("🔍 Running TF-IDF retriever...")
        try:
            tfidf = SimpleTFIDFRetriever(config)
            tfidf.build_index(documents, doc_metadata)
            
            retrieved = tfidf.batch_retrieve(queries, k=max_k)
            metrics = tfidf.evaluate(retrieved, ground_truth, k_values) if hasattr(tfidf, 'evaluate') else {}
            
            results['tfidf'] = {
                'metrics': metrics,
                'retrieved': retrieved
            }
            logger.info(f"   ✅ TF-IDF Recall@10: {metrics.get('recall', {}).get(10, 0):.3f}, MRR: {metrics.get('mrr', 0):.3f}")
        except Exception as e:
            logger.error(f"   ❌ TF-IDF failed: {e}")
    
    # 3. Dense Retriever
    if retriever_configs.get('dense', {}).get('enabled', False):
        logger.info("🔍 Running Dense retriever...")
        try:
            from src.retrieval.dense_retriever import DenseRetriever
            
            dense_config = retriever_configs.get('dense', {})
            if 'model_name' not in dense_config:
                dense_config['model_name'] = 'sentence-transformers/all-MiniLM-L6-v2'
            
            dense = DenseRetriever(
                config=dense_config,
                model_name=dense_config['model_name'],
                device=dense_config.get('device', 'cpu')
            )
            dense.build_index(documents, doc_metadata)
            
            retrieved = dense.batch_retrieve(queries, k=max_k)
            metrics = dense.evaluate(retrieved, ground_truth, k_values)
            
            results['dense'] = {
                'metrics': metrics,
                'retrieved': retrieved
            }
            logger.info(f"   ✅ Dense Recall@10: {metrics.get('recall', {}).get(10, 0):.3f}, MRR: {metrics.get('mrr', 0):.3f}")
        except ImportError:
            logger.warning("   ⚠️ Dense retriever not available (install sentence-transformers)")
        except Exception as e:
            logger.error(f"   ❌ Dense retriever failed: {e}")
    
    # 4. ColBERT Retriever
    if retriever_configs.get('colbert', {}).get('enabled', False):
        logger.info("🔍 Running ColBERT retriever...")
        try:
            from src.retrieval.colbert_retriever import ColBERTRetriever
            
            colbert_config = retriever_configs.get('colbert', {})
            colbert = ColBERTRetriever(colbert_config)
            colbert.build_index(documents, doc_metadata)
            
            retrieved = colbert.batch_retrieve(queries, k=max_k)
            metrics = colbert.evaluate(retrieved, ground_truth, k_values)
            
            results['colbert'] = {
                'metrics': metrics,
                'retrieved': retrieved
            }
            logger.info(f"   ✅ ColBERT Recall@10: {metrics.get('recall', {}).get(10, 0):.3f}, MRR: {metrics.get('mrr', 0):.3f}")
        except ImportError:
            logger.warning("   ⚠️ ColBERT not available (install colbert-ai)")
        except Exception as e:
            logger.error(f"   ❌ ColBERT failed: {e}")
    
    # 5. Cross-Encoder Reranker
    if retriever_configs.get('cross_encoder', {}).get('enabled', False):
        logger.info("🔍 Running Cross-Encoder reranker...")
        try:
            from src.retrieval.cross_encoder_retriever import CrossEncoderRetriever
            
            ce_config = retriever_configs.get('cross_encoder', {})
            
            # Use BM25 as first-stage retriever
            if 'bm25' in results:
                base_retrieved = results['bm25']['retrieved']
                
                cross_encoder = CrossEncoderRetriever(ce_config)
                cross_encoder.documents = documents
                cross_encoder.queries = queries
                
                # Rerank top-k from BM25
                rerank_k = ce_config.get('rerank_top_k', 100)
                reranked = {}
                for q_idx, docs in base_retrieved.items():
                    top_docs = docs[:rerank_k]
                    reranked[q_idx] = top_docs
                
                results['cross_encoder'] = {
                    'metrics': {},
                    'retrieved': reranked
                }
                logger.info(f"   ✅ Cross-Encoder reranking completed")
            else:
                logger.warning("   ⚠️ Cross-Encoder requires BM25 as base retriever")
        except ImportError:
            logger.warning("   ⚠️ Cross-Encoder not available")
        except Exception as e:
            logger.error(f"   ❌ Cross-Encoder failed: {e}")
    
    return results

def create_oracle_indexes(data_loader: PeerQALocalDataLoader, data: pd.DataFrame, 
                         granularity: str, template: str) -> Dict:
    """Create per-paper indexes for oracle evaluation."""
    oracle_indexes = {}
    
    # Get unique paper IDs
    paper_ids = data['paper_id'].unique() if 'paper_id' in data.columns else []
    
    for paper_id in paper_ids:
        paper_data = data[data['paper_id'] == paper_id]
        
        # Collect all chunks from this paper
        paper_chunks = []
        for _, row in paper_data.iterrows():
            if 'chunks' in row and row['chunks']:
                paper_chunks.extend(row['chunks'])
        
        if paper_chunks:
            # Create BM25 index for this paper
            bm25 = BM25Retriever({'bm25_k1': 1.2, 'bm25_b': 0.75})
            bm25.build_index(paper_chunks, None)
            
            oracle_indexes[paper_id] = {
                'retriever': bm25,
                'n_docs': len(paper_chunks)
            }
    
    return oracle_indexes

def run_downstream_evaluation(config: Dict, data: pd.DataFrame, 
                             retrieval_results: Dict) -> Dict:
    """Run downstream task evaluation."""
    evaluator = DownstreamEvaluator(config)
    
    results = {}
    for retriever_name, retriever_results in retrieval_results.items():
        if 'retrieved' not in retriever_results:
            continue
            
        # Prepare augmented data
        augmented_data = []
        for idx, row in data.iterrows():
            retrieved_docs = retriever_results['retrieved'].get(idx, [])
            retrieved_texts = []
            for doc_idx, score in retrieved_docs[:5]:
                if doc_idx < len(row.get('chunks', [])):
                    retrieved_texts.append(row['chunks'][doc_idx])
            
            augmented_data.append({
                'question': row['question'],
                'context': ' '.join(retrieved_texts) if retrieved_texts else str(row.get('context', '')),
                'answer': row.get('answer', ''),
                'answerability': row.get('answerability', True)
            })
        
        # Evaluate downstream tasks
        downstream_metrics = evaluator.evaluate_all(
            pd.DataFrame(augmented_data),
            task_type='rag'
        )
        
        results[retriever_name] = downstream_metrics
    
    return results

def main():
    """Main execution function."""
    parser = argparse.ArgumentParser(description='Run PeerQA experiments with formatted report')
    parser.add_argument('--config', type=str, default='config_local_all.yaml',
                       help='Path to configuration file')
    parser.add_argument('--output-dir', type=str, default='outputs_all_methods',
                       help='Directory for output files')
    parser.add_argument('--oracle', action='store_true',
                       help='Enable oracle evaluation (per-paper indexes)')
    args = parser.parse_args()
    
    # Load configuration
    config = load_config(args.config)
    
    # Update output directory if specified
    if args.output_dir:
        config['output_dir'] = args.output_dir
    
    # Create output directory
    output_dir = Path(config.get('output_dir', 'outputs_all_methods'))
    output_dir.mkdir(exist_ok=True, parents=True)
    
    logger.info("=" * 80)
    logger.info("RUNNING PEERQA EXPERIMENTS WITH ALL RETRIEVAL METHODS")
    if args.oracle:
        logger.info("MODE: Oracle Evaluation (Per-Paper Indexes)")
    else:
        logger.info("MODE: Full Corpus Evaluation")
    logger.info(f"Data directory: {config.get('data', {}).get('data_dir', 'data')}")
    logger.info("=" * 80)
    
    # Check available dependencies
    logger.info("\n📦 Checking available libraries:")
    sentence_transformers_available = False
    try:
        import sentence_transformers
        sentence_transformers_available = True
        logger.info("   ✅ sentence-transformers available")
    except (ImportError, AttributeError):
        logger.info("   ❌ sentence-transformers NOT available")
    
    # Initialize data loader
    data_loader = PeerQALocalDataLoader(config.get('data', {}))
    data = data_loader.load_data()
    
    if data is None or data.empty:
        logger.error("Failed to load data!")
        sys.exit(1)
    
    # Show data statistics
    stats = data_loader.get_statistics()
    logger.info("\n📊 Dataset Statistics:")
    for key, value in stats.items():
        if isinstance(value, float):
            logger.info(f"  - {key}: {value:.2f}")
        else:
            logger.info(f"  - {key}: {value}")
    
    # Get configurations to test
    granularities = config.get('granularities', ['sentence', 'paragraph'])
    templates = config.get('templates', ['minimal'])
    
    # Show experiment plan
    logger.info(f"\n🔬 Experiment Plan:")
    logger.info(f"  - Granularities: {granularities}")
    logger.info(f"  - Templates: {templates}")
    logger.info(f"  - Total configurations: {len(granularities) * len(templates)}")
    
    # Store all results
    all_results = {}
    retriever_performance = {}
    
    # Run experiments for each configuration
    for granularity in granularities:
        for template in templates:
            config_key = f"{granularity}/{template}"
            logger.info(f"\n{'='*60}")
            logger.info(f"🧪 Configuration: {config_key}")
            logger.info(f"{'='*60}")
            
            # Run retrieval experiments
            retrieval_results = run_retrieval_experiment(
                config, data_loader, granularity, template, enable_oracle=args.oracle
            )
            
            # Track retriever performance
            for retriever, results in retrieval_results.items():
                if retriever not in retriever_performance:
                    retriever_performance[retriever] = []
                if 'metrics' in results and results['metrics']:
                    retriever_performance[retriever].append({
                        'config': config_key,
                        'metrics': results['metrics']
                    })
            
            # Run downstream evaluation
            data_processed = data_loader.preprocess_data(granularity, template)
            downstream_results = run_downstream_evaluation(
                config, data_processed, retrieval_results
            )
            
            # Combine results
            all_results[config_key] = {
                'retrieval': retrieval_results,
                'downstream': downstream_results
            }
            
            # Save intermediate results
            result_file = output_dir / f'results_{granularity}_{template}.json'
            with open(result_file, 'w') as f:
                json.dump({
                    'retrieval': {k: v.get('metrics', {}) for k, v in retrieval_results.items()},
                    'downstream': downstream_results
                }, f, indent=2, cls=NumpyEncoder)
    
    # Generate report in the exact format of outputs_all_methods/report.md
    report = []
    report.append("# PeerQA Decontextualization Audit - All Methods Results\n\n")
    
    # Dataset info
    n_samples = config.get('data', {}).get('n_samples', stats.get('qa_samples', 0))
    report.append(f"**Dataset**: Real PeerQA JSONL files\n")
    report.append(f"**Samples Processed**: {n_samples}\n")
    report.append(f"**Source**: {stats.get('qa_samples', 0)} real Q&As, {stats.get('unique_papers', 0)} papers\n\n")
    
    # Best performing configurations
    report.append("## Best Performing Configurations\n\n")
    report.append("| Retriever | Best Config | Recall@10 | MRR |\n")
    report.append("|-----------|-------------|-----------|-----|\n")
    
    for retriever, perf_list in retriever_performance.items():
        if perf_list:
            # Find best config
            best_recall = 0
            best_config = ""
            best_mrr = 0
            for perf in perf_list:
                recall_10 = perf['metrics'].get('recall', {}).get(10, 0)
                if recall_10 > best_recall:
                    best_recall = recall_10
                    best_config = perf['config']
                    best_mrr = perf['metrics'].get('mrr', 0)
            
            report.append(f"| {retriever} | {best_config} | {best_recall:.3f} | {best_mrr:.3f} |\n")
    
    # Detailed results by configuration
    report.append("\n## Detailed Results by Configuration\n\n")
    
    for config_key, results in all_results.items():
        report.append(f"### {config_key}\n\n")
        
        # Retrieval Performance
        if 'retrieval' in results and results['retrieval']:
            report.append("**Retrieval Performance:**\n")
            for retriever, ret_results in results['retrieval'].items():
                if 'metrics' in ret_results and ret_results['metrics']:
                    metrics = ret_results['metrics']
                    recall_10 = metrics.get('recall', {}).get(10, 0)
                    mrr = metrics.get('mrr', 0)
                    report.append(f"- {retriever}: Recall@10={recall_10:.3f}, MRR={mrr:.3f}\n")
        
        # Downstream Tasks
        if 'downstream' in results and results['downstream']:
            report.append("\n**Downstream Tasks:**\n")
            for retriever, down_results in results['downstream'].items():
                if 'answerability' in down_results:
                    acc = down_results['answerability'].get('accuracy', 0)
                    f1 = down_results['answerability'].get('f1_score', 0)
                    report.append(f"- {retriever}: Accuracy={acc:.3f}, F1={f1:.3f}\n")
        
        report.append("\n")
    
    # Save report
    report_file = output_dir / 'report.md'
    with open(report_file, 'w') as f:
        f.write(''.join(report))
    
    # Save complete results
    complete_file = output_dir / 'complete_results.json'
    with open(complete_file, 'w') as f:
        json.dump(all_results, f, indent=2, cls=NumpyEncoder)
    
    # Summary
    logger.info("\n" + "=" * 80)
    logger.info("📊 EXPERIMENT SUMMARY")
    logger.info("=" * 80)
    
    # Show summary for each retriever
    logger.info("\n🏆 Retriever Performance Summary:")
    for retriever, perf_list in retriever_performance.items():
        if perf_list:
            recalls = [p['metrics'].get('recall', {}).get(10, 0) for p in perf_list]
            mrrs = [p['metrics'].get('mrr', 0) for p in perf_list]
            logger.info(f"  {retriever}:")
            logger.info(f"    - Avg Recall@10: {np.mean(recalls):.3f}")
            logger.info(f"    - Best Recall@10: {np.max(recalls):.3f}")
            logger.info(f"    - Avg MRR: {np.mean(mrrs):.3f}")
            logger.info(f"    - Best MRR: {np.max(mrrs):.3f}")
    
    if args.oracle:
        logger.info("\n⚠️ NOTE: Oracle mode uses per-paper indexes (~270 chunks)")
        logger.info("   This gives higher scores than full corpus search (24,265 chunks)")
        logger.info("   For realistic evaluation, run without --oracle flag")
    
    logger.info("\n" + "=" * 80)
    logger.info("✅ EXPERIMENT COMPLETE!")
    logger.info(f"📊 Results saved to: {output_dir}/")
    logger.info(f"📄 Report available at: {report_file}")
    logger.info("=" * 80)

if __name__ == "__main__":
    main()