"""
Data loader for local PeerQA JSONL files.
This reads directly from the extracted data/ directory without HuggingFace.
"""
import json
import os
import pandas as pd
import numpy as np
import logging
from typing import Dict, List, Optional
from pathlib import Path
import random

logger = logging.getLogger(__name__)

class PeerQALocalDataLoader:
    """Loader for local PeerQA JSONL files."""
    
    def __init__(self, config: Dict):
        """Initialize the data loader.
        
        Args:
            config: Configuration dictionary
        """
        self.config = config
        self.data_dir = config.get('data_dir', 'data')
        self.qa_data = None
        self.papers_data = None
        self.qa_augmented = None
        self.processed_data = {}
        random.seed(42)
        np.random.seed(42)
        
    def read_jsonl(self, file_path: str) -> List[Dict]:
        """Read a JSONL file.
        
        Args:
            file_path: Path to JSONL file
            
        Returns:
            List of dictionaries from the file
        """
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line))
        return data
    
    def load_data(self) -> pd.DataFrame:
        """Load PeerQA data from local JSONL files.
        
        Returns:
            DataFrame containing the loaded QA data
        """
        logger.info(f"Loading PeerQA data from {self.data_dir}/")
        
        # Load QA data
        qa_file = os.path.join(self.data_dir, 'qa.jsonl')
        if os.path.exists(qa_file):
            qa_data = self.read_jsonl(qa_file)
            self.qa_data = pd.DataFrame(qa_data)
            logger.info(f"✅ Loaded {len(self.qa_data)} QA pairs from qa.jsonl")
            
            # Show sample structure
            if len(self.qa_data) > 0:
                sample = self.qa_data.iloc[0]
                logger.info("📊 Sample QA structure:")
                for key in sample.keys():
                    value = sample[key]
                    if isinstance(value, str):
                        logger.info(f"  - {key}: '{value[:50]}...'")
                    elif isinstance(value, list):
                        logger.info(f"  - {key}: list with {len(value)} items")
                    elif isinstance(value, bool):
                        logger.info(f"  - {key}: {value}")
                    else:
                        logger.info(f"  - {key}: {type(value).__name__}")
        
        # Load augmented answers if available
        qa_aug_file = os.path.join(self.data_dir, 'qa-augmented-answers.jsonl')
        if os.path.exists(qa_aug_file):
            qa_aug_data = self.read_jsonl(qa_aug_file)
            self.qa_augmented = pd.DataFrame(qa_aug_data)
            logger.info(f"✅ Loaded {len(self.qa_augmented)} augmented answers from qa-augmented-answers.jsonl")
        
        # Load papers data
        papers_file = os.path.join(self.data_dir, 'papers.jsonl')
        if os.path.exists(papers_file):
            papers_data = self.read_jsonl(papers_file)
            self.papers_data = pd.DataFrame(papers_data)
            logger.info(f"✅ Loaded {len(self.papers_data)} paper chunks from papers.jsonl")
        
        return self.qa_data
    
    def preprocess_data(self, granularity: str, template: str) -> pd.DataFrame:
        """Preprocess data for a specific granularity and template.
        
        Args:
            granularity: 'sentence' or 'paragraph'
            template: Decontextualization template to apply
            
        Returns:
            Preprocessed DataFrame
        """
        if self.qa_data is None:
            self.load_data()
        
        if self.qa_data is None or len(self.qa_data) == 0:
            logger.error("No QA data available!")
            return pd.DataFrame()
        
        # Create processed data
        processed_data = []
        n_samples = self.config.get('n_samples', len(self.qa_data))
        
        for idx, row in self.qa_data.head(n_samples).iterrows():
            # Extract fields from real PeerQA data
            question_id = row.get('question_id', f'q_{idx}')
            question = row.get('question', '')
            
            # Get answer (use augmented if available)
            answer = row.get('answer', '')
            if self.qa_augmented is not None:
                aug_row = self.qa_augmented[self.qa_augmented['question_id'] == question_id]
                if not aug_row.empty:
                    answer = aug_row.iloc[0].get('answer_augmented', answer)
            
            # Get evidence/context
            evidence = row.get('answer_evidence_mapped', row.get('answer_evidence_sent', row.get('answer_evidence', [])))
            
            # Handle different evidence structures
            if isinstance(evidence, list):
                context_parts = []
                for item in evidence:
                    if isinstance(item, dict):
                        # Handle dict items in evidence_mapped
                        if 'sentence' in item:
                            context_parts.append(str(item['sentence']))
                        elif 'text' in item:
                            context_parts.append(str(item['text']))
                        else:
                            # Use string representation of dict
                            context_parts.append(str(item))
                    elif isinstance(item, str):
                        context_parts.append(item)
                    else:
                        context_parts.append(str(item))
                context = ' '.join(context_parts)
            elif isinstance(evidence, dict):
                # Handle dict evidence structure
                if 'sentence' in evidence:
                    sentences = evidence.get('sentence', [])
                    if isinstance(sentences, list):
                        context = ' '.join([str(s) for s in sentences])
                    else:
                        context = str(sentences)
                else:
                    context = str(evidence)
            else:
                context = str(evidence)
            
            # If no evidence, try using answer_evidence_sent directly
            if not context or context == '[]':
                evidence_sent = row.get('answer_evidence_sent', [])
                if evidence_sent and isinstance(evidence_sent, list):
                    context = ' '.join([str(s) for s in evidence_sent])
            
            # Get answerability and ensure it's boolean
            answerable_value = row.get('answerable', row.get('answerable_mapped', True))
            if isinstance(answerable_value, bool):
                answerable = answerable_value
            elif isinstance(answerable_value, str):
                answerable = answerable_value.lower() in ['true', 'yes', '1']
            elif answerable_value is None:
                answerable = True
            else:
                answerable = bool(answerable_value)
            
            # Apply template to context
            context = self._apply_template(context, template, row)
            
            # Apply granularity-specific chunking
            if granularity == 'sentence':
                chunks = self._chunk_sentences(context)
            else:
                chunks = self._chunk_paragraphs(context)
            
            processed_data.append({
                'question_id': question_id,
                'question': question,
                'context': context,
                'chunks': chunks,
                'answer': answer,
                'answerability': answerable,
                'paper_id': row.get('paper_id', ''),
                'source': 'PeerQA_Local'
            })
        
        df = pd.DataFrame(processed_data)
        
        # Store for retrieval
        key = f"{granularity}_{template}"
        self.processed_data[key] = df
        
        logger.info(f"Preprocessed {len(df)} samples for {key}")
        logger.info(f"  - Questions with answers: {(df['answer'] != '').sum()}")
        logger.info(f"  - Answerable: {df['answerability'].sum()} ({df['answerability'].mean():.1%})")
        logger.info(f"  - Avg chunks per question: {df['chunks'].apply(len).mean():.1f}")
        
        return df
    
    def _apply_template(self, context: str, template: str, row: pd.Series) -> str:
        """Apply decontextualization template to context.
        
        Args:
            context: Original context text
            template: Template type
            row: Original data row for additional metadata
            
        Returns:
            Processed context string
        """
        paper_id = row.get('paper_id', '')
        
        if template == 'minimal':
            return context
        elif template == 'title_only':
            title = f"Paper: {paper_id}" if paper_id else "Research Paper"
            return f"{title}\n{context}"
        elif template == 'heading_only':
            heading = "Evidence Section"
            return f"{heading}\n{context}"
        elif template == 'title_heading':
            title = f"Paper: {paper_id}" if paper_id else "Research Paper"
            heading = "Evidence Section"
            return f"{title}\n{heading}\n{context}"
        elif template == 'aggressive_title':
            title = f"[PAPER: {paper_id}]" if paper_id else "[RESEARCH PAPER]"
            heading = "EVIDENCE"
            return f"{title} {heading}: {context}"
        else:
            return context
    
    def _chunk_sentences(self, text: str) -> List[str]:
        """Chunk text into sentences.
        
        Args:
            text: Input text
            
        Returns:
            List of sentence chunks
        """
        import re
        sentences = re.split(r'(?<=[.!?])\s+', text)
        return [s.strip() for s in sentences if s.strip()]
    
    def _chunk_paragraphs(self, text: str) -> List[str]:
        """Chunk text into paragraphs.
        
        Args:
            text: Input text
            
        Returns:
            List of paragraph chunks
        """
        import re
        paragraphs = re.split(r'\n\n|\n', text)
        result = []
        current = ""
        for p in paragraphs:
            p = p.strip()
            if not p:
                continue
            if len(current) + len(p) < 200:  # Minimum paragraph size
                current = f"{current} {p}" if current else p
            else:
                if current:
                    result.append(current.strip())
                current = p
        if current:
            result.append(current.strip())
        return result if result else [text]
    
    def get_paper_chunks(self, paper_id: str, granularity: str = 'sentence') -> List[str]:
        """Get paper chunks for a given paper ID.
        
        Args:
            paper_id: Paper identifier
            granularity: 'sentence' or 'paragraph'
            
        Returns:
            List of text chunks from the paper
        """
        if self.papers_data is None:
            return []
        
        paper_data = self.papers_data[self.papers_data['paper_id'] == paper_id]
        
        if paper_data.empty:
            return []
        
        chunks = []
        for _, row in paper_data.iterrows():
            content = row.get('content', '')
            if content:
                if granularity == 'sentence':
                    chunks.extend(self._chunk_sentences(content))
                else:
                    chunks.append(content)
        
        return chunks
    
    def get_statistics(self) -> Dict:
        """Get statistics about the loaded data.
        
        Returns:
            Dictionary with dataset statistics
        """
        stats = {}
        
        if self.qa_data is not None:
            stats['qa_samples'] = len(self.qa_data)
            stats['answerable_ratio'] = self.qa_data.get('answerable', pd.Series([True])).mean()
            stats['avg_question_length'] = self.qa_data.get('question', pd.Series([''])).str.len().mean()
            
        if self.papers_data is not None:
            stats['paper_chunks'] = len(self.papers_data)
            stats['unique_papers'] = self.papers_data.get('paper_id', pd.Series()).nunique()
            
        return stats