import os
import math
import re
import json
import csv
import pandas as pd
from typing import List, Dict, Any, Set
from collections import defaultdict

import openai
from tqdm import tqdm

# Import utility functions
from eval_utils import extract_json


class InvertedIndexRAG:
    """RAG implementation based on inverted index, supporting TF/TF-IDF/BM25 scoring"""
    
    def __init__(self, 
                 base_url: str,
                 model_name_or_path: str,
                 stopwords: Set[str] = None,
                 api_key: str = "EMPTY",
                 qa_df: pd.DataFrame = None,
                 num_few_shot: int = 3,
                 # BM25 core parameters (default values are industry standard configurations)
                 bm25_k1: float = 1.2,
                 bm25_b: float = 0.75):
        """Initialize RAG model based on inverted index (new BM25 parameters)"""
        # Document storage (original logic unchanged)
        self.documents = []
        self.ids_to_doc_idx = {}
        self.doc_idx_to_ids = {}
        
        # Inverted index (structure unchanged, reusing existing storage)
        self.inverted_index = defaultdict(list)
        
        # Document statistics (new average document length required for BM25)
        self.doc_total_terms = []  # Total terms per document
        self.avg_doc_len = 0.0     # Average document length (BM25 specific, precomputed)
        
        # Tokenization configuration (original logic unchanged)
        self.stopwords = stopwords or set()
        
        # LLM client configuration (original logic unchanged)
        self.api_key = api_key
        self.base_url = base_url
        self.model_name_or_path = model_name_or_path
        self.ans_client = openai.OpenAI(api_key=api_key, base_url=base_url)
        
        # Few-shot configuration (original logic unchanged)
        self.qa_df = qa_df
        self.num_few_shot = num_few_shot
        
        # -------------------------- BM25 parameter initialization --------------------------
        self.bm25_k1 = bm25_k1  # Term frequency saturation coefficient (1.2-2.0 optimal)
        self.bm25_b = bm25_b    # Document length normalization coefficient (0.75 optimal)

    def _clean_text(self, text: str) -> str:
        """Simplified text cleaning: remove XML tags, limit length (original logic unchanged)"""
        text = re.sub(r'<[^>]+>', '', text)
        text = "\n".join(x for x in text.split('\n')[:40])
        if len(text.split(' ')) > 1024:
            return " ".join(text.split(' ')[:1024])
        return text

    def _document_tokenizer(self, text: str) -> List[str]:
        """Document tokenization: clean and filter stopwords (original logic unchanged)"""
        cleaned_text = self._clean_text(text)
        symbol_pattern = r'[.,;:!?()"\'-](?=\s|$)'
        whitespace_pattern = r'[ \t\n]'
        tokens = re.split(f'{whitespace_pattern}|{symbol_pattern}', cleaned_text)
        valid_tokens = [token.lower() for token in tokens if token]
        return valid_tokens

    def _query_tokenizer(self, text: str) -> List[str]:
        """Query tokenization: preserve quoted content, filter stopwords (original logic unchanged)"""
        cleaned_text = self._clean_text(text)
        quote_pattern = r'(["\'])(.*?)\1'
        quoted_contents = [match.group(2) for match in re.finditer(quote_pattern, cleaned_text)]
        text_without_quotes = re.sub(quote_pattern, '', cleaned_text)
        symbol_pattern = r'[.,;:!?()"\'-](?=\s|$)'
        whitespace_pattern = r'[ \t\n]'
        tokens = re.split(f'{whitespace_pattern}|{symbol_pattern}', text_without_quotes)
        valid_tokens = [token.lower() for token in tokens if token and token.lower() not in self.stopwords]
        return valid_tokens + quoted_contents

    def read_tsv_documents(self, reactions_tsv: str, compounds_tsv: str) -> None:
        """Load documents from TSV (original logic unchanged)"""
        def _parse_single_tsv(tsv_path: str, source_type: str) -> List[Dict[str, str]]:
            docs = []
            if not os.path.exists(tsv_path):
                raise FileNotFoundError(f"TSV file does not exist: {tsv_path}")
            
            with open(tsv_path, 'r', encoding='utf-8') as f:
                reader = csv.DictReader(f, delimiter='\t')
                for row_idx, row in enumerate(reader):
                    doc_id = f"{source_type}_{row_idx:04d}"
                    
                    if source_type == "compounds":
                        filtered_cols = {col.replace('_', ' '): val for col, val in row.items() 
                                        if col not in ['mol_id', 'relevant_rxn', 'pubchem_id'] and val.strip()}
                        content = '\t'.join([f"{col}: {val}" for col, val in filtered_cols.items()])
                        source = f"{source_type}_{row.get('mol_id', f'row_{row_idx}')}"
                    elif source_type == "reactions":
                        filtered_cols = {col.replace('_', ' '): val for col, val in row.items() 
                                        if col not in ['rxn_id', 'source_patent'] and val.strip()}
                        content = '\t'.join([f"{col}: {val}" for col, val in filtered_cols.items()])
                        source = f"{source_type}_{row.get('rxn_id', f'row_{row_idx}')}"
                    
                    docs.append({
                        "id": doc_id,
                        "content": content,
                        "source": source,
                        "row_idx": row_idx
                    })
            
            print(f"Loaded {len(docs)} documents from {tsv_path}")
            return docs
        
        reactions_docs = _parse_single_tsv(reactions_tsv, source_type="reactions")
        compounds_docs = _parse_single_tsv(compounds_tsv, source_type="compounds")
        self.documents = reactions_docs + compounds_docs
        
        self.ids_to_doc_idx = {doc["id"]: idx for idx, doc in enumerate(self.documents)}
        self.doc_idx_to_ids = {idx: doc["id"] for idx, doc in enumerate(self.documents)}
        
        print(f"Total documents loaded: {len(self.documents)}")
        print(f"Built ID mapping table for {len(self.ids_to_doc_idx)} documents")

    def build_index(self) -> None:
        """Build inverted index (new pre-computation of average document length required for BM25)"""
        if not self.documents:
            raise ValueError("Documents not loaded, please call read_tsv_documents method first")
            
        self.doc_total_terms = []  # Reset document total terms
        self.inverted_index.clear()  # Reset index (avoid rebuilding)
        
        # 1. Traverse documents to build inverted index and document total terms (original logic unchanged)
        for doc in tqdm(self.documents, desc="Building inverted index"):
            doc_idx = self.ids_to_doc_idx[doc["id"]]
            tokens = self._document_tokenizer(doc["content"])
            self.doc_total_terms.append(len(tokens))  # Record current document total terms
            
            # Count term frequency
            term_freq = defaultdict(int)
            for token in tokens:
                term_freq[token] += 1
            
            # Update inverted index (term -> (document index, term frequency))
            for term, freq in term_freq.items():
                self.inverted_index[term].append((doc_idx, freq))
        
        # -------------------------- 2. Pre-compute average document length required for BM25 --------------------------
        # Compute only once, avoid repeated computation during each search (core efficiency optimization)
        total_terms_sum = sum(self.doc_total_terms)
        self.avg_doc_len = total_terms_sum / len(self.doc_total_terms) if len(self.doc_total_terms) > 0 else 0.0
        
        print(f"Index building completed | Documents: {len(self.documents)} | Terms: {len(self.inverted_index)} | Average document length: {self.avg_doc_len:.2f}")

    def _build_context(self, context_docs: List[Dict[str, Any]]) -> str:
        """Build context (original logic unchanged)"""
        max_tokens = 10240
        current_tokens = 0
        context_parts = []
        
        for doc in context_docs:
            cleaned_content = self._clean_text(doc["content"])
            doc_context = f"=== Document ID: {doc['id']} ===\nContent:\n{cleaned_content}"
            doc_token_count = len(doc_context.split(' '))
            
            if current_tokens + doc_token_count <= max_tokens:
                context_parts.append(doc_context)
                current_tokens += doc_token_count
            else:
                print(f"Warning: Skipping document {doc['id']} (exceeds context token limit)")
        
        return "\n\n".join(context_parts)

    def search(self, query: str, k: int = 5, scoring: str = "bm25") -> List[Dict[str, Any]]:
        """Retrieve Top-k documents (new BM25 scoring, retain original TF/TF-IDF)"""
        if not self.documents or len(self.inverted_index) == 0:
            raise ValueError("Index not built: please load documents and call build_index")
        
        query_tokens = self._query_tokenizer(query)
        if not query_tokens:
            return []
        
        doc_scores = defaultdict(float)
        total_docs = len(self.documents)
        
        # Traverse each query term, calculate document scores
        for token in query_tokens:
            if token not in self.inverted_index:
                continue  # Term not in index, skip
            
            postings = self.inverted_index[token]  # (document index, term frequency) list
            doc_freq = len(postings)  # Number of documents containing this term (DF)
            
            # -------------------------- 1. Calculate IDF (BM25 and TF-IDF have different IDF formulas) --------------------------
            if scoring == "bm25":
                # BM25 IDF smoothing formula (avoid division by 0 when DF=0)
                idf = math.log( (total_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0 )
            else:
                # Original TF-IDF IDF formula (maintain compatibility)
                idf = math.log((total_docs + 1) / (doc_freq + 1))
            
            # -------------------------- 2. Traverse postings to calculate each document's score --------------------------
            for doc_idx, term_freq in postings:
                doc_len = self.doc_total_terms[doc_idx]  # Current document total terms
                # Avoid division by 0 when document length is 0 (extreme case handling)
                doc_len = max(doc_len, 1)
                
                if scoring == "bm25":
                    # -------------------------- BM25 core scoring calculation --------------------------
                    # Document length normalization factor: 1 - b + b*(|D|/avg_doc_len)
                    len_norm = 1 - self.bm25_b + self.bm25_b * (doc_len / self.avg_doc_len)
                    # Term frequency term: TF*(k1+1)/(TF + k1*len_norm)
                    tf_term = (term_freq * (self.bm25_k1 + 1)) / (term_freq + self.bm25_k1 * len_norm)
                    # BM25 total score = IDF * term frequency term
                    doc_score = idf * tf_term
                
                elif scoring == "tfidf":
                    # Original TF-IDF scoring (maintain compatibility)
                    doc_score = term_freq * idf
                
                else:
                    # Original TF scoring (maintain compatibility)
                    doc_score = term_freq / doc_len
                
                # Accumulate current term's score for the document
                doc_scores[doc_idx] += doc_score
        
        # Sort by score in descending order, take Top-k
        sorted_doc_idx = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)[:k]
        return [
            {**self.documents[doc_idx], "score": round(score, 4)}
            for doc_idx, score in sorted_doc_idx
        ]

    def get_few_shot_examples(self, 
                             current_question: str, 
                             current_qa_type: str, 
                             current_input_type: str) -> str:
        """Generate few-shot examples of the same type (original logic unchanged)"""
        if self.qa_df is None:
            return ""
            
        candidate_mask = (
            (self.qa_df['qa_type'] == current_qa_type) &
            (self.qa_df['input_type'] == current_input_type) &
            (self.qa_df['question'] != current_question)
        )
        candidates = self.qa_df[candidate_mask].copy()
        
        if len(candidates) < self.num_few_shot:
            candidate_mask = (
                (self.qa_df['qa_type'] == current_qa_type) &
                (self.qa_df['question'] != current_question)
            )
            candidates = self.qa_df[candidate_mask].copy()
            print(f"Warning: Insufficient examples for {current_qa_type}+{current_input_type}, downgrading to same {current_qa_type} examples")
        
        if len(candidates) < self.num_few_shot:
            candidates = self.qa_df[self.qa_df['question'] != current_question].copy()
            print(f"Warning: Insufficient examples for {current_qa_type}, using examples of all types")
        
        num_select = min(self.num_few_shot, len(candidates))
        if num_select == 0:
            return ""
        
        selected = candidates.sample(n=num_select, random_state=42)
        examples_str = ""
        
        for idx, (_, row) in enumerate(selected.iterrows(), 1):
            example_output = json.dumps({
                "thinking": f"...",
                "answer": row['answer']
            }, ensure_ascii=False)
            
            examples_str += f"## Example {idx} (Thinking Process Omitted)\n"
            examples_str += f"Question:\n{row['question']}\n\n"
            examples_str += f"Output:\n{example_output}\n\n"
        
        return examples_str.strip()
    
    def generate_answer(self, query: str, context_docs: List[Dict[str, Any]], few_shot_examples: str = "") -> str:
        """Generate answer (original logic unchanged)"""
        context = self._build_context(context_docs)
        
        system_prompt = """/nothink You are an intelligent assistant. Answer query based on the given context (retrieved documents and few-shot examples).
# Constraints
- Take context as references to infer the answer. If no context explicitly points to the answer, derive the question-to-answer reasoning path from the given documents or examples, especially for similar compounds or reactions, and thereby draw an analogy to the current query.
- Output must be JSON with 'thinking' and 'answer', where 'thinking' is your step-by-step thinking process, and 'answer' should directly answer the given query.
- 'answer' should specify the compound name in specific format if required in the query or form numeric answer."""
        
        prompt = f"""# Retrieved Documents
{context}

# Few-shot Examples
{few_shot_examples}

# User Query
{query}

# Output
"""
        
        response = self.ans_client.chat.completions.create(
            model=self.model_name_or_path,
            messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
            top_p=0.4,
            max_tokens=4096,
        )
        return response.choices[0].message.content.strip()

    def retrieve_and_generate(self, query: str, k: int = 5, scoring: str = "bm25") -> tuple[str, List[Dict[str, Any]]]:
        """End-to-end RAG: retrieval + generation (default to BM25 scoring)"""
        context_docs = self.search(query, k, scoring)
        few_shot_examples = self.get_few_shot_examples(
            current_question=query,
            current_qa_type="",  # Actual use requires passing through predict method
            current_input_type=""
        )
        llm_answer = self.generate_answer(query, context_docs, few_shot_examples)
        return llm_answer, context_docs
    
    def predict(self, row: pd.Series, retrieve_k: int = 5, retrieve_scoring: str = "bm25") -> Dict[str, Any]:
        """RAG prediction for a single data point (default to BM25 scoring)"""
        few_shot_examples = self.get_few_shot_examples(
            current_question=row['question'],
            current_qa_type=row['qa_type'],
            current_input_type=row['input_type']
        )
        
        rag_answer, retrieved_docs = self.search_and_generate(
            query=row['question'],
            k=retrieve_k,
            scoring=retrieve_scoring,
            few_shot_examples=few_shot_examples
        )
        
        retrieved_suffixes = []
        for doc in retrieved_docs:
            if '_' in doc['source']:
                retrieved_suffixes.append(doc['source'].split('_')[1].strip())
        
        extracted_answer = extract_json(rag_answer)
        answer_short = extracted_answer.get('answer', rag_answer) if isinstance(extracted_answer, dict) else rag_answer
        
        return {
            "answer": rag_answer,
            "answer_short": answer_short,
            "retrieved_suffixes": ','.join(retrieved_suffixes),
            "num_few_shot": self.num_few_shot,
            "few_shot_available": len(few_shot_examples) > 0
        }
    
    def search_and_generate(self, query: str, k: int = 5, scoring: str = "bm25", few_shot_examples: str = "") -> tuple[str, List[Dict[str, Any]]]:
        """Search and generate answer (default to BM25 scoring)"""
        context_docs = self.search(query, k, scoring)
        llm_answer = self.generate_answer(query, context_docs, few_shot_examples)
        return llm_answer, context_docs
