"""
Oracle evaluation with per-paper indexes, matching PeerQA paper's setup.
This creates separate indexes for each paper and only searches within the relevant paper.
"""
import json
import logging
import os
from typing import Dict, List, Tuple, Any, Set
import pandas as pd
import numpy as np
from pathlib import Path
from collections import defaultdict

from src.retrieval.bm25_retriever import BM25Retriever

logger = logging.getLogger(__name__)

class OraclePeerQAEvaluator:
    """Oracle evaluator that matches PeerQA paper's per-paper index setup."""
    
    def __init__(self, data_dir: str = "data"):
        """Initialize with PeerQA data directory."""
        self.data_dir = data_dir
        self.qa_data = None
        self.papers_data = None
        self.paper_indexes = {}  # BM25 index for each paper
        self.paper_corpus = {}   # Corpus for each paper
        self.paper_idx_mappings = {}  # idx to corpus position for each paper
        
    def load_data(self):
        """Load PeerQA data files."""
        # Load QA data
        qa_file = os.path.join(self.data_dir, 'qa.jsonl')
        qa_data = []
        with open(qa_file, 'r') as f:
            for line in f:
                qa_data.append(json.loads(line))
        self.qa_data = pd.DataFrame(qa_data)
        logger.info(f"Loaded {len(self.qa_data)} QA pairs")
        
        # Load papers data
        papers_file = os.path.join(self.data_dir, 'papers.jsonl')
        papers_data = []
        with open(papers_file, 'r') as f:
            for line in f:
                papers_data.append(json.loads(line))
        self.papers_data = pd.DataFrame(papers_data)
        logger.info(f"Loaded {len(self.papers_data)} paper chunks")
        
        # Get unique papers
        unique_papers = self.papers_data['paper_id'].unique()
        logger.info(f"Found {len(unique_papers)} unique papers")
    
    def build_per_paper_indexes(self, granularity: str = 'paragraph'):
        """Build separate BM25 indexes for each paper (Oracle setting).
        
        Args:
            granularity: 'paragraph' or 'sentence'
        """
        logger.info(f"Building per-paper indexes ({granularity} level)...")
        
        # Group papers data by paper_id
        paper_groups = self.papers_data.groupby('paper_id')
        
        for paper_id, paper_chunks in paper_groups:
            # Build corpus for this paper
            corpus = []
            idx_to_corpus_pos = {}
            pidx_to_corpus_pos = {}
            
            if granularity == 'paragraph':
                # Group by pidx for paragraph-level
                pidx_groups = paper_chunks.groupby('pidx')
                
                for pidx, pidx_chunks in pidx_groups:
                    # Combine all sentences in this paragraph
                    paragraph_text = ' '.join(pidx_chunks['content'].tolist())
                    corpus_pos = len(corpus)
                    corpus.append(paragraph_text)
                    
                    # Map pidx to corpus position
                    pidx_to_corpus_pos[pidx] = corpus_pos
                    
                    # Map all idx values in this paragraph
                    for idx in pidx_chunks['idx'].values:
                        idx_to_corpus_pos[idx] = corpus_pos
            else:
                # Sentence-level: use individual chunks
                for _, row in paper_chunks.iterrows():
                    corpus_pos = len(corpus)
                    corpus.append(row['content'])
                    idx_to_corpus_pos[row['idx']] = corpus_pos
            
            # Create BM25 index for this paper
            if corpus:
                bm25 = BM25Retriever({})
                bm25.build_index(corpus)
                self.paper_indexes[paper_id] = bm25
                self.paper_corpus[paper_id] = corpus
                self.paper_idx_mappings[paper_id] = {
                    'idx_to_pos': idx_to_corpus_pos,
                    'pidx_to_pos': pidx_to_corpus_pos
                }
                
                logger.debug(f"Paper {paper_id}: {len(corpus)} chunks indexed")
        
        logger.info(f"Built indexes for {len(self.paper_indexes)} papers")
        logger.info(f"Average chunks per paper: {np.mean([len(c) for c in self.paper_corpus.values()]):.1f}")
    
    def get_ground_truth_for_question_oracle(self, question_row: pd.Series, paper_id: str, 
                                            granularity: str = 'paragraph') -> Set[int]:
        """Get ground truth corpus positions within a specific paper.
        
        Args:
            question_row: Row from QA dataframe
            paper_id: Paper ID to search in
            granularity: 'paragraph' or 'sentence'
            
        Returns:
            Set of corpus positions within that paper's index
        """
        if paper_id not in self.paper_idx_mappings:
            return set()
        
        mappings = self.paper_idx_mappings[paper_id]
        relevant_positions = set()
        
        # Get answer_evidence_mapped
        evidence_mapped = question_row.get('answer_evidence_mapped', [])
        if not evidence_mapped:
            return relevant_positions
        
        # Extract idx values and map to corpus positions
        for evidence_item in evidence_mapped:
            if isinstance(evidence_item, dict) and 'idx' in evidence_item:
                idx_list = evidence_item['idx']
                if isinstance(idx_list, list):
                    for idx in idx_list:
                        if idx is not None and idx in mappings['idx_to_pos']:
                            relevant_positions.add(mappings['idx_to_pos'][idx])
        
        return relevant_positions
    
    def run_oracle_retrieval(self, qa_data: pd.DataFrame, granularity: str = 'paragraph',
                           k: int = 100) -> Dict:
        """Run oracle retrieval (search only within each question's paper).
        
        Args:
            qa_data: QA dataframe
            granularity: 'paragraph' or 'sentence'
            k: Number of results to retrieve
            
        Returns:
            Dictionary with retrieved results and ground truth
        """
        retrieved = {}
        ground_truth = {}
        
        questions_with_index = 0
        questions_with_gt = 0
        
        for idx, row in qa_data.iterrows():
            paper_id = row.get('paper_id')
            question = row.get('question')
            
            # Skip if no paper_id or question
            if not paper_id or not question:
                continue
            
            # Skip if we don't have an index for this paper
            if paper_id not in self.paper_indexes:
                logger.debug(f"No index for paper {paper_id}")
                continue
            
            questions_with_index += 1
            
            # Get BM25 retriever for this paper
            bm25 = self.paper_indexes[paper_id]
            
            # Retrieve from this paper's index
            results = bm25.retrieve(question, k=k)
            retrieved[idx] = results
            
            # Get ground truth for this question
            gt_positions = self.get_ground_truth_for_question_oracle(row, paper_id, granularity)
            ground_truth[idx] = list(gt_positions)
            
            if gt_positions:
                questions_with_gt += 1
        
        logger.info(f"Oracle retrieval completed:")
        logger.info(f"  - Questions with indexes: {questions_with_index}/{len(qa_data)}")
        logger.info(f"  - Questions with ground truth: {questions_with_gt}")
        
        return {
            'retrieved': retrieved,
            'ground_truth': ground_truth
        }
    
    def evaluate_retrieval(self, retrieved: Dict[int, List[Tuple[int, float]]], 
                          ground_truth: Dict[int, List[int]], 
                          k_values: List[int] = [5, 10, 20]) -> Dict[str, Any]:
        """Evaluate retrieval performance."""
        metrics = {}
        
        # Calculate Recall@K
        recall_at_k = {}
        for k in k_values:
            recalls = []
            for query_id, retrieved_docs in retrieved.items():
                if query_id not in ground_truth or not ground_truth[query_id]:
                    continue
                
                relevant_docs = set(ground_truth[query_id])
                retrieved_at_k = set([doc_id for doc_id, _ in retrieved_docs[:k]])
                
                recall = len(retrieved_at_k & relevant_docs) / len(relevant_docs)
                recalls.append(recall)
            
            recall_at_k[k] = np.mean(recalls) if recalls else 0.0
        
        metrics['recall'] = recall_at_k
        
        # Calculate MRR
        mrrs = []
        for query_id, retrieved_docs in retrieved.items():
            if query_id not in ground_truth or not ground_truth[query_id]:
                continue
            
            relevant_docs = set(ground_truth[query_id])
            
            for i, (doc_id, _) in enumerate(retrieved_docs):
                if doc_id in relevant_docs:
                    mrrs.append(1.0 / (i + 1))
                    break
            else:
                mrrs.append(0.0)
        
        metrics['mrr'] = np.mean(mrrs) if mrrs else 0.0
        
        # Calculate nDCG@K
        ndcg_at_k = {}
        for k in k_values:
            ndcgs = []
            for query_id, retrieved_docs in retrieved.items():
                if query_id not in ground_truth or not ground_truth[query_id]:
                    continue
                
                relevant_docs = set(ground_truth[query_id])
                
                # Calculate DCG
                dcg = 0.0
                for i, (doc_id, _) in enumerate(retrieved_docs[:k]):
                    if doc_id in relevant_docs:
                        dcg += 1.0 / np.log2(i + 2)
                
                # Calculate IDCG
                idcg = sum(1.0 / np.log2(i + 2) for i in range(min(k, len(relevant_docs))))
                
                if idcg > 0:
                    ndcgs.append(dcg / idcg)
            
            ndcg_at_k[k] = np.mean(ndcgs) if ndcgs else 0.0
        
        metrics['ndcg'] = ndcg_at_k
        
        return metrics