import os
import math
import re
import json
import csv
import string
import random
import requests
import numpy as np
from typing import List, Dict, Any, Set, Union
from collections import defaultdict

import pandas as pd 
from tqdm import tqdm
import openai


class DenseVectorRAG:
    """RAG implementation based on dense vector retrieval, maintaining original logic without stopwords"""
    
    def __init__(self, 
            base_url: str,
            model_name_or_path: str,
            vector_server_url: str,  # Vector encoding server address
            api_key: str = "EMPTY",
            qa_df: pd.DataFrame = None,
            num_few_shot: int = 3):
        """Initialize RAG system based on dense vector retrieval (without stopwords)"""
        # Document storage and ID mapping
        self.documents = []
        self.ids_to_doc_idx = {}
        self.doc_idx_to_ids = {}
        
        # Dense vector core storage
        self.document_vectors = []
        self.vector_dim = None
        
        # Configuration parameters
        self.vector_server_url = vector_server_url
        self.api_key = api_key
        self.base_url = base_url
        self.model_name_or_path = model_name_or_path
        self.ans_client = openai.OpenAI(api_key=api_key, base_url=base_url)
        
        # Few-shot configuration (maintaining original logic)
        self.qa_df = qa_df
        self.num_few_shot = num_few_shot

    def _clean_text_light(self, text: str) -> str:
        """Lightweight text cleaning: only remove XML tags and overly long content (maintaining original implementation)"""
        text = re.sub(r'<[^>]+>', '', text)
        text = "\n".join(x for x in text.split('\n')[:40])
        if len(text.split(' ')) > 2048:
            return " ".join(text.split(' ')[:2048])
        return text

    def encode_text_to_vector(self, text: Union[str, List[str]], batch_size: int = 64) -> np.ndarray:
        """Function to call vector server for encoding (maintaining original implementation)"""
        # Ensure input is in list format
        if isinstance(text, str):
            texts = [text]
        else:
            texts = text
        
        # Validate input
        if not isinstance(texts, list) or not all(isinstance(t, str) for t in texts):
            raise ValueError("Input must be a string or list of strings")
        
        # Split into batches
        batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
        all_embeddings = []
        is_single_text = len(texts) == 1  # Check if it's a single text
        
        try:
            # Batch processing: single text doesn't show progress bar, batch processing does
            if is_single_text:
                # Process single text directly without progress bar
                batch = batches[0]
                response = requests.post(
                    self.vector_server_url,
                    json={"texts": batch},
                    timeout=600
                )
                if response.status_code != 200:
                    raise Exception(f"Vector service request failed: {response.text}")
                result = response.json()
                all_embeddings.append(np.array(result['embeddings'], dtype=np.float32))
            else:
                # Batch processing with progress bar
                for batch in tqdm(batches, desc="Encoding text vectors (batch processing)", leave=False):
                    response = requests.post(
                        self.vector_server_url,
                        json={"texts": batch},
                        timeout=600
                    )
                    if response.status_code != 200:
                        raise Exception(f"Vector service request failed (batch): {response.text}")
                    result = response.json()
                    all_embeddings.append(np.array(result['embeddings'], dtype=np.float32))
            
            # Merge all batch results
            embeddings = np.vstack(all_embeddings)
            
            # Set vector dimension
            if self.vector_dim is None and len(embeddings) > 0:
                self.vector_dim = embeddings.shape[1]
                
            return embeddings
            
        except Exception as e:
            raise RuntimeError(f"Failed to encode text to vector: {str(e)}")

    def read_tsv_documents(self, reactions_tsv: str, compounds_tsv: str) -> None:
        """Load documents from TSV (maintaining original implementation)"""
        def _parse_single_tsv(tsv_path: str, source_type: str) -> List[Dict[str, str]]:
            docs = []
            if not os.path.exists(tsv_path):
                raise FileNotFoundError(f"TSV file does not exist: {tsv_path}")
            
            with open(tsv_path, 'r', encoding='utf-8') as f:
                reader = csv.DictReader(f, delimiter='\t')
                for row_idx, row in enumerate(reader):
                    doc_id = f"{source_type}_{row_idx:04d}"
                    
                    if source_type == "compounds":
                        filtered_cols = {col.replace('_', ' '): val for col, val in row.items() 
                                        if col not in ['mol_id', 'relevant_rxn', 'pubchem_id'] and val.strip()}
                        content = '\t'.join([f"{col}: {val}" for col, val in filtered_cols.items()])
                        source = f"{source_type}_{row.get('mol_id', f'row_{row_idx}')}"
                    elif source_type == "reactions":
                        filtered_cols = {col.replace('_', ' '): val for col, val in row.items() 
                                        if col not in ['rxn_id', 'source_patent'] and val.strip()}
                        content = '\t'.join([f"{col}: {val}" for col, val in filtered_cols.items()])
                        source = f"{source_type}_{row.get('rxn_id', f'row_{row_idx}')}"
                    
                    cleaned_content = self._clean_text_light(content)
                    docs.append({
                        "id": doc_id,
                        "content": cleaned_content,
                        "source": source,
                        "row_idx": row_idx
                    })
            
            print(f"Loaded {len(docs)} documents from {tsv_path}")
            return docs
        
        reactions_docs = _parse_single_tsv(reactions_tsv, source_type="reactions")
        compounds_docs = _parse_single_tsv(compounds_tsv, source_type="compounds")
        self.documents = reactions_docs + compounds_docs
        
        self.ids_to_doc_idx = {doc["id"]: idx for idx, doc in enumerate(self.documents)}
        self.doc_idx_to_ids = {idx: doc["id"] for idx, doc in enumerate(self.documents)}
        
        print(f"Total documents loaded: {len(self.documents)}")

    def build_index(self) -> None:
        """Build dense vector index (maintaining original implementation)"""
        if not self.documents:
            raise ValueError("Please call read_tsv_documents to load documents first")
        
        print(f"Starting document vector encoding (total {len(self.documents)} documents)...")
        
        # Extract all document texts and encode
        doc_texts = [doc["content"] for doc in self.documents]
        self.document_vectors = self.encode_text_to_vector(doc_texts)
        
        # Verify vector count
        if len(self.document_vectors) != len(self.documents):
            raise ValueError(f"Vector encoding count mismatch: documents {len(self.documents)}, vectors {len(self.document_vectors)}")
        
        print(f"Vector index building completed | Documents: {len(self.documents)} | Vector dimension: {self.vector_dim}")

    def _build_context(self, context_docs) -> str:
        """Build LLM input context (maintaining original implementation)"""
        max_tokens = 10240
        current_tokens = 0
        context_parts = []
        
        for doc in context_docs:
            doc_context = f"=== Document ID: {doc['id']} ===\nContent:\n{doc['content']}"
            doc_token_count = len(doc_context.split(' '))
            
            if current_tokens + doc_token_count <= max_tokens:
                context_parts.append(doc_context)
                current_tokens += doc_token_count
            else:
                print(f"Warning: Skipping document {doc['id']} (exceeds context token limit)")
        
        return "\n\n".join(context_parts)

    def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
        """Dense vector retrieval: dot product similarity calculation (maintaining original implementation)"""
        if not self.documents or len(self.document_vectors) == 0:
            raise ValueError("Index not built: please load documents and call build_index first")
        
        # Encode query vector
        query_vector = self.encode_text_to_vector(query)

        # Calculate dot product similarity
        similarities = np.dot(self.document_vectors, query_vector.T).reshape(-1)
        
        # Get Top-k documents
        top_k_indices = np.argsort(similarities)[-k:][::-1]
        
        top_k_docs = []
        for idx in top_k_indices:
            doc = self.documents[idx]
            top_k_docs.append({
                **doc,
                "score": round(float(similarities[idx]), 4)
            })
        
        return top_k_docs

    def generate_answer(self, query: str, context_docs: List[Dict[str, Any]], few_shot_examples: str = "") -> str:
        """Generate LLM answer (maintaining original implementation)"""
        context = self._build_context(context_docs)
        
        system_prompt = """/nothink You are an intelligent assistant. Answer query based on the given context (retrieved documents and few-shot examples).
# Constraints
- Take context as references to infer the answer. If no context explicitly points to the answer, derive the question-to-answer reasoning path from the given documents or examples, especially for similar compounds or reactions, and thereby draw an analogy to the current query.
- Output must be JSON with 'thinking' and 'answer', where 'thinking' is your thinking process, and 'answer' should directly answer the given query in one or a few words.
- 'answer' should specify the compound name in specific format if required in the query or form numeric answer."""
        
        prompt = f"""# Retrieved Documents
{context}

# Few-shot Examples
{few_shot_examples}

# User Query
{query}

# Output
"""
        
        response = self.ans_client.chat.completions.create(
            model=self.model_name_or_path,
            messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
            top_p=0.4,
            max_tokens=4096,
        )
        return response.choices[0].message.content.strip()

    def get_few_shot_examples(self, 
                             current_question: str, 
                             current_qa_type: str, 
                             current_input_type: str) -> str:
        """Generate few-shot examples of the same type (maintaining original logic)"""
        if self.qa_df is None:
            return ""
            
        # Priority filtering: same type + different question
        candidate_mask = (
            (self.qa_df['qa_type'] == current_qa_type) &
            (self.qa_df['input_type'] == current_input_type) &
            (self.qa_df['question'] != current_question)
        )
        candidates = self.qa_df[candidate_mask].copy()
        
        # Downgrade filtering: same qa_type
        if len(candidates) < self.num_few_shot:
            candidate_mask = (
                (self.qa_df['qa_type'] == current_qa_type) &
                (self.qa_df['question'] != current_question)
            )
            candidates = self.qa_df[candidate_mask].copy()
            print(f"Warning: Insufficient {current_qa_type}+{current_input_type} examples, downgrading to same {current_qa_type} examples")
        
        # Final filtering: all different questions
        if len(candidates) < self.num_few_shot:
            candidates = self.qa_df[self.qa_df['question'] != current_question].copy()
            print(f"Warning: Insufficient {current_qa_type} examples, using all type examples")
        
        # Generate example string
        num_select = min(self.num_few_shot, len(candidates))
        if num_select == 0:
            return ""
        
        selected = candidates.sample(n=num_select, random_state=42)
        examples_str = ""
        
        for idx, (_, row) in enumerate(selected.iterrows(), 1):
            # Build standard JSON example
            example_output = json.dumps({
                "thinking": f"...",
                "answer": row['answer']
            }, ensure_ascii=False)
            
            examples_str += f"## Example {idx} (Thinking Process Omitted)\n"
            examples_str += f"Question:\n{row['question']}\n\n"
            examples_str += f"Output:\n{example_output}\n\n"
        
        return examples_str.strip()

    def retrieve_and_generate(self, query: str, k: int = 5, few_shot_examples: str = "") -> tuple[str, List[Dict[str, Any]]]:
        """End-to-end RAG process (maintaining original implementation)"""
        context_docs = self.search(query, k)
        llm_answer = self.generate_answer(query, context_docs, few_shot_examples)
        return llm_answer, context_docs
    
    def predict(self, row: pd.Series, retrieve_k: int = 5) -> Dict[str, Any]:
        """Perform RAG prediction on a single data point, returning standardized results (interface adaptation only, without changing core logic)"""
        # Generate few-shot examples
        few_shot_examples = self.get_few_shot_examples(
            current_question=row['question'],
            current_qa_type=row['qa_type'],
            current_input_type=row['input_type']
        )
        
        # Get RAG answer
        rag_answer, retrieved_docs = self.retrieve_and_generate(
            query=row['question'],
            k=retrieve_k,
            few_shot_examples=few_shot_examples
        )
        
        # Extract retrieved document identifiers (for subsequent evaluator to calculate Recall)
        retrieved_suffixes = []
        for doc in retrieved_docs:
            if '_' in doc['source']:
                retrieved_suffixes.append(doc['source'].split('_')[1].strip())
        
        # Extract answer
        from eval_utils import extract_json
        extracted_answer = extract_json(rag_answer)
        answer_short = extracted_answer.get('answer', rag_answer) if isinstance(extracted_answer, dict) else rag_answer
        
        return {
            "answer": rag_answer,
            "answer_short": answer_short,
            "retrieved_suffixes": ','.join(retrieved_suffixes),
            "num_few_shot": self.num_few_shot,
            "few_shot_available": len(few_shot_examples) > 0
        }
