
import ast
import os
import re
import json
import random
import itertools
from typing import List, Dict, Any, Set, Tuple, Optional
from functools import reduce
from rouge_score import rouge_scorer
from concurrent.futures import ThreadPoolExecutor, as_completed

import openai
import pandas as pd
from tqdm import tqdm

def extract_json(text: str) -> Dict[str, Any]:
    """Extract JSON content from text"""
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        match = re.search(r'\{.*\}', text, re.DOTALL)
        if match:
            try:
                return json.loads(match.group())
            except json.JSONDecodeError:
                return {}
        return {}


class StructuralRAG:
    """Structure-based RAG system for compounds, supporting specified compound name formats (SMILES/IUPAC)"""
    
    _FIRST_STEP_PROMPT = """/nothink You are a chemistry expert. Given a chemical query, generate a step-by-step reasoning path to solve it.
Requirements:
1. Each reasoning path must have clear step numbering (e.g., Step 1...;Step 2...;Step 3...).
2. Highlight specific chemical names and numerical values.
3. No answer at the end."""

    _SECOND_STEP_PROMPT = """/nothink You're a chemistry expert. Adapt the given reasoning path to strictly match the chemical record.
Steps:
1. Start by providing a json string with keys "answer" pertaining to the {aspect}, copied directly from the corresponding key of the record.
2. Then, replace the compound and reaction information in the reasoning path with data from the record.
Requirements:
1. Strictly copy the name and info from the record without modifications."""

    _THIRD_STEP_PROMPT = """/nothink You're a chemistry expert. Infer the answer to the query based on the given context.
Requirements:
1. If there is context highly matching the query, you should directly use the answer. Otherwise, take the reasoning paths as few-shot examples and try to find something in common between compounds in the reasoning paths and given query, then infer the answer with step-by-step thinking. 
2. If there is no valid context, infer the answer using your knowledge step-by-step.
3. Conclude with a JSON string with a key 'answer'. The "answer" should follow the format of the concise answer in the reasoning path. Unless the query is a description task, the "answer" should only consist of one or several words or numbers that indicates only one answer to the query directly."""

    _QUERY_EXTRACTION_PROMPT = """/nothink You are a chemistry expert. Given a chemical query, extract information into a JSON string with keys:
- 'level': the level involved in the query ('compound' or 'reaction')
- 'compounds': list of compound names in the query. If there is one compound, output a list with a single element. Directly extract names from the query.
- 'format': input format of compounds in the query ('smiles' or 'iupac'), NOT the output format.
- 'aspect': query target (e.g., weight, product, reactants, condition, name conversion (to IUPAC or SMILES), description, etc.)
Output only the JSON string."""

    def __init__(self, 
                 llm_base_url: str = "http://localhost:10000/v1",
                 llm_model_path: str = "/home/share/ckpt/Qwen3-8B",
                 llm_api_key: str = "EMPTY",
                 qa_df: Optional[pd.DataFrame] = None,
                 num_few_shot: int = 3,
                 batch_similar_top_n: int = 5,
                 max_threads: int = 5):
        """Initialize RAG system"""
        self._llm_client = openai.OpenAI(
            api_key=llm_api_key,
            base_url=llm_base_url
        )
        self._llm_model = llm_model_path
        self._qa_df = qa_df
        self._num_few_shot = num_few_shot
        self._kb_compound: Optional[pd.DataFrame] = None
        self._kb_reaction: Optional[pd.DataFrame] = None
        self._rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)
        self._batch_similar_top_n = batch_similar_top_n
        self._max_threads = max_threads
        
    def load_knowledge_base(self, compound_tsv: str, reaction_tsv: str) -> None:
        """Load compound and reaction knowledge base"""
        for path in [compound_tsv, reaction_tsv]:
            if not os.path.exists(path):
                raise FileNotFoundError(f"Knowledge base file does not exist: {path}")
        
        self._kb_compound = pd.read_csv(compound_tsv, sep='\t')
        self._validate_dataframe(self._kb_compound, 'compound', ['smiles', 'iupac', 'mol_id', 'relevant_rxn'])
        
        self._kb_reaction = pd.read_csv(reaction_tsv, sep='\t')
        self._validate_dataframe(self._kb_reaction, 'reaction', ['rxn_id'])
        
        self._kb_compound[['smiles', 'iupac', 'relevant_rxn']] = self._kb_compound[
            ['smiles', 'iupac', 'relevant_rxn']
        ].astype(str).fillna('')
        self._kb_compound['relevant_rxn'] = self._kb_compound['relevant_rxn'].apply(
            lambda x: re.sub(r'[,\s;]+', ',', x.strip())
        )

    def process_query(self, query: str, retrieve_k: int = 5, few_shot_examples: str = "") -> Tuple[str, List[Dict], List[str], str, str]:
        """Process user query and return reasoning results and related information"""
        query_info = self._extract_query_information(query)
        name_type = query_info.get('format', 'smiles').upper()
        if name_type not in ['SMILES', 'IUPAC']:
            name_type = 'SMILES'
        
        filtered_compounds = self._filter_compounds(
            compounds=query_info.get('compounds', []),
            level=query_info.get('level', ''),
            compound_format=query_info.get('format', ''),
            retrieve_k=retrieve_k
        )
        
        context, retrieved_ids = self._generate_context(
            filtered_compounds=filtered_compounds,
            level=query_info.get('level', ''),
            aspect=query_info.get('aspect', ''),
            retrieve_k=retrieve_k
        )
        
        reasoning_original = self._generate_first_step_reasoning(query, retrieve_k, few_shot_examples)
        original_paths = self._split_reasoning_paths(reasoning_original)
        
        reasoning_retrieved = self._generate_batch_second_step_reasoning(original_paths, context, name_type, aspect=query_info.get('aspect', ''))
        
        final_answer = self._generate_third_step_answer(query, reasoning_retrieved, few_shot_examples, json.dumps(context, ensure_ascii=False))
        
        return final_answer, context, retrieved_ids, reasoning_original, reasoning_retrieved

    def _tokenize_query(self, query: str) -> List[str]:
        """Simple tokenization: split only by spaces"""
        symbol_pattern = r'[.,;:!?()"\'-](?=\s|$)'
        whitespace_pattern = r'[ \t\n]'
        tokens = re.split(f'{whitespace_pattern}|{symbol_pattern}', query)
        valid_tokens = [token for token in tokens if token]
        return valid_tokens

    def _find_compounds_in_query(self, query: str, candidates: List[str]) -> List[str]:
        """Extract compound names from query"""
        if not candidates or not query:
            return []
            
        tokens = self._tokenize_query(query)
        if not tokens:
            return []
            
        unique_compounds = []
        used_indices = set()
        
        for candidate in candidates:
            if not isinstance(candidate, str) or not candidate.strip():
                continue
                
            candidate_parts = [p.strip() for p in candidate.split() if p.strip()]
            if not candidate_parts:
                continue
                
            best_score = 0.0
            best_token_idx = -1
            
            for idx, token in enumerate(tokens):
                if idx in used_indices:
                    continue
                    
                score = self._calculate_overlap(candidate_parts, [token])
                
                if score > best_score:
                    best_score = score
                    best_token_idx = idx
            
            if best_token_idx != -1 and best_score > 0.3:
                unique_compounds.append(tokens[best_token_idx])
                used_indices.add(best_token_idx)
        
        if not unique_compounds:
            return candidates
            
        return unique_compounds

    def _calculate_overlap(self, candidate_parts: List[str], tokens: List[str]) -> float:
        """Calculate overlap between candidate compounds and query tokens"""
        if not candidate_parts or not tokens:
            return 0.0
            
        max_score = 0.0
        for part in candidate_parts:
            for token in tokens:
                score = self._calculate_similarity(part, token)
                if score > max_score:
                    max_score = score
                    if score == 1.0:
                        return 1.0
                        
        return max_score

    def _calculate_similarity(self, text1: str, text2: str) -> float:
        """Calculate ROUGE-L similarity between two texts"""
        text1_clean = text1.lower().strip()
        text2_clean = text2.lower().strip()
        if not text1_clean or not text2_clean:
            return 0.0
        return self._rouge_scorer.score(text1_clean, text2_clean)['rougeL'].fmeasure

    def _clean_compound(self, text: str) -> str:
        """Clean compound text"""
        return text.lower().strip()

    def _split_reasoning_paths(self, combined_reasoning: str) -> List[str]:
        """Split reasoning paths"""
        if not combined_reasoning:
            return []
            
        pattern = r"Reasoning Path \d+"
        matches = list(re.finditer(pattern, combined_reasoning))
        
        if not matches:
            paths = [p.strip() for p in re.split(r'\n\s*\n', combined_reasoning) if p.strip()]
            return [f"Reasoning Path {i+1}\n{path}" for i, path in enumerate(paths)]
            
        paths = []
        for i, match in enumerate(matches):
            start = match.start()
            if i < len(matches) - 1:
                end = matches[i+1].start()
                path_text = combined_reasoning[start:end].strip()
            else:
                path_text = combined_reasoning[start:].strip()
            paths.append(path_text)
            
        return paths

    def _format_compound_list(self, compound_str: str) -> str:
        """
        Convert a compound list separated by "." to natural language format
        - Two compounds: Connect with "and"
        - Three or more: Separate with ", " and connect the last one with ", and "
        """
        if not compound_str:
            return ""
            
        # Split compound list and clean each element
        compounds = [comp.strip() for comp in compound_str.split('.') if comp.strip()]
        
        if len(compounds) == 1:
            return compounds[0]
        elif len(compounds) == 2:
            return f"{compounds[0]} and {compounds[1]}"
        else:
            return ", ".join(compounds[:-1]) + f", and {compounds[-1]}"

    def _generate_first_step_reasoning(self, query: str, retrieve_k: int, few_shot_examples: str = "") -> str:
        """Generate first step reasoning path"""
        try:
            formatted_prompt = self._FIRST_STEP_PROMPT
            response = self._llm_client.chat.completions.create(
                model=self._llm_model,
                messages=[
                    {"role": "system", "content": formatted_prompt},
                    {"role": "user", "content": f"# Current Query\n{query}"}
                ],
                top_p=0.4,
                max_tokens=4096,
                n=retrieve_k, 
            )

            if 'hkust' in self._llm_client.base_url.host:
                with open('usage.txt', 'a+') as f:
                    f.write(str(response.usage) + '\n')
            
            reasoning_paths = []
            for i, choice in enumerate(response.choices, 1):
                reasoning_text = choice.message.content.strip()
                if not re.match(rf"Reasoning Path {i}", reasoning_text, re.IGNORECASE):
                    reasoning_text = f"Reasoning Path {i}\n{reasoning_text}"
                reasoning_paths.append(reasoning_text)
            
            return "\n\n".join(reasoning_paths)
        
        except Exception as e:
            error_msg = f"First step reasoning failed: {str(e)}"
            print(error_msg)
            return f"{error_msg}\nFinal Answer: [[Failed]]"
    
    def _generate_single_second_step_reasoning(self, original_reasoning: str, context_item: Dict, aspect: str) -> str:
        """Generate a single adapted reasoning path, accepting context_item of dictionary type"""
        try:
            formatted_prompt = self._SECOND_STEP_PROMPT.format(aspect=aspect)
            # Convert dictionary to JSON string as chemical record
            context_json = json.dumps(context_item, ensure_ascii=False, indent=2)
            response = self._llm_client.chat.completions.create(
                model=self._llm_model,
                messages=[
                    {"role": "system", "content": formatted_prompt},
                    {"role": "user", "content": f"""# Original Reasoning Path
{original_reasoning}

# Chemical Record (JSON)
{context_json}

# Adapt the reasoning path to match the chemical record"""}
                ],
                top_p=0.4,
                max_tokens=4096
            )

            if 'hkust' in self._llm_client.base_url.host:
                with open('usage.txt', 'a+') as f:
                    f.write(str(response.usage) + '\n')
            
            return response.choices[0].message.content.strip()
        
        except Exception as e:
            error_msg = f"Second step reasoning failed for a path: {str(e)}"
            print(error_msg)
            return f"{error_msg}\nAdapted Reasoning: [[Failed]]"

    def _generate_batch_second_step_reasoning(self, original_paths: List[str], context_items: List[Dict], name_type: str, aspect: str) -> str:
        """Batch generate adapted reasoning paths, directly using dictionary list context_items"""
        if not original_paths or not context_items:
            return "No valid reasoning paths or context items for adaptation."
        
        paired_tasks = []
        max_length = min(len(original_paths), len(context_items))
        
        for i in range(max_length):
            path = original_paths[i % len(original_paths)]
            item = context_items[i % len(context_items)]
            paired_tasks.append((path, item, aspect))
        
        adapted_results = []
        with ThreadPoolExecutor(max_workers=self._max_threads) as executor:
            future_to_index = {
                executor.submit(self._generate_single_second_step_reasoning, path, item, aspect): i
                for i, (path, item, aspect) in enumerate(paired_tasks)
            }
            
            results = [None] * len(paired_tasks)
            for future in as_completed(future_to_index):
                index = future_to_index[future]
                try:
                    results[index] = future.result()
                except Exception as e:
                    print(f"Error processing task {index}: {str(e)}")
                    results[index] = f"Error adapting reasoning path {index+1}: {str(e)}"
            
            for i, result in enumerate(results):
                if result is not None:
                    adapted_results.append(f"Adapted Reasoning Path {i+1}\n{result}")
        
        return "\n\n".join(adapted_results)

    def _generate_third_step_answer(self, query: str, retrieved_reasoning: str, few_shot_examples: str, retrieved_records_json: str) -> str:
        """Generate final answer, accepting retrieved records in JSON string format"""
        try:
            response = self._llm_client.chat.completions.create(
                model=self._llm_model,
                messages=[
                    {"role": "system", "content": self._THIRD_STEP_PROMPT},
                    {"role": "user", "content": f"""# Retrieved Records for Your Reference
{retrieved_records_json}

# Reasoning Paths for Your Reference
{retrieved_reasoning}

# Current Query
{query}

# Your Reasoning and Answer to the Query"""}
                ],
                top_p=0.4,
                max_tokens=8192
            )

            if 'hkust' in self._llm_client.base_url.host:
                with open('usage.txt', 'a+') as f:
                    f.write(str(response.usage) + '\n')
            
            return response.choices[0].message.content.strip()
        
        except Exception as e:
            error_msg = f"Third step inference failed: {str(e)}"
            print(error_msg)
            return f"{error_msg}\nFinal Answer: [[Failed]]"

    def _validate_dataframe(self, df: pd.DataFrame, name: str, required_cols: List[str]) -> None:
        """Validate if dataframe contains required columns"""
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            raise ValueError(f"{name} knowledge base missing required columns: {missing_cols}")

    def _extract_query_information(self, query: str) -> Dict[str, Any]:
        """Extract query information"""
        try:
            response = self._llm_client.chat.completions.create(
                model=self._llm_model,
                messages=[
                    {"role": "system", "content": self._QUERY_EXTRACTION_PROMPT},
                    {"role": "user", "content": f"Query: {query}\nExtraction:"}
                ],
                top_p=0.4,
                max_tokens=1024
            )

            if 'hkust' in self._llm_client.base_url.host:
                with open('usage.txt', 'a+') as f:
                    f.write(str(response.usage) + '\n')
            
            extracted = extract_json(response.choices[0].message.content.strip())
            
            raw_candidates = extracted.get('compounds', [])
            name_format = extracted.get('format', 'iupac')
            if not isinstance(raw_candidates, list):
                raw_candidates = [raw_candidates]
            
            if name_format == 'smiles':
                cleaned_compounds = self._find_compounds_in_query(query, raw_candidates)
            else:
                cleaned_compounds = raw_candidates
            
            return {
                'level': extracted.get('level', 'reaction'),
                'compounds': cleaned_compounds,
                'format': name_format,
                'aspect': extracted.get('aspect', '')
            }
        except Exception as e:
            print(f"Query information extraction failed: {str(e)}")
            return {
                'level': 'others', 
                'compounds': raw_candidates, 
                'format': '', 
                'aspect': ''
            }

    def _find_similar_compounds(self, 
                               target: str, 
                               format_type: str, 
                               exclude_indices: Set[int], 
                               top_n: Optional[int] = None) -> pd.DataFrame:
        """Find similar compounds"""
        if self._kb_compound is None or format_type not in ['smiles', 'iupac']:
            return pd.DataFrame()
            
        actual_top_n = top_n if top_n is not None else self._batch_similar_top_n
        
        target_clean = self._clean_compound(target)
        temp_df = self._kb_compound.copy()
        temp_df['similarity'] = temp_df.apply(
            lambda row: self._calculate_similarity(target_clean, self._clean_compound(row[format_type])),
            axis=1
        )
        
        similar_df = temp_df[~temp_df.index.isin(exclude_indices)]
        return similar_df.sort_values('similarity', ascending=False)\
                        .head(actual_top_n)\
                        .drop(columns=['similarity'], errors='ignore')

    def _filter_compounds(self, 
                         compounds: List[str], 
                         level: str, 
                         compound_format: str, 
                         retrieve_k: int = 5) -> pd.DataFrame:
        """Filter compounds"""
        if not isinstance(compounds, list):
            compounds = []
        
        if compound_format not in ['smiles', 'iupac']:
            compound_format = 'iupac'
            
        if not compounds or self._kb_compound is None:
            return pd.DataFrame()
        
        compounds_clean = [self._clean_compound(comp) for comp in compounds if isinstance(comp, str) and comp.strip()]
        
        if not compounds_clean:
            return pd.DataFrame()
            
        temp_df = self._kb_compound.copy()
        temp_df['cleaned'] = temp_df[compound_format].apply(self._clean_compound)
        matched_mask = temp_df['cleaned'].isin(compounds_clean)
        matched_df = temp_df[matched_mask].drop(columns=['cleaned']).copy()
        matched_df['match'] = 'exact'
        matched_indices = set(matched_df.index)
        remaining = retrieve_k - len(matched_df)
        
        if remaining <= 0:
            return matched_df.head(retrieve_k)

        compound_similar_cache: Dict[int, pd.DataFrame] = {}
        for comp_idx, compound in enumerate(compounds):
            similar_batch = self._find_similar_compounds(
                target=compound,
                format_type=compound_format,
                exclude_indices=matched_indices,
                top_n=remaining
            )
            if not similar_batch.empty:
                compound_similar_cache[comp_idx] = similar_batch.reset_index(drop=True)

        used_similar_idx: Dict[int, int] = {comp_idx: 0 for comp_idx in compound_similar_cache.keys()}
        comp_idx_cycle = 0

        similar_compounds = []
        while remaining > 0 and compound_similar_cache:
            current_comp_idx = list(compound_similar_cache.keys())[comp_idx_cycle % len(compound_similar_cache)]
            current_similar_df = compound_similar_cache[current_comp_idx]
            current_used_idx = used_similar_idx[current_comp_idx]

            if current_used_idx >= len(current_similar_df):
                del compound_similar_cache[current_comp_idx]
                del used_similar_idx[current_comp_idx]
                comp_idx_cycle += 1
                continue

            selected_comp = current_similar_df.iloc[[current_used_idx]]
            selected_idx = selected_comp.index[0]

            if selected_idx not in matched_indices:
                similar_compounds.append(selected_comp)
                matched_indices.add(selected_idx)
                remaining -= 1

            used_similar_idx[current_comp_idx] += 1
            comp_idx_cycle += 1

        if similar_compounds:
            similar_df = pd.concat(similar_compounds, ignore_index=True)
            similar_df['match'] = 'similar'
            matched_df = pd.concat([matched_df, similar_df], ignore_index=True)

        return matched_df.head(retrieve_k)

    def _find_reaction_similar_compounds(self, 
                                        compounds: List[str], 
                                        format_type: str) -> pd.DataFrame:
        """Find similar compounds related to reactions"""
        similar_list = []
        for compound in compounds:
            similar_batch = self._find_similar_compounds(
                target=compound,
                format_type=format_type,
                exclude_indices=set()
            )
            if not similar_batch.empty:
                similar_list.append(similar_batch)
                
        if not similar_list:
            return pd.DataFrame()
        combined_df = pd.concat(similar_list, ignore_index=True)
        return combined_df.drop_duplicates(subset=[format_type])
    
    def flatten_dict(self, d):
        def is_dict_string(s):
            """
            Determine if a string can be converted to a dictionary
            
            Parameters:
                s: String to check
                
            Returns:
                Boolean: True if the string can be converted to a dictionary, False otherwise
            """
            if not isinstance(s, str):
                return False  # Ensure input is a string
            try:
                parsed = ast.literal_eval(s)
                return isinstance(parsed, dict)
            except:
                return False  # Return False for any exception
        result = {}
        for key, value in d.items():
            if isinstance(value, str) and is_dict_string(value):
                value = ast.literal_eval(value)

            if isinstance(value, dict):
                result.update(value)
            # Other types (including lists) are preserved directly
            else:
                result[key] = value
        return result
    
    def _generate_context(self, 
                         filtered_compounds: pd.DataFrame, 
                         level: str, 
                         aspect: str,
                         retrieve_k: int = 5) -> Tuple[List[Dict], List[str]]:
        """Generate context information, returning a list of dictionaries instead of a table string"""
        if filtered_compounds.empty or self._kb_reaction is None:
            return [], []
            
        retrieved_ids = [str(id_) for id_ in filtered_compounds['mol_id'].tolist()]
        
        if level == 'compound':
            base_context = self._format_compound_context(filtered_compounds)
        elif level == 'reaction':
            base_context, retrieved_ids = self._format_reaction_context(filtered_compounds, retrieve_k, retrieved_ids)
        else:  
            base_context = self._format_compound_context(filtered_compounds)

        base_context = [self.flatten_dict(x) for x in base_context]
        return base_context, retrieved_ids

    def _format_compound_context(self, compounds_df: pd.DataFrame) -> List[Dict]:
        """Format compound context as a list of dictionaries"""
        if 'match' not in compounds_df.columns:
            compounds_df = compounds_df.copy()
            compounds_df['match'] = 'unknown'
        
        exact_mask = compounds_df['match'] == 'exact'
        if exact_mask.any():
            filtered_df = compounds_df[exact_mask].copy()
        else:
            filtered_df = compounds_df.copy()
        
        # Remove unnecessary columns
        filtered_df = filtered_df.drop(columns=['relevant_rxn', 'pubchem_id'], errors='ignore')
        
        # Convert to list of dictionaries
        return filtered_df.to_dict('records')

    def _safe_convert_to_int(self, value: Any) -> Optional[int]:
        """Safely convert to integer"""
        if value is None:
            return None
            
        str_val = str(value).strip()
        
        if not str_val:
            return None
            
        if '.' in str_val:
            parts = str_val.split('.')
            if len(parts) == 2 and parts[1].isdigit() and int(parts[1]) == 0:
                str_val = parts[0]
        
        if str_val.isdigit():
            return int(str_val)
        return None

    def _format_reaction_context(self, 
                                compounds_df: pd.DataFrame, 
                                retrieve_k: int, 
                                retrieved_ids: List[str]) -> Tuple[List[Dict], List[str]]:
        """Format reaction context as a list of dictionaries, including compound information in reaction information"""
        # Collect reaction IDs and match types related to each compound
        rxn_sets = []
        match_types = []
        compound_info_map = {}  # Mapping to store compound information
        
        # Define compound fields to include
        compound_columns = ['mol_id', 'iupac', 'smiles', 'match']
        available_compound_cols = [col for col in compound_columns if col in compounds_df.columns]
        
        for _, row in compounds_df.iterrows():
            # Extract and store compound information
            comp_info = {col: str(row[col]) for col in available_compound_cols}
            comp_id = str(row.get('mol_id', 'unknown'))
            compound_info_map[comp_id] = comp_info
            
            # Process reaction IDs
            rxns = str(row.get('relevant_rxn', '')).split(',')
            valid_rxns = set()
            for r in rxns:
                rxn_id = self._safe_convert_to_int(r.strip())
                if rxn_id is not None:
                    valid_rxns.add(str(rxn_id))
            
            if valid_rxns:
                rxn_sets.append(valid_rxns)
                match_types.append(row.get('match', 'similar'))
                    
        if not rxn_sets:
            return [], retrieved_ids
                
        # Get candidate reactions and their match types
        candidate_rxns, candidate_matches = self._collect_reaction_candidates(rxn_sets, match_types, retrieve_k)
        
        # All candidate IDs for recall calculation
        all_candidate_ids = [self._safe_convert_to_int(r) for r in candidate_rxns]
        all_candidate_ids = [str(id_) for id_ in all_candidate_ids if id_ is not None]
        
        # Check if there are exact matches
        has_exact = any(match == "exact" for match in candidate_matches)
        
        # Determine which reaction IDs to use: take only exact matches if available, otherwise take all
        if has_exact:
            selected_rxns = [rxn for rxn, match in zip(candidate_rxns, candidate_matches) if match == "exact"]
        else:
            selected_rxns = candidate_rxns
        
        # Convert to valid IDs and filter
        valid_rxn_ids = [self._safe_convert_to_int(r) for r in selected_rxns]
        valid_rxn_ids = [id_ for id_ in valid_rxn_ids if id_ is not None]

        if not valid_rxn_ids:
            return [], all_candidate_ids

        # Sort reactions by candidate order
        self._kb_reaction['rxn_id_cat'] = pd.Categorical(
            self._kb_reaction['rxn_id'],
            categories=valid_rxn_ids,
            ordered=True
        )

        # Filter and sort
        rxn_mask = self._kb_reaction['rxn_id'].isin(valid_rxn_ids)
        result_rxns = self._kb_reaction[rxn_mask].sort_values(by='rxn_id_cat').drop(
            columns=['source_text', 'source_patent', 'remark', 'rxn_id_cat'],
            errors='ignore'
        )
        
        if result_rxns.empty:
            return [], all_candidate_ids
                
        # Build a list of reaction dictionaries with compound information and match information
        reaction_list = []
        # List of fields to format
        fields_to_format = ['reactants', 'products', 'solvents', 'catalysts']
        
        for _, rxn_row in result_rxns.iterrows():
            rxn_dict = rxn_row.to_dict()
            
            # Format compound lists in specified fields as natural language
            for field in fields_to_format:
                if field in rxn_dict:
                    rxn_dict[field] = self._format_compound_list(str(rxn_dict[field]))
            
            rxn_id = str(rxn_dict['rxn_id'])
            rel_match_types = []
            related_compounds = []
            
            # Collect related compound information
            for i, rxn_set in enumerate(rxn_sets):
                if rxn_id in rxn_set:
                    rel_match_types.append(match_types[i])
                    # Get the corresponding compound ID
                    comp_row = compounds_df.iloc[i]
                    comp_id = str(comp_row.get('mol_id', 'unknown'))
                    # Get information for this compound
                    if comp_id in compound_info_map:
                        related_compounds.append(compound_info_map[comp_id])
            
            # Add compound information and match information
            rxn_dict['compound_info'] = related_compounds
            rxn_dict['match_info'] = ', '.join(set(rel_match_types)) if rel_match_types else "N/A"
            
            reaction_list.append(rxn_dict)
        
        return reaction_list, all_candidate_ids
        
    def _collect_reaction_candidates(self, rxn_sets: List[Set[str]], match_types: List[str], target_count: int) -> Tuple[List[str], List[str]]:
        """Collect reaction candidates and return corresponding match types
        Prioritize building overlapping IDs from exact match reaction sets, and supplement with overlapping IDs from all reaction sets if insufficient (deduplicated)
        """
        num_sets = len(rxn_sets)
        candidates = []
        candidate_matches = []  # Store match types for each candidate
        added = set()
        
        # Separate reaction sets corresponding to exact matches
        exact_rxn_sets = [rxn_sets[i] for i in range(len(rxn_sets)) if match_types[i] == 'exact']
        has_exact_sets = len(exact_rxn_sets) > 0
        
        # Phase 1: Build overlapping IDs from exact match reaction sets
        if has_exact_sets:
            for subset_size in range(len(exact_rxn_sets), 1, -1):
                for combination in itertools.combinations(range(len(exact_rxn_sets)), subset_size):
                    subset_rxn_sets = [exact_rxn_sets[i] for i in combination]
                    overlapping_rxns = reduce(set.intersection, subset_rxn_sets)
                    
                    new_rxns = [rxn for rxn in overlapping_rxns if rxn not in added]
                    if new_rxns:
                        need = min(target_count - len(candidates), len(new_rxns))
                        candidates.extend(new_rxns[:need])
                        candidate_matches.extend(["exact"] * need)
                        added.update(new_rxns[:need])
                        
                        if len(candidates) >= target_count:
                            return candidates[:target_count], candidate_matches[:target_count]
                
                if len(candidates) >= target_count:
                    return candidates[:target_count], candidate_matches[:target_count]
        
        # Calculate how many more are needed
        remaining = target_count - len(candidates)

        if remaining > 0 and len(exact_rxn_sets) == 1:
            candidates.extend(exact_rxn_sets[0])
            candidate_matches.extend(len(exact_rxn_sets[0]) * ['exact'])
            remaining = target_count - len(exact_rxn_sets[0])
        
        # Phase 2: If overlapping IDs from exact matches are insufficient, build overlapping IDs from all reaction sets (deduplicated)
        if remaining > 0:
            for subset_size in range(num_sets, 1, -1):
                for combination in itertools.combinations(range(num_sets), subset_size):
                    subset_rxn_sets = [rxn_sets[i] for i in combination]
                    overlapping_rxns = reduce(set.intersection, subset_rxn_sets)
                    
                    new_rxns = [rxn for rxn in overlapping_rxns if rxn not in added]
                    if new_rxns:
                        need = min(remaining, len(new_rxns))
                        candidates.extend(new_rxns[:need])
                        
                        candidate_matches.extend(["similar"] * need)
                        added.update(new_rxns[:need])
                        remaining -= need
                        
                        if remaining <= 0:
                            return candidates[:target_count], candidate_matches[:target_count]
                
                if remaining <= 0:
                    return candidates[:target_count], candidate_matches[:target_count]
        
        # Phase 3: Process exact match sets, labeled as "similar"
        exact_rxn_sets_all = [rxn_sets[i] for i, match_type in enumerate(match_types) if match_type == 'exact']
        similar_rxn_sets = [rxn_sets[i] for i, match_type in enumerate(match_types) if match_type == 'similar']
        
        if exact_rxn_sets_all:
            all_exact_rxns = [rxn for rxn_set in exact_rxn_sets_all for rxn in rxn_set]
            unique_exact_rxns = []
            seen_exact = set()
            for rxn in all_exact_rxns:
                if rxn not in seen_exact and rxn not in added:
                    seen_exact.add(rxn)
                    unique_exact_rxns.append(rxn)
            
            if remaining > 0 and unique_exact_rxns:
                take = min(remaining, len(unique_exact_rxns))
                candidates.extend(unique_exact_rxns[:take])
                candidate_matches.extend(["similar"] * take)
                added.update(unique_exact_rxns[:take])
                remaining -= take
        
        # Phase 4: Process similar match sets, labeled as "similar"
        if remaining > 0 and similar_rxn_sets:
            all_similar_rxns = [rxn for rxn_set in similar_rxn_sets for rxn in rxn_set]
            unique_similar_rxns = []
            seen_similar = set()
            for rxn in all_similar_rxns:
                if rxn not in seen_similar and rxn not in added:
                    seen_similar.add(rxn)
                    unique_similar_rxns.append(rxn)
            
            if remaining > 0 and unique_similar_rxns:
                take = min(remaining, len(unique_similar_rxns))
                candidates.extend(unique_similar_rxns[:take])
                candidate_matches.extend(["similar"] * take)
                remaining -= take
        
        return candidates[:target_count], candidate_matches[:target_count]
    
    def predict(self, row: pd.Series, retrieve_k: int = 5) -> Dict[str, Any]:
        """Prediction interface"""
        try:
            rxn_id = self._safe_convert_to_int(row.get('rxn_id'))
            mol_id = self._safe_convert_to_int(row.get('mol_id'))
            
            row['rxn_id'] = str(rxn_id) if rxn_id is not None else None
            row['mol_id'] = str(mol_id) if mol_id is not None else None

            few_shot_examples = self._get_few_shot_examples(
                current_question=row.get('question', ''),
                current_qa_type=row.get('qa_type', ''),
                current_input_type=row.get('input_type', '')
            )

            final_answer, context, retrieved_ids, reasoning_original, reasoning_retrieved = self.process_query(
                query=row.get('question', ''),
                retrieve_k=retrieve_k,
                few_shot_examples=few_shot_examples
            )

            answer_short = extract_json(final_answer).get('answer', '')

            return {
                "answer": final_answer,
                "answer_short": answer_short,
                "reasoning_original": reasoning_original,
                "reasoning_retrieved": reasoning_retrieved,
                "context": context,
                "retrieved_suffixes": ','.join(retrieved_ids),
                "num_few_shot": self._num_few_shot,
                "few_shot_available": len(few_shot_examples) > 0
            }

        except Exception as e:
            error_msg = str(e)
            print(f"Error in rxn_id={row.get('rxn_id')}, mol_id={row.get('mol_id')}: {error_msg}")
            return {
                "answer": f"Prediction error: {error_msg}",
                "answer_short": "Failed Prediction",
                "reasoning_original": "",
                "reasoning_retrieved": "",
                "context": "",
                "retrieved_suffixes": "",
                "num_few_shot": self._num_few_shot,
                "few_shot_available": False
            }

    def _get_few_shot_examples(self, current_question: str, current_qa_type: str, current_input_type: str) -> str:
        """Get few-shot examples"""
        if self._qa_df is None:
            return ""

        candidate_mask = (
            (self._qa_df['qa_type'] == current_qa_type) &
            (self._qa_df['input_type'] == current_input_type) &
            (self._qa_df['question'] != current_question)
        )
        candidates = self._qa_df[candidate_mask].copy()

        if len(candidates) < self._num_few_shot:
            candidate_mask = (
                (self._qa_df['qa_type'] == current_qa_type) &
                (self._qa_df['question'] != current_question)
            )
            candidates = self._qa_df[candidate_mask].copy()

        if len(candidates) < self._num_few_shot:
            candidates = self._qa_df[self._qa_df['question'] != current_question].copy()

        num_select = min(self._num_few_shot, len(candidates))
        if num_select == 0:
            return ""

        selected = candidates.sample(n=num_select, random_state=42)
        examples_str = ""
        for idx, (_, row) in enumerate(selected.iterrows(), 1):
            example_output = json.dumps({
                "answer": row['answer']
            }, ensure_ascii=False)

            examples_str += f"## Example {idx}\n"
            examples_str += f"Query:\n{row['question']}\n\n"
            examples_str += f"Expected Answer (reasoning process omitted):\n{example_output}\n\n"

        return examples_str.strip()

if __name__ == '__main__':
    # Initialize RAG system
    rag = StructuralRAG(
        batch_similar_top_n=20,
        llm_base_url= "http://localhost:10000/v1",
        llm_model_path= "/home/share/ckpt/Qwen3-8B",
        llm_api_key="xxxx",
        max_threads=5
    )
    
    # Load knowledge base (please replace with actual paths)
    rag.load_knowledge_base(
        compound_tsv='kb_and_qas/compounds_samples.tsv',
        reaction_tsv='kb_and_qas/reactions_samples.tsv'
    )
    
    # Process example query
    query = "Find the SMILES string corresponding to triphenylphosphane."
    final_answer, context, retrieved_ids, reasoning_original, reasoning_retrieved = rag.process_query(
        query=query, 
        retrieve_k=5, 
        few_shot_examples="N/A"
    )
    
    print("=== Final Answer ===")
    print(final_answer)
