import os
import re
import json
import pandas as pd
from typing import List, Dict, Any, Set

import openai
from tqdm import tqdm

# Import utility functions
from eval_utils import extract_json


class DirectInference:
    """Direct LLM inference implementation without retrieval augmentation"""
    
    def __init__(self, 
                 base_url: str,
                 model_name_or_path: str,
                 api_key: str = "EMPTY",
                 qa_df: pd.DataFrame = None,
                 num_few_shot: int = 3):
        """Initialize direct inference model"""
        # LLM client configuration
        self.api_key = api_key
        self.base_url = base_url
        self.model_name_or_path = model_name_or_path
        self.ans_client = openai.OpenAI(api_key=api_key, base_url=base_url)
        
        # Few-shot configuration
        self.qa_df = qa_df
        self.num_few_shot = num_few_shot
    
    def get_few_shot_examples(self, 
                             current_question: str, 
                             current_qa_type: str, 
                             current_input_type: str) -> str:
        """Generate few-shot examples of the same type"""
        if self.qa_df is None:
            return ""
            
        # Priority filtering: same type + different question
        candidate_mask = (
            (self.qa_df['qa_type'] == current_qa_type) &
            (self.qa_df['input_type'] == current_input_type) &
            (self.qa_df['question'] != current_question)
        )
        candidates = self.qa_df[candidate_mask].copy()
        
        # Downgrade filtering: same qa_type
        if len(candidates) < self.num_few_shot:
            candidate_mask = (
                (self.qa_df['qa_type'] == current_qa_type) &
                (self.qa_df['question'] != current_question)
            )
            candidates = self.qa_df[candidate_mask].copy()
            print(f"Warning: Insufficient examples for {current_qa_type}+{current_input_type}, downgrading to same {current_qa_type} examples")
        
        # Final filtering: all different questions
        if len(candidates) < self.num_few_shot:
            candidates = self.qa_df[self.qa_df['question'] != current_question].copy()
            print(f"Warning: Insufficient examples for {current_qa_type}, using examples of all types")
        
        # Generate example string
        num_select = min(self.num_few_shot, len(candidates))
        if num_select == 0:
            return ""
        
        selected = candidates.sample(n=num_select, random_state=42)
        examples_str = "# Few-shot Examples\n"
        
        for idx, (_, row) in enumerate(selected.iterrows(), 1):
            # Build standard JSON example
            example_output = json.dumps({
                "thinking": f"...",
                "answer": row['answer']
            }, ensure_ascii=False)
            
            examples_str += f"## Example {idx} (Thinking Process Omitted)\n"
            examples_str += f"Question:\n{row['question']}\n\n"
            examples_str += f"Output:\n{example_output}\n\n"
        
        return examples_str.strip()
    
    def generate_answer(self, query: str, few_shot_examples: str = "") -> str:
        """Generate answer (with few-shot examples support)"""
        # System prompt (with few-shot examples)
        system_prompt = """/nothink You are an intelligent assistant. Answer query based on the given few-shot examples.
# Constraints
- If you don't know the answer, make the best guess based on your knowledge.
- Output must be JSON with 'thinking' and 'answer', where 'thinking' is your thinking process, and 'answer' should directly answer the given query in one or a few words.
- 'answer' should specify the compound name or form numeric answer. For question requires you to transform something into SMILES, the 'answer' should output SMILES format.
"""
        
        # Build user prompt
        prompt = f"""
{few_shot_examples}

# User Query
{query}

# Output
"""
        
        # Call LLM
        response = self.ans_client.chat.completions.create(
            model=self.model_name_or_path,
            messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
            top_p=0.4,
            max_tokens=4096,
        )
        if 'hkust' in self.ans_client.base_url.host:
            with open('usge.txt', 'a+') as f:
                f.write(str(response.usage) + '\n')
        return response.choices[0].message.content.strip()
    
    def predict(self, row: pd.Series) -> Dict[str, Any]:
        """Perform direct LLM prediction on a single data point, return standardized results"""
        # Generate few-shot examples
        few_shot_examples = self.get_few_shot_examples(
            current_question=row['question'],
            current_qa_type=row['qa_type'],
            current_input_type=row['input_type']
        )
        
        # Get LLM answer
        llm_answer = self.generate_answer(
            query=row['question'],
            few_shot_examples=few_shot_examples
        )
        
        # Extract answer
        extracted_answer = extract_json(llm_answer)
        answer_short = extracted_answer.get('answer', llm_answer) if isinstance(extracted_answer, dict) else llm_answer
        
        return {
            "answer": llm_answer,
            "answer_short": answer_short,
            "num_few_shot": self.num_few_shot,
            "few_shot_available": len(few_shot_examples) > 0
        }
