import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from ast import arg, parse
from dis import Instruction
import enum
import json 
from math import fabs
from operator import concat
import random
import math
import time
from vllm import LLM, SamplingParams
import torch
from tqdm import tqdm
import argparse
import re
import time
import datetime
from typing import List, Dict, Optional, final, Union
import requests
import os
import pandas as pd
from python_executor import PythonExecutor
from tools.web_search_main import deep_search
from tools.debug_code import debug_code_function
from tools.rollback_code import rollback
from tools.refine_code import refine

import re

from utils import *
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed

# 默认参数配置
default_params_config = {
    'temperature': 1.0,
    'max_tokens': 15000,
    'top_p': 0.9,
    'top_k': -1,
    'presence_penalty': 1.0,
    'repetition_penalty': 1.05,
    'n': 1,
    'stop': ['</python>', '</search>', '</answer>'],
    'include_stop_str_in_output': True,
}


def seed_api(idx, query):
    """调用API获取答案"""
    try:
        response = requests.post("http://0.0.0.0:8008/v1/chat/completions", 
                               json={"question": query, "default_params_config": default_params_config}, timeout=600)
        # response.raise_for_status()
        data = response.json()
        ans = data["answer"]
        final_answer = data["final_answer"]

        answer_text = ans.rfind("</think>")
        if answer_text != -1:
            answer_text = ans[answer_text+len("</think>"):]
            answer_text = last_boxed_only_string(answer_text)
            if answer_text:
                final_answer = answer_text[7:-1]
            else:
                final_answer = "None"
        else:
            final_answer = "None"
        return idx, ans, final_answer
    except Exception as e:
        print(f"API调用失败: {e}")
        return idx, f"Error: {str(e)}", f"Error: {str(e)}"

