"""
Proper PeerQA evaluation using the exact same setup as the paper.
This uses the idx mappings from answer_evidence_mapped to create ground truth.
"""
import json
import logging
import os
from typing import Dict, List, Tuple, Any, Set
import pandas as pd
import numpy as np
from pathlib import Path

logger = logging.getLogger(__name__)

class ProperPeerQAEvaluator:
    """Evaluator that exactly matches PeerQA paper's ground truth setup."""
    
    def __init__(self, data_dir: str = "data"):
        """Initialize with PeerQA data directory."""
        self.data_dir = data_dir
        self.qa_data = None
        self.papers_data = None
        self.idx_to_corpus_pos = {}  # Map from paper idx to corpus position
        self.pidx_to_corpus_pos = {}  # Map from pidx to corpus positions
        self.corpus = []
        
    def load_data(self):
        """Load PeerQA data files."""
        # Load QA data
        qa_file = os.path.join(self.data_dir, 'qa.jsonl')
        qa_data = []
        with open(qa_file, 'r') as f:
            for line in f:
                qa_data.append(json.loads(line))
        self.qa_data = pd.DataFrame(qa_data)
        logger.info(f"Loaded {len(self.qa_data)} QA pairs")
        
        # Load papers data
        papers_file = os.path.join(self.data_dir, 'papers.jsonl')
        papers_data = []
        with open(papers_file, 'r') as f:
            for line in f:
                papers_data.append(json.loads(line))
        self.papers_data = pd.DataFrame(papers_data)
        logger.info(f"Loaded {len(self.papers_data)} paper chunks")
        
    def build_retrieval_corpus(self, granularity: str = 'paragraph') -> List[str]:
        """Build the retrieval corpus from papers data, preserving indices.
        
        Args:
            granularity: 'paragraph' or 'sentence'
            
        Returns:
            List of all paper chunks (the retrieval corpus)
        """
        if self.papers_data is None:
            self.load_data()
        
        corpus = []
        
        if granularity == 'paragraph':
            # Group by pidx for paragraph-level retrieval
            pidx_groups = {}
            
            for _, row in self.papers_data.iterrows():
                pidx = row['pidx']
                if pidx not in pidx_groups:
                    pidx_groups[pidx] = []
                pidx_groups[pidx].append(row)
            
            # Create corpus with paragraph chunks
            for pidx in sorted(pidx_groups.keys()):
                # Combine all sentences in this paragraph
                paragraph_text = ' '.join([row['content'] for row in pidx_groups[pidx]])
                corpus_pos = len(corpus)
                corpus.append(paragraph_text)
                
                # Map pidx to corpus position
                self.pidx_to_corpus_pos[pidx] = corpus_pos
                
                # Map all idx values in this paragraph to this corpus position
                for row in pidx_groups[pidx]:
                    self.idx_to_corpus_pos[row['idx']] = corpus_pos
        else:
            # Sentence-level: use individual chunks
            for _, row in self.papers_data.iterrows():
                corpus_pos = len(corpus)
                corpus.append(row['content'])
                
                # Map idx to corpus position
                self.idx_to_corpus_pos[row['idx']] = corpus_pos
                
                # For sentence level, also track pidx/sidx
                pidx_sidx = f"{row['pidx']}/{row['sidx']}"
                self.idx_to_corpus_pos[pidx_sidx] = corpus_pos
        
        self.corpus = corpus
        logger.info(f"Built corpus with {len(corpus)} chunks ({granularity} level)")
        logger.info(f"Mapped {len(self.idx_to_corpus_pos)} indices to corpus positions")
        
        return corpus
    
    def get_ground_truth_for_question(self, question_row: pd.Series, granularity: str = 'paragraph') -> Set[int]:
        """Get the ground truth corpus positions for a question using idx mappings.
        
        Args:
            question_row: Row from QA dataframe
            granularity: 'paragraph' or 'sentence'
            
        Returns:
            Set of corpus positions that contain the evidence
        """
        relevant_positions = set()
        
        # Get answer_evidence_mapped which contains the idx values
        evidence_mapped = question_row.get('answer_evidence_mapped', [])
        
        if not evidence_mapped:
            return relevant_positions
        
        # Extract all idx values from the evidence
        for evidence_item in evidence_mapped:
            if isinstance(evidence_item, dict) and 'idx' in evidence_item:
                idx_list = evidence_item['idx']
                if isinstance(idx_list, list):
                    for idx in idx_list:
                        if idx is not None:
                            if granularity == 'paragraph':
                                # For paragraph level, use pidx mapping
                                # First get the pidx for this idx
                                matching_rows = self.papers_data[self.papers_data['idx'] == idx]
                                if not matching_rows.empty:
                                    pidx = matching_rows.iloc[0]['pidx']
                                    if pidx in self.pidx_to_corpus_pos:
                                        relevant_positions.add(self.pidx_to_corpus_pos[pidx])
                            else:
                                # For sentence level, use direct idx mapping
                                if idx in self.idx_to_corpus_pos:
                                    relevant_positions.add(self.idx_to_corpus_pos[idx])
        
        return relevant_positions
    
    def create_ground_truth(self, qa_data: pd.DataFrame, granularity: str = 'paragraph') -> Dict[int, List[int]]:
        """Create ground truth mapping for evaluation using the exact idx mappings.
        
        Args:
            qa_data: QA dataframe
            granularity: 'paragraph' or 'sentence'
            
        Returns:
            Dictionary mapping query index to list of relevant corpus positions
        """
        ground_truth = {}
        
        for idx, row in qa_data.iterrows():
            # Only create ground truth for answerable questions with evidence
            if row.get('answerable', True) and row.get('answer_evidence_mapped'):
                relevant_positions = self.get_ground_truth_for_question(row, granularity)
                ground_truth[idx] = list(relevant_positions)
            else:
                ground_truth[idx] = []
        
        # Statistics
        questions_with_gt = sum(1 for v in ground_truth.values() if v)
        avg_relevant = np.mean([len(v) for v in ground_truth.values() if v]) if questions_with_gt > 0 else 0
        
        logger.info(f"Ground truth created for {len(ground_truth)} questions")
        logger.info(f"Questions with ground truth: {questions_with_gt}")
        logger.info(f"Average relevant chunks per question: {avg_relevant:.2f}")
        
        return ground_truth
    
    def evaluate_retrieval(self, retrieved: Dict[int, List[Tuple[int, float]]], 
                          ground_truth: Dict[int, List[int]], 
                          k_values: List[int] = [5, 10, 20]) -> Dict[str, Any]:
        """Evaluate retrieval performance.
        
        Args:
            retrieved: Retrieved documents per query
            ground_truth: Relevant documents per query
            k_values: K values for Recall@K
            
        Returns:
            Dictionary of metrics
        """
        metrics = {}
        
        # Calculate Recall@K
        recall_at_k = {}
        for k in k_values:
            recalls = []
            for query_id, retrieved_docs in retrieved.items():
                if query_id not in ground_truth or not ground_truth[query_id]:
                    continue
                
                relevant_docs = set(ground_truth[query_id])
                retrieved_at_k = set([doc_id for doc_id, _ in retrieved_docs[:k]])
                
                recall = len(retrieved_at_k & relevant_docs) / len(relevant_docs)
                recalls.append(recall)
            
            recall_at_k[k] = np.mean(recalls) if recalls else 0.0
        
        metrics['recall'] = recall_at_k
        
        # Calculate MRR
        mrrs = []
        for query_id, retrieved_docs in retrieved.items():
            if query_id not in ground_truth or not ground_truth[query_id]:
                continue
            
            relevant_docs = set(ground_truth[query_id])
            
            for i, (doc_id, _) in enumerate(retrieved_docs):
                if doc_id in relevant_docs:
                    mrrs.append(1.0 / (i + 1))
                    break
            else:
                mrrs.append(0.0)
        
        metrics['mrr'] = np.mean(mrrs) if mrrs else 0.0
        
        # Calculate nDCG@K
        ndcg_at_k = {}
        for k in k_values:
            ndcgs = []
            for query_id, retrieved_docs in retrieved.items():
                if query_id not in ground_truth or not ground_truth[query_id]:
                    continue
                
                relevant_docs = set(ground_truth[query_id])
                
                # Calculate DCG
                dcg = 0.0
                for i, (doc_id, _) in enumerate(retrieved_docs[:k]):
                    if doc_id in relevant_docs:
                        dcg += 1.0 / np.log2(i + 2)
                
                # Calculate IDCG
                idcg = sum(1.0 / np.log2(i + 2) for i in range(min(k, len(relevant_docs))))
                
                if idcg > 0:
                    ndcgs.append(dcg / idcg)
            
            ndcg_at_k[k] = np.mean(ndcgs) if ndcgs else 0.0
        
        metrics['ndcg'] = ndcg_at_k
        
        return metrics