import os
import re
import json
import torch
import openai
from openai import OpenAI
from prompts import *
from typing import List, Dict, Optional, Any
import requests
import time
from urllib.parse import urlparse
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer


def get_response_from_llm(
        messages: List[Dict[str, Any]],
        client: openai.OpenAI,
        model: str,
        enable_thinking: bool,
        max_tokens: int = 2048,
        temperature: Optional[float] = 0.6,
        top_p: Optional[float] = 0.7,
        stop: Optional[str] = None,
        include_stopword: bool = True,
        depth: int = 0,
        n: int = 1
):
    try:
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            stop=stop,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=0.95,
            extra_body={"top_k": 20,
                        "chat_template_kwargs": {"enable_thinking": enable_thinking},
                        "include_stop_str_in_output": include_stopword},
            n=n
        )
        if n == 1:
            content = response.choices[0].message.content
            return content.strip()
        elif n > 1:
            content = [i.message.content.strip() for i in response.choices]
            return content
    
    except Exception as e:
        print(f"LLM API error: {e}")
        if depth < 512:
            time.sleep(1)
            return get_response_from_llm(messages=messages, 
                                         client=client, 
                                         model=model, 
                                         enable_thinking=enable_thinking, 
                                         temperature=temperature, 
                                         top_p=top_p,
                                         max_tokens=max_tokens, 
                                         repetition_penalty=repetition_penalty, 
                                         stop=stop, 
                                         include_stopword=include_stopword, 
                                         depth=depth+1, 
                                         n=n)
        raise e


def extract_answer(output, mode='gen'):
    extracted_text = ''
    if mode == 'codegen':
        # Extract the code between ```python and ```
        pattern = r'```python\s*(.*?)\s*```'
        matches = re.findall(pattern, output, re.DOTALL | re.IGNORECASE)
        if matches:
            extracted_text = matches[-1].strip()  # Take the last match
    elif mode == 'infogen':
        # Extract content after **Final Information** or **Modified Reasoning Steps**
        pattern_info = "**Final Information**"
        pattern_step = "**Modified Reasoning Steps**"
        if pattern_info in output:
            extracted_text = output.split(pattern_info)[-1].replace("\n","").strip("```").strip()
        elif pattern_step in output:
            extracted_text = output.split(pattern_step)[-1].strip("```").strip()
        else:
            extracted_text = "No helpful information found."
    else:
        # Existing extraction logic for 'gen' and 'choose' modes
        pattern = r'\\boxed\{(.*)\}'
        matches = re.findall(pattern, output)
        if matches:
            extracted_text = matches[-1]  # Take the last match
            if mode in ['choose', 'qa']:
                # Handle 'choose' mode
                inner_pattern = r'\\text\{(.*)\}'
                inner_matches = re.findall(inner_pattern, extracted_text)
                if inner_matches:
                    extracted_text = inner_matches[-1]  # Take the last match
                extracted_text = extracted_text.strip("()")
    return extracted_text


# Function to extract text between two tags
def extract_between(text: str, begin_search_query_token: str, end_search_query_token: str) -> List[str]:
    
    # Set pattern 
    pattern_begin = begin_search_query_token.replace('|', '\\|')
    pattern = begin_search_query_token.replace('|', '\\|') + "([\s\S]*?)" + end_search_query_token.replace('|', '\\|')
    
    if len(re.findall(pattern_begin, text)) > 1:
        return
    output = re.findall(pattern, text)
    if len(output) == 0:
        return
    
    # Split output
    output = output[0].strip().strip("\n")
    output = [i.strip().strip(";") for i in output.split("\n")]
    if len(output) == 1 and ";" in output[0]:
        output = output[0].split(";")
        output = [o.strip() for o in output]
    
    # Return output when 
    output = [o for o in output if len(o) > 0]
    return output


