import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from ast import arg, parse
from dis import Instruction
import enum
import json 
from math import fabs
from operator import concat
import random
import math

from vllm import LLM, SamplingParams
import torch
from tqdm import tqdm
import argparse
import re
import time
import datetime
from typing import List, Dict, Optional, final, Union, Any
import requests
import os
import pandas as pd
from python_executor import PythonExecutor
from tools.web_search_main import deep_search

import re

from utils import *
import gradio as gr


def execute_code_batch(
    batch_code: List[str], 
    max_length: Optional[int] = None, 
    max_workers: Optional[int] = 4
) -> Dict:
    """
    调用远程代码执行API执行批量代码

    :param batch_code: 代码片段列表，每个元素是一个完整的代码字符串
    :param max_length: 输出限制的最大长度，默认为None，表示使用API默认值
    :param max_workers: 批量执行的最大线程数，默认为4
    :return: API返回的JSON响应
    :raises: 
        - requests.exceptions.RequestException: 当请求失败时
        - json.JSONDecodeError: 当响应不是有效JSON时
    """
    # API端点
    url = "http://0.0.0.0:3999/execute"
    headers = {"Content-Type": "application/json"}
    
    # 构建请求载荷
    payload: dict[str, Any] = {"batch_code": batch_code}
    
    # 添加可选参数
    if max_length is not None:
        payload["max_length"] = max_length
    
    if max_workers is not None:
        payload["max_workers"] = max_workers
    

    if len(batch_code) <= 1:
        payload["max_workers"] = 1
    
    # 发送请求
    response = requests.post(url, headers=headers, json=payload)
    
    # 检查响应状态
    response.raise_for_status()
    
    # 解析并返回JSON响应
    return response.json()

def seed_api(query, params_config: dict):
    """调用API获取答案"""

    system = """You are a helpful assistant that can solve the given question step by step with the help of the search tool and python interpreter tool.
Given a question, you need to first think about the reasoning process in the mind and then provide the answer.
During thinking, you can invoke the search tool to search and python interpreter tool to calculate the math problem for fact information about specific topics if needed.
The reasoning process is enclosed within <think> </think>, and the answer is after </think>,
and the search query and result are enclosed within <search> </search> and <result> </result> tags respectively.
For example, <think> This is the reasoning process. </think> <search> search query here </search> <result> search result here </result>
<think> This is the reasoning process. </think> <python> python code here </python> <result> python interpreter result here </result>
<think> This is the reasoning process. </think> The final answer is \\[ \\boxed{answer here} \\]
In the last part of the answer, the final exact answer is enclosed within \\boxed{} with latex format."""

    # system = "You are a helpful assistant."
    print("query: ", query)
    print("*"*100)
    content = {
        "model": "xxx",
        "messages": [
            {"role": "system", "content": system},
            {"role": "user", "content": query}
        ],
        "temperature": params_config['temperature'],
        "top_p": params_config['top_p'],
        "max_tokens": params_config['max_tokens'],
        'top_k': params_config['top_k'],
        'n': params_config['n'],
        'stop': params_config['stop'],
        'presence_penalty': params_config['presence_penalty'],
        'repetition_penalty': params_config['repetition_penalty'],
        'include_stop_str_in_output': params_config['include_stop_str_in_output'],
        "stream": False
    }
    
    try:
        response = requests.post("http://0.0.0.0:6537/v1/chat/completions", 
                               json=content, timeout=600)
        response.raise_for_status()
        data = response.json()
        print("data: ", data)
        print("*"*100)
        ans = data["choices"][0]["message"]["content"]
        return ans
    except Exception as e:
        print(f"API调用失败: {e}")
        return f"Error: {str(e)}"

