import requests
import os

from sglang import assistant_begin, assistant_end
from sglang import assistant, function, gen, system, user
from sglang import image
from sglang import RuntimeEndpoint, set_default_backend
from sglang.srt.utils import load_image
from sglang.test.test_utils import is_in_ci
from sglang.utils import print_highlight, terminate_process, wait_for_server

import json
import re
import json_repair
from tqdm import tqdm

from ckans_v3 import check_answer
from concurrent.futures import ThreadPoolExecutor

# physics old backend chat_template
# math new backend chat_template

from sglang.lang.chat_template import (  
    register_chat_template,   
    ChatTemplate,   
    ChatTemplateStyle  
)
  
register_chat_template(  
    ChatTemplate(  
        name="qwen-no-system",  
        default_system_prompt=None,  # 设置为None  
        role_prefix_and_suffix={  
            "system": ("<|im_start|>system\n", "<|im_end|>\n"),  
            "user": ("<|im_start|>user\n", "<|im_end|>\n"),  
            "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),  
        },  
        style=ChatTemplateStyle.PLAIN,  
        stop_str=("<|im_end|>",),  
    )  
)

port = 8001
set_default_backend(RuntimeEndpoint(f"http://localhost:{port}", chat_template_name='qwen-no-system'))


#cks model port, you can change to your own server address
cks_backend = RuntimeEndpoint("http://xx.xx.xx.xx:8001", chat_template_name='qwen-no-system')


def ckans(question, s1_answer, answer):
    try:
        check_result = check_answer.run(question, s1_answer, answer, backend=cks_backend)
        json_obj = json_repair.loads(check_result["judgement"])
        return json_obj['is_correct'] == 'true'
    except Exception:
        return False

@function
def solve_question(s, question):
    # answer_regex = r'^<answer>[\s\S]*?</answer>$'
    gen_args = {
        'max_tokens': 32768,
        'temperature': 0.6,
        # 'regex': answer_regex,
        'top_k': 20,
        'presence_penalty': 1.5,
    }    

    s += user(f"{question}\n\n")
    s += assistant(gen('response', **gen_args))


def rollout_k(x):
    k = 16
    x[f'rollout_{k}'] = []
    for i in range(k):
        try:
            question = x['question']
            state = solve_question(question)
            msgs = state.messages()
            if '</think>' in msgs[1]['content']:
                extract_answer = msgs[1]['content'].split('</think>')[-1]   
                ckans_res = ckans(question, extract_answer, x['standard_answer'])
            else:
                extract_answer = None
                ckans_res = False
            y = {
                'messages': msgs,
                'extract_answer': msgs[1]['content'].split('</think>')[-1],
                'ckans': ckans_res,
            }
            '''
            user question
            assistant answer
            '''
            x[f'rollout_{k}'].append(y)
        except Exception as e:
            x[f'rollout_{k}'].append({
                'err': str(e),
            })

    return x
    


with open('../boba/math_200_questions.jsonl', 'r') as f:
    data = f.readlines()

data = [json.loads(x) for x in data]
# data = data[:10]
data = [
    {
        # 'question': x['input'].split('user\n')[-1].split('\nassistant\n')[0],
        # 'model_response': x['output'],
        # 'standard_answer': x['standard_answer'],
        'id': x['id'],
        'question': x['question'],
        'standard_answer': x['standard_answer'],
        #'domain': x['domain'],
        #'answer': x['answer'],
    }
    for x in data
]

with ThreadPoolExecutor(max_workers=5000) as executor:
    results = list(tqdm(executor.map(rollout_k, data), total=len(data)))

with open('../boba/rollout_16_200_qwq.jsonl', 'w') as f:
    f.writelines([json.dumps(x, ensure_ascii=False) + '\n' for x in results])