def replace_recent_steps(origin_str, replace_str):
    """
    Replaces specific steps in the original reasoning steps with new steps.
    If a replacement step contains "DELETE THIS STEP", that step is removed.

    Parameters:
    - origin_str (str): The original reasoning steps.
    - replace_str (str): The steps to replace or delete.

    Returns:
    - str: The updated reasoning steps after applying replacements.
    """

    def parse_steps(text):
        """
        Parses the reasoning steps from a given text.

        Parameters:
        - text (str): The text containing reasoning steps.

        Returns:
        - dict: A dictionary mapping step numbers to their content.
        """
        step_pattern = re.compile(r"Step\s+(\d+):\s*")
        steps = {}
        current_step_num = None
        current_content = []

        for line in text.splitlines():
            step_match = step_pattern.match(line)
            if step_match:
                # If there's an ongoing step, save its content
                if current_step_num is not None:
                    steps[current_step_num] = "\n".join(current_content).strip()
                current_step_num = int(step_match.group(1))
                content = line[step_match.end():].strip()
                current_content = [content] if content else []
            else:
                if current_step_num is not None:
                    current_content.append(line)
        
        # Save the last step if any
        if current_step_num is not None:
            steps[current_step_num] = "\n".join(current_content).strip()
        
        return steps

    # Parse the original and replacement steps
    origin_steps = parse_steps(origin_str)
    replace_steps = parse_steps(replace_str)

    # Apply replacements
    for step_num, content in replace_steps.items():
        if "DELETE THIS STEP" in content:
            # Remove the step if it exists
            if step_num in origin_steps:
                del origin_steps[step_num]
        else:
            # Replace or add the step
            origin_steps[step_num] = content

    # Sort the steps by step number
    sorted_steps = sorted(origin_steps.items())

    # Reconstruct the reasoning steps as a single string
    new_reasoning_steps = "\n\n".join([f"{content}" for num, content in sorted_steps])

    return new_reasoning_steps

def get_qa_pairs(start_idx=0, end_idx=1000):
    with open("reasoning_rag/parallel_dataset/datasets/nq_train.json", "r") as f:
        data = json.load(f)
    qa_pairs = []
    for item in data:
        qa_pairs.append({
            "question": item.get("prompt", [{}])[0].get("content"), 
            "answer": item.get("reward_model", {}).get("ground_truth")
        })
    return qa_pairs[start_idx:end_idx]

def call_local_api(prompt, model="Qwen/Qwen3-235B-A22B", max_tokens=1024, temperature=0.7):
    client = OpenAI(
        api_key="EMPTY",
        base_url="http://0.0.0.0:8000/v1",
    )
    chat_response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=0.95,
        extra_body={"top_k": 20,
                    "chat_template_kwargs": {"enable_thinking": False}}
    )
    return chat_response.choices[0].message.content

def call_openrouter_api(prompt, model="qwen/qwen3-235b-a22b:free", max_tokens=1024, temperature=0.7):
    client = OpenAI(
        base_url="https://openrouter.ai/api/v1",
        api_key=os.getenv("OPENROUTER_API_KEY"),
    )
    completion = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}]
    )
    return completion.choices[0].message.content.strip()

def call_openai_api(prompt, max_tokens=150, temperature=0.7, model="gpt-4.5-preview-2025-02-27"):
    openai.api_key = os.getenv("OPENAI_API_KEY")
    try:
        response = openai.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=max_tokens,
            temperature=temperature
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"OpenAI API error: {e}")
        return None

def skip_steps(result, step, msg):
    for s in ["2_related_queries", "3_statements", "3.5_statement_verification", "4_complex_question", "5_integration"]:
        if s >= step:
            result["steps"][s]["status"] = "skipped"
    result["overall_status"] = f"failed_at_{step}"
    result["steps"][step]["error"] = msg
    return result 

        
class Generator:
    def __init__(self, model_id_or_path="Qwen/Qwen3-32B"):
        self.llm = LLM(
            model=model_id_or_path, 
            enforce_eager=True,
            tensor_parallel_size=torch.cuda.device_count(),
            gpu_memory_utilization=0.9,
            max_num_seqs=32, # Reduce this due to vllm prefill
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path)

    def apply_chat_template(self, prompt, enable_thinking=True):
        return self.tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}], 
            tokenize=False, 
            add_generation_prompt=True,
            enable_thinking=enable_thinking
        )
       
    def generate(self, prompts, max_tokens, sampling_params=None, n=1, stop=None, repetition_penalty=None, apply_chat=True, enable_thinking=True):
        if sampling_params:
            self.params = sampling_params
        else:   
            self.params = {
                "temperature": 0.6 if enable_thinking else 0.7,
                "top_p": 0.95 if enable_thinking else 0.8,
                "top_k": 20,
                "min_p": 0.0,
                "n": n
            }
        print("Generator params:", self.params)
        # Set additional sampling params
        self.params["max_tokens"] = max_tokens
        if stop is not None:
            if isinstance(stop, str):
                self.params["stop"] = [stop]
            elif isinstance(stop, list):
                self.params["stop"] = stop
            self.params["include_stop_str_in_output"] = True
        if repetition_penalty is not None:
            self.params["repetition_penalty"] = repetition_penalty 
            
        # Set prompts
        if apply_chat:
            prompts = [self.apply_chat_template(p, enable_thinking) for p in prompts]
        
        # Generate responses
        responses = self.llm.generate(
            prompts, 
            SamplingParams(**self.params),
        )
        return [(response.outputs[0].text, len(response.outputs[0].token_ids)) for response in responses]