def process_question(question: str, params_config: dict) -> tuple[str, str]:
    """
    处理单个问题并返回答案
    
    Args:
        question: 输入的问题
        
    Returns:
        str: 处理后的答案
    """
    
    
    # 初始化变量
    generating = [0]  # 只有一个样本
    completed = []
    python_rounds = [0]
    search_rounds = [0]
    max_python_times = 3
    max_search_times = 3
    question = question + "\nPlease give the final answer in the format of \\boxed{}."
    ori_question = question

    while generating:
        params_config['stop'] = ['</python>', '</search>', '</answer>']
        output = seed_api(
            question,
            params_config,
        )

        python_indices = []
        search_indices = []
        other_indices = []
        text_generating_indices = []
        
      
        if output.strip().endswith('</python>'):
            if python_rounds[0] >= max_python_times:
                text_generating_indices.append((0, output))
            else:
                python_indices.append((0, output))
                python_rounds[0] += 1
        elif output.strip().endswith('</search>'):
            if search_rounds[0] >= max_search_times:
                text_generating_indices.append((0, output))
            else:
                search_indices.append((0, output))
                search_rounds[0] += 1
        else:
            other_indices.append((0, output))
        
        # 处理Python代码执行
        if python_indices:
            print('python begin')
            python_contents = []
            for i, content in python_indices:
                python_contents.append(content)
                question += content
            python_contents = [extract_python_content(content) for content in python_contents]
            for i, (idx, content) in enumerate(python_indices):
                data = execute_code_batch([python_contents[i]])[0]
                result = data['result']
                report = data['report']
                if report == "Done":
                    question += f'<result>\n{result}\n</result>'
                    print(result)
                else:
                    question += f'<result>\n{report}\n</result>'
                    print(report)
            print('python end')

        if search_indices:
            print('search begin')
            search_contents = []
            for i, content in search_indices:
                search_contents.append(
                    content
                )
                question += content
            search_contents = [extract_search_content(content) for content in search_contents]

            for i, (idx, content) in enumerate(search_indices):
                try:
                    search_result = deep_search(search_contents[i])
                    question += f'<result>\n{search_result}\n</result>'
                except Exception as e:
                    print(f"search error: {e}")
                    question += f'<result>\n\n</result>'
            print('search end')


        # 处理text_generating_indices
        if text_generating_indices:
            generate_results = []
            for i, content in text_generating_indices:
                generate_results.append(
                    question + content
                )
                question += content
            params_config['stop'] = ['</answer>']
            output_texts = seed_api(
                question,
                params_config,
            )
            question += output_texts
            completed.append(0)
        
        # 处理其他输出
        if other_indices:
            for i, content in other_indices:
                question += content
                completed.append(0)
        
        generating = [i for i in generating if i not in completed]
    
    # 提取答案
    text = question[len(ori_question):]
    last_answer_end = text.rfind('</answer>')
    if last_answer_end != -1:
        # Find the corresponding opening tag before this closing tag
        temp_text = text[:last_answer_end]
        last_answer_start = temp_text.rfind('<answer>')
        if last_answer_start != -1:
            temp_answer = text[last_answer_start + len('<answer>'):last_answer_end]
        else:
            temp_answer = None
    else:
        temp_answer = None
    if temp_answer:
        boxed_answer = temp_answer.strip()
        boxed_answer = last_boxed_only_string(boxed_answer)
        if boxed_answer and boxed_answer.startswith("\\boxed{") and boxed_answer.endswith("}"):
            boxed_content = boxed_answer[7:-1]  # Extract content between \\boxed{ and }
            boxed_answer = boxed_content
        if not boxed_answer:
            final_answer = temp_answer
        else:
            final_answer = boxed_answer
    else:
        boxed_answer = text.strip()
        final_answer = last_boxed_only_string(boxed_answer)
        if final_answer and final_answer.startswith("\\boxed{") and final_answer.endswith("}"):
            final_answer = final_answer[7:-1]  # Extract content between \\boxed{ and }
    if type(final_answer) == str:
        final_answer = final_answer.replace("<SMILES>", "").replace("</SMILES>", "")
    else:
        final_answer = "None"

    return text, final_answer

    

if __name__ == "__main__":

    params_config = {
        'temperature': 0.4,
        'max_tokens': 15000,
        'top_p': 0.95,
        'top_k': -1,
        'presence_penalty': 1.0,
        'n': 1,
        'stop': ['</python>', '</search>', '</answer>'],
        'include_stop_str_in_output': True,
    }

    def gradio_answer(question):
        return process_question(question, params_config)

    with gr.Blocks() as demo:
        gr.Markdown("# 单问题推理界面")
        with gr.Row():
            question = gr.Textbox(label="请输入问题", lines=3)
        with gr.Row():
            answer = gr.Textbox(label="模型答案", lines=12)
        submit_btn = gr.Button("提交")
        submit_btn.click(fn=gradio_answer, inputs=question, outputs=answer)

    demo.launch(server_name="0.0.0.0", server_port=7888)
