import json
import pickle
from pathlib import Path
from typing import List, Dict, Any, Optional
from collections import defaultdict, Counter
import math
import re
import jieba
from tqdm import tqdm

from llama_index.core import Document


class BM25Retriever:
                   

    def __init__(self, config, k1: float = 1.2, b: float = 0.75):
           
        self.config = config
        self.k1 = k1
        self.b = b

                  
        self.documents = []
        self.doc_freqs = []             
        self.idf = {}           
        self.doc_lens = []           
        self.avgdl = 0.0          
        self.N = 0        

        if config.verbose:
            print("BM25Retriever initialized")

    def _tokenize(self, text: str) -> List[str]:

           
        if not text:
            return []

              
        text = re.sub(r'[^\w\s\u4e00-\u9fff]', ' ', text)
        text = text.lower()

                              
        tokens = []
        words = jieba.lcut(text)

        for word in words:
            word = word.strip()
            if len(word) > 1:         
                tokens.append(word)

        return tokens

    def _compute_idf(self):
                         
        df = defaultdict(int)        

                        
        for doc_freq in self.doc_freqs:
            for word in doc_freq.keys():
                df[word] += 1

               
        self.idf = {}
        for word, freq in df.items():
            self.idf[word] = math.log(self.N / freq)

    def build_index(self, documents: List[Document]) -> bool:
           
        index_name = self._get_index_name()
        index_path = Path(self.config.index_dir) / self.config.game_name / index_name

                  
        if index_path.exists() and not self.config.force_rebuild:
            if self._load_existing_index(index_path):
                return True

               
        return self._build_new_index(documents, index_path)

    def _get_index_name(self) -> str:
                    
        if self.config.target_segment_id:
            suffix = "_with_timeless" if self.config.include_timeless else ""
            return f"bm25/segment_{self.config.target_segment_id}{suffix}"
        return "bm25/full_corpus"

    def _load_existing_index(self, index_path: Path) -> bool:
                    
        try:
            if self.config.verbose:
                print("Loading existing BM25 index")
                    
            with open(index_path / "bm25_index.pkl", 'rb') as f:
                index_data = pickle.load(f)

            self.documents = index_data['documents']
            self.doc_freqs = index_data['doc_freqs']
            self.idf = index_data['idf']
            self.doc_lens = index_data['doc_lens']
            self.avgdl = index_data['avgdl']
            self.N = index_data['N']

            if self.config.verbose:
                print("Loaded existing BM25 index")
            return True
        except Exception as e:
            if self.config.verbose:
                print(f"Error loading existing BM25 index: {e}")
            return False

    def _build_new_index(self, documents: List[Document], index_path: Path) -> bool:
                        
        try:
            if self.config.verbose:
                print("Building new BM25 index")
            self.documents = documents
            self.N = len(documents)
            self.doc_freqs = []
            self.doc_lens = []

                       
            if self.config.verbose:
                print("Processing documents")
            for doc in tqdm(documents, desc="Processing documents", disable=not self.config.verbose):
                    
                tokens = self._tokenize(doc.text)

                      
                word_freq = Counter(tokens)
                self.doc_freqs.append(dict(word_freq))
                self.doc_lens.append(len(tokens))

                      
            self.avgdl = sum(self.doc_lens) / len(self.doc_lens)

                   
            if self.config.verbose:
                print("Computing IDF")
            self._compute_idf()

                  
            index_path.mkdir(parents=True, exist_ok=True)

            index_data = {
                'documents': self.documents,
                'doc_freqs': self.doc_freqs,
                'idf': self.idf,
                'doc_lens': self.doc_lens,
                'avgdl': self.avgdl,
                'N': self.N
            }

            with open(index_path / "bm25_index.pkl", 'wb') as f:
                pickle.dump(index_data, f)

                   
            metadata = {
                'k1': self.k1,
                'b': self.b,
                'num_documents': self.N,
                'num_vocab': len(self.idf),
                'avg_doc_len': self.avgdl
            }

            with open(index_path / "metadata.json", 'w', encoding='utf-8') as f:
                json.dump(metadata, f, ensure_ascii=False, indent=2)

            if self.config.verbose:
                print("Built new BM25 index")
            return True
        except Exception as e:
            if self.config.verbose:
                print(f"Error building new BM25 index: {e}")
            return False

    def _bm25_score(self, query_tokens: List[str], doc_idx: int) -> float:

           
        score = 0.0
        doc_freq = self.doc_freqs[doc_idx]
        doc_len = self.doc_lens[doc_idx]

        for token in query_tokens:
            if token in doc_freq and token in self.idf:
                tf = doc_freq[token]
                idf = self.idf[token]

                        
                numerator = tf * (self.k1 + 1)
                denominator = tf + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)
                score += idf * numerator / denominator

        return score

    def retrieve(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]:

           
        if not self.documents:
            raise ValueError("BM25 index not initialized")

        if top_k is None:
            top_k = self.config.top_k

              
        query_tokens = self._tokenize(query)
        if not query_tokens:
            return []

                       
        scores = []
        for i in range(self.N):
            score = self._bm25_score(query_tokens, i)
            scores.append((score, i))

                      
        scores.sort(reverse=True)
        top_scores = scores[:top_k]

              
        results = []
        for score, doc_idx in top_scores:
            doc = self.documents[doc_idx]
            results.append({
                'id': doc.id_,
                'content': doc.text,
                'metadata': doc.metadata,
                'score': score
            })

        return results

    def batch_retrieve(self, queries: List[str], top_k: Optional[int] = None) -> List[List[Dict[str, Any]]]:

           
        if top_k is None:
            top_k = self.config.top_k

        results = []

        if self.config.verbose:
            print("Batch retrieving")   
        for i, query in enumerate(tqdm(queries, desc="BM25 retrieval", disable=not self.config.verbose)):
            if self.config.verbose and i % 20 == 0:
                print("Batch retrieving")
            try:
                query_results = self.retrieve(query, top_k)
                results.append(query_results)
            except Exception as e:
                if self.config.verbose:
                    print(f"Error retrieving for query {i}: {e}")
                results.append([])

        if self.config.verbose:
            print("Batch retrieved")
        return results
