import json
import time
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any

from llama_index.core import Document

from .config import RAGConfig
from .retriever import VectorRetriever
from .generator import TextGenerator


class RAGPipeline:
                         

    def __init__(self, config: RAGConfig):
        self.config = config
        self.retriever = VectorRetriever(config)
        self.generator = TextGenerator(config)
        self.documents = []

        if config.verbose:
            if config.target_segment_id:
                print("RAGPipeline initialized with target segment id")
            else:
                print("RAGPipeline initialized without target segment id")

    def load_documents(self) -> bool:
                  
        if self.config.verbose:
            print("Loading documents")
        corpus_dir = Path(self.config.corpus_dir) / self.config.game_name / "corpus"
        if not corpus_dir.exists():
            if self.config.verbose:
                print("Corpus directory does not exist")
            return False

                    
        segment_dirs = self._get_segment_dirs(corpus_dir)

              
        self.documents = []
        for segment_dir in segment_dirs:
            self._load_segment_documents(segment_dir)

        if self.config.verbose:
            self._print_document_stats()

        return len(self.documents) > 0

    def _get_segment_dirs(self, corpus_dir: Path) -> List[Path]:
                        
        if self.config.target_segment_id:
                    
            dirs = [corpus_dir / f"segment_{self.config.target_segment_id}"]
            if self.config.include_timeless:
                timeless_dir = corpus_dir / "segment_timeless"
                if timeless_dir.exists():
                    dirs.append(timeless_dir)
        else:
                    
            dirs = [d for d in corpus_dir.iterdir() if d.is_dir() and d.name.startswith('segment_')]
            if self.config.include_timeless:
                timeless_dir = corpus_dir / "segment_timeless"
                if timeless_dir.exists() and timeless_dir not in dirs:
                    dirs.append(timeless_dir)
        return dirs

    def _load_segment_documents(self, segment_dir: Path):
                       
        corpus_file = segment_dir / "corpus.jsonl"
        if not corpus_file.exists():
            return

        if self.config.verbose:
            print("Loading segment documents")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                if not line.strip():
                    continue

                try:
                    data = json.loads(line.strip())

                                              
                    original_metadata = data.get('metadata', {})
                    metadata = {
                        'id': data.get('id', f'{segment_dir.name}_doc_{line_num}'),
                        'title': data.get('title', ''),
                        'source': data.get('source', ''),
                        'game': self.config.game_name,
                        'segment_id': segment_dir.name,
                    }

                                               
                    for key, value in original_metadata.items():
                        if key == 'entities' and isinstance(value, list):
                                                               
                            entity_texts = []
                            for entity in value:
                                if isinstance(entity, dict) and entity.get('text'):
                                    entity_texts.append(entity.get('text', ''))
                            metadata['entity_texts'] = entity_texts
                        else:
                            metadata[key] = value

                    doc = Document(text=data.get('contents', ''), metadata=metadata)
                    doc.id_ = metadata['id']
                    self.documents.append(doc)

                except json.JSONDecodeError as e:
                    if self.config.verbose:
                        print(f"Error loading segment documents: {e}")
    def _print_document_stats(self):
                      
        segment_stats = {}
        for doc in self.documents:
            segment = doc.metadata.get('segment_id', 'unknown')
            segment_stats[segment] = segment_stats.get(segment, 0) + 1

        for segment, count in sorted(segment_stats.items()):
            label = " (timeless)" if segment == 'segment_timeless' or segment == -1 else ""

                    
        has_timeless = 'segment_timeless' in segment_stats or -1 in segment_stats
        if self.config.include_timeless and has_timeless:
            print("Has timeless")
        elif self.config.include_timeless and not has_timeless:
            print("Has not timeless")
        elif not self.config.include_timeless and has_timeless:
            print("Has not timeless")

    def initialize(self) -> bool:
                         
        try:
                  
            if not self.load_documents():
                return False

                  
            if not self.retriever.build_index(self.documents):
                return False

            if self.config.verbose:
                print("Initialized")
            return True
        except Exception as e:
            if self.config.verbose:
                print(f"Error initializing: {e}")
            return False

    def query(self, question: str) -> Dict[str, Any]:
                  
        start_time = time.time()

        if self.config.verbose:
            print("Querying")
              
        docs = self.retriever.retrieve(question)

        if self.config.verbose:
            for i, doc in enumerate(docs, 1):
                metadata = doc.get('metadata', {})
                title = metadata.get('title', 'Unknown')
                segment = metadata.get('segment_id', 'Unknown')
                score = doc.get('score', 0.0)


        answer = self.generator.generate(question, docs)

              
        result = {
            'question': question,
            'answer': answer,
            'retrieved_docs': docs,
            'retrieved_doc_ids': [doc.get('id', '') for doc in docs],
            'processing_time': time.time() - start_time,
            'timestamp': datetime.now().isoformat(),
            'config': {
                'game': self.config.game_name,
                'embedding_model': self.config.embedding_model,
                'llm_model': self.config.llm_model,
                'segment_id': self.config.target_segment_id
            }
        }

        return result

    def batch_query(self, questions: List[str], output_file: str = None) -> List[Dict[str, Any]]:
                  
        total = len(questions)
        results = []

        if self.config.verbose:
            print("Batch querying")
        for i, question in enumerate(questions):
            if self.config.verbose:
                print("Batch querying")
            try:
                result = self.query(question)
                result['question_index'] = i
                results.append(result)

                       
                if output_file:
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(result, ensure_ascii=False) + '\n')

            except Exception as e:
                if self.config.verbose:
                    print(f"Error querying for question {i}: {e}")
                error_result = {
                    'question': question,
                    'question_index': i,
                    'answer': f"Query failed: {str(e)}",
                    'error': str(e),
                    'timestamp': datetime.now().isoformat()
                }
                results.append(error_result)

                if output_file:
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(error_result, ensure_ascii=False) + '\n')

        if self.config.verbose:
            print("Batch queried")
        return results

    def batch_evaluate(self, qa_pairs: List[Dict[str, Any]], output_file: str = None) -> List[Dict[str, Any]]:
                                     
        total = len(qa_pairs)
        results = []

        if self.config.verbose:
            print("Batch evaluating")
        for i, qa_pair in enumerate(qa_pairs):
            question = qa_pair['question']
            ground_truth_answer = qa_pair['ground_truth_answer']
            ground_truth_doc_ids = qa_pair['ground_truth_doc_ids']

            if self.config.verbose:
                print("Batch evaluating")
            try:
                         
                result = self.query(question)

                          
                result.update({
                    'question_index': i,
                    'ground_truth_answer': ground_truth_answer,
                    'ground_truth_doc_ids': ground_truth_doc_ids,
                    'ground_truth_docs': qa_pair['ground_truth_docs'],
                                                     
                                                   
                })

                results.append(result)

                       
                if output_file:
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(result, ensure_ascii=False) + '\n')

            except Exception as e:
                if self.config.verbose:
                    print(f"Error evaluating for question {i}: {e}")
                error_result = {
                    'question': question,
                    'question_index': i,
                    'answer': f"Evaluation failed: {str(e)}",
                    'error': str(e),
                    'timestamp': datetime.now().isoformat(),
                    'ground_truth_answer': ground_truth_answer,
                    'ground_truth_doc_ids': ground_truth_doc_ids,
                    'retrieved_doc_ids': []
                }
                results.append(error_result)

                if output_file:
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(error_result, ensure_ascii=False) + '\n')

        if self.config.verbose:
            print("Batch evaluated")
        return results

    def load_qa_pairs(self, segment_id: int = None) -> List[Dict[str, Any]]:
                                    
        if segment_id is None:
            segment_id = self.config.target_segment_id

        if segment_id is None:
            raise ValueError("Segment ID must be specified")

                           
        qa_dir = Path(self.config.qa_data_dir) / f"{self.config.game_name}_resample" / f"segment_{segment_id}"
        qa_filename = "generated_qa_pairs_resampled.jsonl"
        qa_file = qa_dir / qa_filename

        if not qa_file.exists():
            error_msg = f"❌ Resample QA data file does not exist: {qa_file}"
            if self.config.verbose:
                print(f"Error loading QA pairs: {e}")
            raise FileNotFoundError(error_msg)

        if self.config.verbose:
            print("Loading QA pairs")
        qa_pairs = []
        try:
            with open(qa_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        try:
                            data = json.loads(line.strip())
                            question = data.get('question', '').strip()
                            if question:
                                               
                                ground_truth_answer = data.get('answer', '')
                                ground_truth_docs = data.get('retrieved_docs', [])
                                ground_truth_doc_ids = [doc.get('id', '') for doc in ground_truth_docs]

                                qa_pairs.append({
                                    'question': question,
                                    'ground_truth_answer': ground_truth_answer,
                                    'ground_truth_doc_ids': ground_truth_doc_ids,
                                    'ground_truth_docs': ground_truth_docs,
                                    'original_data': data                  
                                })
                        except json.JSONDecodeError:
                            continue

            if self.config.verbose:
                print("Loaded QA pairs")
        except Exception as e:
            if self.config.verbose:
                print(f"Error loading QA pairs: {e}") 
        return qa_pairs

    def load_qa_questions(self, segment_id: int = None) -> List[str]:
                          
        qa_pairs = self.load_qa_pairs(segment_id)
        return [pair['question'] for pair in qa_pairs]
