import json
import time
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any

from llama_index.core import Document

from .retriever import VectorRetriever
from .bm25_retriever import BM25Retriever


class RetrievalPipeline:
                       

    def __init__(self, config):
        self.config = config
        self.documents = []

                     
        if hasattr(config, 'retrieval_method') and config.retrieval_method == 'bm25':
            if hasattr(config, 'bm25_k1') and hasattr(config, 'bm25_b'):
                self.retriever = BM25Retriever(config, k1=config.bm25_k1, b=config.bm25_b)
            else:
                self.retriever = BM25Retriever(config)
            retrieval_method = "BM25"
        else:
            self.retriever = VectorRetriever(config)
            retrieval_method = "Vector retrieval"

        if config.verbose:
            if config.target_segment_id:
                print("RetrievalPipeline initialized with target segment id")
            else:
                print("RetrievalPipeline initialized without target segment id")

    def load_documents(self) -> bool:

        if self.config.verbose:
            print("Loading documents")
        corpus_dir = Path(self.config.corpus_dir) / self.config.game_name / "corpus"
        if not corpus_dir.exists():
            if self.config.verbose:
                print("Corpus directory does not exist")
            return False

                    
        segment_dirs = self._get_segment_dirs(corpus_dir)

              
        self.documents = []
        for segment_dir in segment_dirs:
            self._load_segment_documents(segment_dir)

        if self.config.verbose:
            self._print_document_stats()

        return len(self.documents) > 0

    def _get_segment_dirs(self, corpus_dir: Path) -> List[Path]:
                        
        if self.config.target_segment_id:
                    
            dirs = [corpus_dir / f"segment_{self.config.target_segment_id}"]
            if self.config.include_timeless:
                timeless_dir = corpus_dir / "segment_timeless"
                if timeless_dir.exists():
                    dirs.append(timeless_dir)
        else:
                    
            dirs = [d for d in corpus_dir.iterdir() if d.is_dir() and d.name.startswith('segment_')]
            if self.config.include_timeless:
                timeless_dir = corpus_dir / "segment_timeless"
                if timeless_dir.exists() and timeless_dir not in dirs:
                    dirs.append(timeless_dir)
        return dirs

    def _load_segment_documents(self, segment_dir: Path):
                       
        corpus_file = segment_dir / "corpus.jsonl"
        if not corpus_file.exists():
            return

        if self.config.verbose:
            print("Loading segment documents")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                if not line.strip():
                    continue

                try:
                    data = json.loads(line.strip())

                                              
                    original_metadata = data.get('metadata', {})
                    metadata = {
                        'id': data.get('id', f'{segment_dir.name}_doc_{line_num}'),
                        'title': data.get('title', ''),
                        'source': data.get('source', ''),
                        'game': self.config.game_name,
                        'segment_id': segment_dir.name,
                    }

                                               
                    for key, value in original_metadata.items():
                        if key == 'entities' and isinstance(value, list):
                                                               
                            entity_texts = []
                            for entity in value:
                                if isinstance(entity, dict) and entity.get('text'):
                                    entity_texts.append(entity.get('text', ''))
                            metadata['entity_texts'] = entity_texts
                        else:
                            metadata[key] = value

                    doc = Document(text=data.get('contents', ''), metadata=metadata)
                    doc.id_ = metadata['id']
                    self.documents.append(doc)

                except json.JSONDecodeError as e:
                    if self.config.verbose:
                        print(f"Error loading segment documents: {e}")
    def _print_document_stats(self):
                      
        segment_stats = {}
        for doc in self.documents:
            segment = doc.metadata.get('segment_id', 'unknown')
            segment_stats[segment] = segment_stats.get(segment, 0) + 1

        for segment, count in sorted(segment_stats.items()):
            label = " (timeless)" if segment == 'segment_timeless' or segment == -1 else ""

                    
        has_timeless = 'segment_timeless' in segment_stats or -1 in segment_stats
        if self.config.include_timeless and has_timeless:
            print("Has timeless")
        elif self.config.include_timeless and not has_timeless:
            print("Has not timeless")
        elif not self.config.include_timeless and has_timeless:
            print("Has not timeless")

    def initialize(self) -> bool:
                           
        try:
                  
            if not self.load_documents():
                return False

                  
            if not self.retriever.build_index(self.documents):
                return False

            if self.config.verbose:
                print("Initialized")
            return True
        except Exception as e:
            if self.config.verbose:
                print(f"Error initializing: {e}")
            return False

    def retrieve_single(self, question: str) -> Dict[str, Any]:
                       
        start_time = time.time()

        if self.config.verbose:
            print("Retrieving")
              
        docs = self.retriever.retrieve(question)

        if self.config.verbose:
            for i, doc in enumerate(docs, 1):
                metadata = doc.get('metadata', {})
                title = metadata.get('title', 'Unknown')
                segment = metadata.get('segment_id', 'Unknown')
                score = doc.get('score', 0.0)

                
        result = {
            'question': question,
            'retrieved_docs': docs,
            'retrieved_doc_ids': [doc.get('id', '') for doc in docs],
            'retrieval_time': time.time() - start_time,
            'timestamp': datetime.now().isoformat(),
            'retrieval_config': self._get_retrieval_config()
        }

        if self.config.verbose:
            print("Retrieved")
        return result

    def batch_retrieve(self, questions: List[str], output_file: str = None) -> List[Dict[str, Any]]:
                                          
        total = len(questions)
        results = []

        if self.config.verbose:
            print("Batch retrieving")
        try:
                      
            start_time = time.time()
            batch_results = self.retriever.batch_retrieve(questions)
            batch_time = time.time() - start_time

            if self.config.verbose:
                print("Batch retrieving")
                    
            for i, (question, docs) in enumerate(zip(questions, batch_results)):
                if self.config.verbose and i % 20 == 0:
                    print("Batch retrieving")
                result = {
                    'question': question,
                    'retrieved_docs': docs,
                    'retrieved_doc_ids': [doc.get('id', '') for doc in docs],
                    'retrieval_time': batch_time / total,        
                    'timestamp': datetime.now().isoformat(),
                    'question_index': i,
                    'retrieval_config': {
                        **self._get_retrieval_config(),
                        'batch_mode': True
                    }
                }

                results.append(result)

                       
                if output_file:
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(result, ensure_ascii=False) + '\n')

        except Exception as e:
            if self.config.verbose:
                print(f"Error retrieving for question {i}: {e}")
                          
            for i, question in enumerate(questions):
                if self.config.verbose:
                    print("Batch retrieving")
                try:
                    result = self.retrieve_single(question)
                    result['question_index'] = i
                    results.append(result)

                           
                    if output_file:
                        with open(output_file, 'a', encoding='utf-8') as f:
                            f.write(json.dumps(result, ensure_ascii=False) + '\n')

                except Exception as e:
                    if self.config.verbose:
                        print(f"Error retrieving for question {i}: {e}")
                    error_result = {
                        'question': question,
                        'question_index': i,
                        'retrieved_docs': [],
                        'retrieved_doc_ids': [],
                        'error': str(e),
                        'timestamp': datetime.now().isoformat()
                    }
                    results.append(error_result)

                    if output_file:
                        with open(output_file, 'a', encoding='utf-8') as f:
                            f.write(json.dumps(error_result, ensure_ascii=False) + '\n')

        if self.config.verbose:
            print("Batch retrieved")
        return results

    def batch_retrieve_qa_pairs(self, qa_pairs: List[Dict[str, Any]], output_file: str = None) -> List[Dict[str, Any]]:
                                                              
        total = len(qa_pairs)
        results = []

        if self.config.verbose:
            print("Batch retrieving QA pairs")
        try:
                    
            questions = [qa_pair['question'] for qa_pair in qa_pairs]

                      
            start_time = time.time()
            batch_results = self.retriever.batch_retrieve(questions)
            batch_time = time.time() - start_time

            for i, (qa_pair, docs) in enumerate(zip(qa_pairs, batch_results)):
                question = qa_pair['question']

                result = {
                    'question': question,
                    'retrieved_docs': docs,
                    'retrieved_doc_ids': [doc.get('id', '') for doc in docs],
                    'retrieval_time': batch_time / total,        
                    'timestamp': datetime.now().isoformat(),
                    'question_index': i,
                    'ground_truth_answer': qa_pair['ground_truth_answer'],
                    'ground_truth_doc_ids': qa_pair['ground_truth_doc_ids'],
                    'ground_truth_docs': qa_pair['ground_truth_docs'],
                    'original_qa_data': qa_pair.get('original_data', {}),
                    'retrieval_config': {
                        **self._get_retrieval_config(),
                        'batch_mode': True
                    }
                }

                results.append(result)

                       
                if output_file:
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(result, ensure_ascii=False) + '\n')

        except Exception as e:
            if self.config.verbose:
                print(f"Error retrieving for question {i}: {e}")
                          
            for i, qa_pair in enumerate(qa_pairs):
                question = qa_pair['question']

                if self.config.verbose:
                    print("Batch retrieving QA pairs")
                try:
                          
                    result = self.retrieve_single(question)

                                          
                    result.update({
                        'question_index': i,
                        'ground_truth_answer': qa_pair['ground_truth_answer'],
                        'ground_truth_doc_ids': qa_pair['ground_truth_doc_ids'],
                        'ground_truth_docs': qa_pair['ground_truth_docs'],
                        'original_qa_data': qa_pair.get('original_data', {})
                    })

                    results.append(result)

                           
                    if output_file:
                        with open(output_file, 'a', encoding='utf-8') as f:
                            f.write(json.dumps(result, ensure_ascii=False) + '\n')

                except Exception as e:
                    if self.config.verbose:
                        print(f"Error retrieving for question {i}: {e}")
                    error_result = {
                        'question': question,
                        'question_index': i,
                        'retrieved_docs': [],
                        'retrieved_doc_ids': [],
                        'error': str(e),
                        'timestamp': datetime.now().isoformat(),
                        'ground_truth_answer': qa_pair['ground_truth_answer'],
                        'ground_truth_doc_ids': qa_pair['ground_truth_doc_ids']
                    }
                    results.append(error_result)

                    if output_file:
                        with open(output_file, 'a', encoding='utf-8') as f:
                            f.write(json.dumps(error_result, ensure_ascii=False) + '\n')

        return results

    def load_qa_pairs(self, segment_id: int = None) -> List[Dict[str, Any]]:
                                
        if segment_id is None:
            segment_id = self.config.target_segment_id

        if segment_id is None:
            raise ValueError("Segment ID must be specified")

                             
        data_type = getattr(self.config, 'data_type', 'resample')
        if data_type == 'resample':
            qa_dir_name = f"{self.config.game_name}_resample"
        elif data_type == 'wo_interest_drift_resample':
            qa_dir_name = f"{self.config.game_name}_wo_interest_drift_resample"
        elif data_type == 'wo_knowledge_update_resample':
            qa_dir_name = f"{self.config.game_name}_wo_knowledge_update_resample"
        else:
            qa_dir_name = f"{self.config.game_name}_{data_type}"

        qa_dir = Path(self.config.qa_data_dir) / qa_dir_name / f"segment_{segment_id}"
        qa_filename = "generated_qa_pairs_resampled.jsonl"
        qa_file = qa_dir / qa_filename

        if not qa_file.exists():
            error_msg = f"❌ QA data file does not exist: {qa_file}"
            if self.config.verbose:
                print("Loading QA pairs")
            raise FileNotFoundError(error_msg)

        qa_pairs = []
        try:
            with open(qa_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        try:
                            data = json.loads(line.strip())
                            question = data.get('question', '').strip()
                            if question:
                                               
                                ground_truth_answer = data.get('answer', '')
                                ground_truth_docs = data.get('retrieved_docs', [])
                                ground_truth_doc_ids = [doc.get('id', '') for doc in ground_truth_docs]

                                qa_pairs.append({
                                    'question': question,
                                    'ground_truth_answer': ground_truth_answer,
                                    'ground_truth_doc_ids': ground_truth_doc_ids,
                                    'ground_truth_docs': ground_truth_docs,
                                    'original_data': data            
                                })
                        except json.JSONDecodeError:
                            continue

        except Exception as e:
            if self.config.verbose:
                print(f"Error loading QA pairs: {e}")
        return qa_pairs

    def _get_retrieval_config(self) -> Dict[str, Any]:
                      
        config = {
            'game': self.config.game_name,
            'segment_id': self.config.target_segment_id,
            'top_k': self.config.top_k,
            'include_timeless': self.config.include_timeless
        }

                         
        if hasattr(self.config, 'retrieval_method') and self.config.retrieval_method == 'bm25':
            config.update({
                'retrieval_method': 'bm25',
                'bm25_k1': getattr(self.config, 'bm25_k1', 1.2),
                'bm25_b': getattr(self.config, 'bm25_b', 0.75)
            })
        else:
            config.update({
                'retrieval_method': 'vector',
                'embedding_model': getattr(self.config, 'embedding_model', 'unknown'),
                'embedding_service': getattr(self.config, 'embedding_service', 'unknown')
            })

        return config