class Inference():
    def __init__(self, batch_size=20, prompt_type='code_search', use_debug=False, use_rollback=False, use_refiner=False):
        self.counts = 100
        self.prompt_type = prompt_type
        self.batch_size = batch_size
        self.use_debug = use_debug
        self.use_rollback = use_rollback
        self.use_refiner = use_refiner
        self.prompt_template = ''
        self.max_python_times = 3
        self.max_search_times = 3
        self.max_debug_times = 1
        self.max_refine_times = 1
        self.max_rollback_times = 1
        self.questions = []
        self.answers = []

        if self.prompt_type == 'code_search':
            self.prompt_template = """You are a helpful assistant that can solve the given question step by step with the help of the search tool and python interpreter tool.
Given a question, you need to first think about the reasoning process in the mind and then provide the answer.
During thinking, you can invoke the search tool to search and python interpreter tool to calculate the math problem for fact information about specific topics if needed.
The reasoning process is enclosed within <think> </think>, and the answer is after </think>,
and the search query and result are enclosed within <search> </search> and <result> </result> tags respectively.
For example, <think> This is the reasoning process. </think> <search> search query here </search> <result> search result here </result>
<think> This is the reasoning process. </think> <python> python code here </python> <result> python interpreter result here </result>
<think> This is the reasoning process. </think> The final answer is \\[ \\boxed{answer here} \\]
In the last part of the answer, the final exact answer is enclosed within \\boxed{} with latex format."""
        elif self.prompt_type == 'search':
            self.prompt_template = """
You are a helpful assistant that can solve the given question step by step with the help of the wikipedia search tool. \
Given a question, you need to first think about the reasoning process in the mind and then provide the answer. \
During thinking, you can invoke the wikipedia search tool to search for fact information about specific topics if needed. \
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags respectively, \
and the search query and result are enclosed within <search> </search> and <result> </result> tags respectively. \
For example, <think> This is the reasoning process. </think> <search> search query here </search> <result> search result here </result> \
<think> This is the reasoning process. </think> <answer> The final answer is \\[ \\boxed{answer here} \\] </answer>. \
In the last part of the answer, the final exact answer is enclosed within \\boxed{} with latex format.
"""
        elif self.prompt_type == 'math':
            self.prompt_template = """
You are a helpful assistant that can solve the given question step by step with the help of the python interpreter tool. \
Given a question, you need to first think about the reasoning process in the mind and then provide the answer. \
During thinking, you can invoke the python interpreter tool to calculate the math problem for fact information about specific topics if needed. \
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags respectively. \
For example, <think> This is the reasoning process. </think> <python> python code here </python> <result> python interpreter result here </result> \
<think> This is the reasoning process. </think> <answer> The final answer is \\[ \\boxed{answer here} \\] </answer>. \
In the last part of the answer, the final exact answer is enclosed within \\boxed{} with latex format.
"""

    def load_csvs(self, csv_file_path):
        pre_instruction = """You are working as an assistant of a chemist user. Please follow the instruction of the chemist and generate a molecule that satisfies the requirements of the chemist user. You could think step by step, but your final response should be a SMILES string. 

Final Result Format:
- Place the final calculation or derived answer within the symbol \\boxed{ SMILES string }.

Questions:
{question}
"""
        questions = []
        df = pd.read_csv(csv_file_path)
        for index, row in df.iterrows():
            if index < 100:
                questions.append(pre_instruction.replace('{question}', row['Instruction']))
        name = csv_file_path.split('/')[-2]
        self.dataset_name = name
        return questions
    
    def load_json(self, data_path):
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                if "question" in data:
                    question = data['question']
                else:
                    question = data['problem']
                self.questions.append(question)
                answer = data['answer']
                self.answers.append(answer)
        name = data_path.split('/')[-2]
        self.dataset_name = name

    def math_run(self, data_path):
        self.load_json(data_path)
        questions = self.questions
        answers = self.answers
        res = []
        total_examples = len(questions)
        num_batches = math.ceil(len(questions) / self.batch_size)
        print(f"dataset {self.dataset_name} all counts: {total_examples}, batch size: {self.batch_size}, bath counts: {num_batches}")
        folder_dir = "evaluation/math_result/"
        
        if os.path.exists(folder_dir) == False:
            os.makedirs(folder_dir) 
        save_path = folder_dir + self.dataset_name + ".jsonl"

        extracted_answers = []
        
        for batch_idx in tqdm(range(num_batches), desc=f"Processing batches"):
            start_idx = batch_idx * self.batch_size
            end_idx = min((batch_idx + 1) * self.batch_size, len(questions))
            batch_samples = questions[start_idx:end_idx]

            results = []
            with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
                future_to_query = {executor.submit(seed_api, start_idx+idx, query): query for idx, query in enumerate(batch_samples)}
                
                for future in as_completed(future_to_query):
                    query = future_to_query[future]
                    try:
                        idx, result, final_answer = future.result()
                        results.append((idx, query, result, final_answer))
                    except Exception as e:
                        print(f"请求'{query}'时发生异常: {e}")
                        final_answer = "Error"
                        results.append((idx, query, f"Error: {str(e)}", final_answer))

                    extracted_answers.append((idx, final_answer))
                    with open(save_path, "a", encoding="utf-8") as f:
                        f.write(json.dumps({"idx": idx, "query": query, "real_answer": answers[idx], "final_answer": final_answer, "answer": result}, ensure_ascii=False) + "\n")
        # 对extracted_answers进行idx升序排序
        extracted_answers.sort(key=lambda x: x[0])
        extracted_answers = [x[1] for x in extracted_answers]
        return extracted_answers


    def run(self, csv_file_path):
        # self.load_datas()
        questions = self.load_csvs(csv_file_path)
        res = []
        total_examples = min(len(questions), self.counts)
        questions = questions[:total_examples]
        # answers = self.answers[:total_examples]
        num_batches = math.ceil(len(questions) / self.batch_size)
        print(f"dataset {self.dataset_name} all counts: {total_examples}, batch size: {self.batch_size}, bath counts: {num_batches}")
        folder_dir = "evaluation/eval_result_qwen3-8b-tools-grpo-ckp140-answer-temp10_tp09_rp105/"
        
        if os.path.exists(folder_dir) == False:
            os.makedirs(folder_dir) 
        save_path = folder_dir + self.dataset_name + ".csv"
        with open(save_path, "w+") as f:
            f.write("outputs\n")

        extracted_answers = []
        
        for batch_idx in tqdm(range(num_batches), desc=f"Processing batches"):
            start_idx = batch_idx * self.batch_size
            end_idx = min((batch_idx + 1) * self.batch_size, len(questions))
            batch_samples = questions[start_idx:end_idx]

            results = []
            with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
                future_to_query = {executor.submit(seed_api, start_idx+idx, query): query for idx, query in enumerate(batch_samples)}
                
                for future in as_completed(future_to_query):
                    query = future_to_query[future]
                    try:
                        idx, result, final_answer = future.result()
                        results.append((idx, query, result, final_answer))
                    except Exception as e:
                        print(f"请求'{query}'时发生异常: {e}")
                        final_answer = "Error"
                        results.append((idx, query, f"Error: {str(e)}", final_answer))

                    extracted_answers.append((idx, final_answer))
                    with open("./qwen3-8b-tools-grpo-ckp140-answer-temp10_tp09_rp105.jsonl", "a", encoding="utf-8") as f:
                        f.write(json.dumps({"idx": idx, "query": query, "answer": result, "final_answer": final_answer}, ensure_ascii=False) + "\n")
        # 对extracted_answers进行idx升序排序
        extracted_answers.sort(key=lambda x: x[0])
        extracted_answers = [x[1] for x in extracted_answers]
        df = pd.DataFrame(extracted_answers, columns=["outputs"])
        df.to_csv(save_path, mode='a', header=False, index=True)

if __name__ == "__main__":
    
    inference = Inference()
    bg_time = time.time()

    csv_file_path_list = [
        "tomg_bench-4cham/datasets/TOMG-Bench/benchmarks/open_generation/MolCustom/AtomNum/test.csv",
        "tomg_bench-4cham/datasets/TOMG-Bench/benchmarks/open_generation/MolCustom/BondNum/test.csv",
        "tomg_bench-4cham/datasets/TOMG-Bench/benchmarks/open_generation/MolCustom/FunctionalGroup/test.csv",
        "tomg_bench-4cham/datasets/TOMG-Bench/benchmarks/open_generation/MolEdit/AddComponent/test.csv",
        "tomg_bench-4cham/datasets/TOMG-Bench/benchmarks/open_generation/MolEdit/DelComponent/test.csv",
        "tomg_bench-4cham/datasets/TOMG-Bench/benchmarks/open_generation/MolEdit/SubComponent/test.csv",
        "tomg_bench-4cham/datasets/TOMG-Bench/benchmarks/open_generation/MolOpt/LogP/test.csv",
        "tomg_bench-4cham/datasets/TOMG-Bench/benchmarks/open_generation/MolOpt/MR/test.csv",
        "tomg_bench-4cham/datasets/TOMG-Bench/benchmarks/open_generation/MolOpt/QED/test.csv"
    ]
    for csv_file_path in csv_file_path_list:
        inference.run(csv_file_path)
    end_time = time.time()
    print(f"Total time: {end_time - bg_time} seconds")