from contextlib import contextmanager
from codetiming import Timer
@contextmanager
def _timer(name: str, timing_raw):
    with Timer(name=name, logger=None) as timer:
        yield
    timing_raw[name] = timer.last
    
from verl.manager import MultiEnvManager, BufferManager
import os
import json
import numpy as np
def write_to_json(data, path):
    with open(path, 'w', encoding='utf8') as f:
        f.write(json.dumps(data, ensure_ascii=False, indent=4))

def read_json(path):
    with open(path, 'r', encoding='utf8') as f:
        return json.loads(f.read())
    
def read_jsonl(path):
    data = []
    with open(path, 'r', encoding='utf8') as f:
        for line in f.readlines():
            data.append(json.loads(line))
    return data

SETTINGS = {
    "use_old_output_format": False,
    'no_short_thought': False,
    'thinker_freq': 6
}

"""
http://localhost:8072, Qwen2.5-72B-Instrcut
http://localhost:8008, Qwen3-8B
http://localhost:8007, Qwen2.5-7B-Instruct
"""

ACTOR_URL="http://localhost:8007"
THINKER_URL="http://localhost:8007"   


USE_OPENAI_API_ACTOR = False
USE_OPENAI_API_THINKER = False
ACTOR_MODEL="gpt-4o-mini"  
THINKER_MODEL="gpt-4o-mini"


from openai import OpenAI
client = OpenAI(
    api_key="your api kei",
    base_url="https://xxx",
)

import requests
def generate_text_v1(input_messages, max_tokens=128, temperature=0.7, base_url="http://g2:8000"):
    headers = {
        "Content-Type": "application/json",
        "Accept": "application/json"
    }
    
    data = {
        "input_messages": input_messages,
        "temperature": temperature,
        "max_tokens": max_tokens, 
    }
    
    response = requests.post(
        f"{base_url}/chat",
        headers=headers,
        json=data
    )
    # breakpoint()
    # response.raise_for_status()
    return response.json()["generated_texts"]

def package_messages(messages):
    prompt = ""
    for item in messages:
        input_text = item["content"]

        if item["role"] == "system":
            prompt += f"<|im_start|>system\n{input_text}<|im_end|>\n"
        elif item["role"] == "user":
            prompt += f"<|im_start|>user\n{input_text}<|im_end|>\n"
        else:
            prompt += f"<|im_start|>assistant\n{input_text}<|im_end|>\n"
    
    prompt += "<|im_start|>assistant\n"
    return prompt


def generate_text_v2(input_messages, max_tokens=512, temperature=0.0, base_url="http://g7:8000"):
    headers = {
        "Content-Type": "application/json",
        "Accept": "application/json"
    }
    
    input_prompts = [package_messages(x) for x in input_messages]

    data = {
        "input_messages": input_prompts,
        "temperature": temperature,
        "max_tokens": max_tokens, 
    }
    
    response = requests.post(
        f"{base_url}/generate",
        headers=headers,
        json=data
    )
    # breakpoint()
    # response.raise_for_status()
    return response.json()["generated_texts"]

from joblib import Parallel, delayed
def generate_text_v3(input_messages, max_tokens=512, temperature=0.0, model_name="gpt-5-mini-2025-08-07"):
    def _test_once(messages, max_tokens=128, temperature=0.7):
        retry_time = 0
        max_retry = 3
        response = "ERROR: api call error"
        while retry_time < max_retry:
            try:
                chat_response = client.chat.completions.create(
                    model=model_name,
                    messages=messages,
                    max_tokens=max_tokens,
                    temperature=temperature,
                )
                response = chat_response.choices[0].message.content
                break
            except Exception as e:
                print(f"Error: {e}")
                retry_time += 1
                continue
        return response
    
    if len(input_messages) == 0:
        return []
    
    # test
    # x = _test_once(input_messages[0], max_tokens=128, temperature=0.7)
    # breakpoint()
    results = Parallel(n_jobs=len(input_messages), backend="threading")(
        delayed(_test_once)(
            prompt,
            max_tokens,
            temperature
        ) for prompt in input_messages
    )
    return results

# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("Qwen3-8B-Instruct")
def generate_text_v4(input_messages, max_tokens=512, temperature=0.0, base_url="http://g7:8000"):
    def package_thinking_messages(messages):
        prompts = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True  # True is the default value for enable_thinking
        )
        return prompts
    headers = {
        "Content-Type": "application/json",
        "Accept": "application/json"
    }
    
    input_prompts = [package_thinking_messages(x) for x in input_messages]

    data = {
        "input_messages": input_prompts,
        "temperature": temperature,
        "max_tokens": max_tokens, 
    }
    
    response = requests.post(
        f"{base_url}/generate",
        headers=headers,
        json=data
    )
    return response.json()["generated_texts"]

def rollout_with_env(total_env_infos, batch_size, max_turns):
    """
    Args:
        total_env_infos: 
            uid,
            env_name,
            env_config,
    """
    ###############################
    #### splited by batch size ####
    ###############################
    n = len(total_env_infos) // batch_size
    batch_env_infos = []
    for i in range(n+1):
        temp_data = total_env_infos[i*batch_size: (i+1)*batch_size]
        if len(temp_data) > 0:
            batch_env_infos.append(temp_data)
    print("ENV NUMBER with BATCH: ", [len(x) for x in batch_env_infos])
    
    total_rollout_data = []
    for env_infos in batch_env_infos:
        
        timing_raw = {}
        ###########################################
        #### acquire env configs and init envs ####
        ###########################################
        with _timer('init env', timing_raw):
            env_manager = MultiEnvManager(env_infos)  # 通过本地或者远程启动多个进程
            initial_feedbacks = env_manager.init_envs()     # 通过多进程并行初始化多个环境，并返回初始化信息，如task, init_obs等
        
        with _timer('init buffer', timing_raw):
            buffer_manager = BufferManager(initial_feedbacks) #进行数据记录和prompt生成
            
        print(timing_raw)
    
        while True:
            # Break at max-turns
            if buffer_manager.step >= max_turns:
                break
                
            #######################################
            #### prepare input prompts (actor) ####
            #######################################
            running_ids, messagess_todo = buffer_manager.build_prompts_for_actors() # 1.对外返回todo list 2.对内记录 running_ids
            print(f"step {buffer_manager.step}:")
            print("Env number: {}".format(len(messagess_todo)))

            # Break when no tasks
            if len(messagess_todo) == 0:
                break
            
            ##########################
            #### generate by vLLM ####
            ##########################
            timing_raw = {}
            with _timer('vllm sampling', timing_raw):
                if USE_OPENAI_API_ACTOR:
                    response_texts = generate_text_v3(messagess_todo, max_tokens=128, temperature=0.0, model_name=ACTOR_MODEL)
                else:
                    if "8072" in ACTOR_URL:
                        response_texts = generate_text_v2(messagess_todo, max_tokens=128, temperature=0.0, base_url=ACTOR_URL)
                    else:
                        response_texts = generate_text_v1(messagess_todo, max_tokens=128, temperature=0.0, base_url=ACTOR_URL)

            #################################################
            #### execute in environment and get feedback ####
            #################################################
            with _timer('action executing', timing_raw):
                feedbacks = env_manager.execute_actions(running_ids, response_texts)

            ###################################
            #### postprocess the feedbacks ####
            ###################################
            with _timer('update buffer', timing_raw):
                buffer_manager.update_trajectory(running_ids, response_texts, feedbacks)
                buffer_manager.step += 1

            #######################################
            #### prepare input prompts (summary) ####
            #######################################
            summary_ids, messagess_todo_v2 = buffer_manager.build_prompts_for_deepthinks(running_ids)
            # breakpoint()
            print("Deepthink number: {}".format(len(messagess_todo_v2)))
            
            ##########################
            #### generate by vLLM ####
            ##########################
            with _timer('vllm sampling v2', timing_raw):
                if USE_OPENAI_API_THINKER:
                    response_texts = generate_text_v3(messagess_todo_v2, max_tokens=512, temperature=0.0, model_name=THINKER_MODEL)
                else:
                    if "8072" in THINKER_URL:
                        # As of transformers v4.44, default chat template is no longer allowed, so you must provide a chat template if the tokenizer does not define one
                        response_texts = generate_text_v2(messagess_todo_v2, max_tokens=512, temperature=0.0, base_url=THINKER_URL)
                    elif "8008" in THINKER_URL:
                        response_texts = generate_text_v4(messagess_todo_v2, max_tokens=4096, temperature=0.0, base_url=THINKER_URL)
                    else:
                        response_texts = generate_text_v1(messagess_todo_v2, max_tokens=512, temperature=0.0, base_url=THINKER_URL)

            # clean_responses = [''] * len(summary_ids)
            ###################################
            #### postprocess the feedbacks ####
            ###################################
            with _timer('update buffer v2', timing_raw):
                buffer_manager.update_trajectory_for_deepthinks(summary_ids, response_texts)

            print(timing_raw)
            
        total_rollout_data.extend(buffer_manager.batch_rollout_data)
        #####################################
        #### clear all envs and shutdown ####
        #####################################
        env_manager.shutdown()
        del env_manager
        del buffer_manager
            
    return total_rollout_data

def test_all(out_path):
    if not os.path.exists(out_path):
        os.mkdir(out_path)
    ##################
    #### ALFWorld ####
    ##################
    alfworld_agentboard_env_config = _load_alfworld_config("agentboard")

    ##################
    #### SciWorld ####
    ##################
    sciworld_agentboard_env_config = _load_sciworld_config("agentboard")

    ##################
    ##### BabyAI #####
    ##################
    babyai_agentboard_env_config = _load_babyai_config()

    #################
    #### Jericho ####
    #################
    jericho_agentboard_env_config = _load_jericho_config()

    ##################
    ###### PDDL ######
    ##################
    pddl_agentboard_env_config = _load_pddl_config()

    ###################
    #### init vllm ####
    ###################
    timing_raw = {}
    total_results = []

    with _timer('alfworld_agentboard', timing_raw):
        print("TESTING alfworld_agentboard")
        total_rollout_data = rollout_with_env(alfworld_agentboard_env_config, 256, 50) # 134
        max_scores = [x["state"]["env_score"] * 10 for x in total_rollout_data]
        print("total_max_scores: ", max_scores)
        success_rate = sum(np.array(max_scores)==100.0)/len(max_scores) * 100
        mean_score = sum(max_scores)/len(max_scores)
        print("success rate: ", success_rate)
        print("mean score: ", mean_score)
        total_results.append({
            "success rate": success_rate,
            "mean score": mean_score
        })
        write_to_json(total_rollout_data, f"{out_path}/alfworld_agentboard.json")

    with _timer('sciworld_agentboard', timing_raw):
        print("TESTING sciworld_agentboard")
        total_rollout_data = rollout_with_env(sciworld_agentboard_env_config, 256, 50) # 90
        max_scores = [x["state"]["env_score"] * 10 for x in total_rollout_data]
        print("total_max_scores: ", max_scores)
        success_rate = sum(np.array(max_scores)==100.0)/len(max_scores) * 100
        mean_score = sum(max_scores)/len(max_scores)
        print("success rate: ", success_rate)
        print("mean score: ", mean_score)
        total_results.append({
            "success rate": success_rate,
            "mean score": mean_score
        })
        write_to_json(total_rollout_data, f"{out_path}/sciworld_agentboard.json")

    with _timer('babyai_agentboard', timing_raw):
        print("TESTING babyai_agentboard")
        total_rollout_data = rollout_with_env(babyai_agentboard_env_config, 256, 50) # 114
        max_scores = [x["state"]["env_score"] * 10 for x in total_rollout_data]
        print("total_max_scores: ", max_scores)
        success_rate = sum(np.array(max_scores)==100.0)/len(max_scores) * 100
        mean_score = sum(max_scores)/len(max_scores)
        print("success rate: ", success_rate)
        print("mean score: ", mean_score)
        total_results.append({
            "success rate": success_rate,
            "mean score": mean_score
        })
        write_to_json(total_rollout_data, f"{out_path}/babyai_agentboard.json")

    with _timer('jericho_agentboard', timing_raw):
        print("TESTING jericho_agentboard")
        total_rollout_data = rollout_with_env(jericho_agentboard_env_config, 256, 50) # 20
        max_scores = [x["state"]["env_score"] * 10 for x in total_rollout_data]
        print("total_max_scores: ", max_scores)
        success_rate = sum(np.array(max_scores)==100.0)/len(max_scores) * 100
        mean_score = sum(max_scores)/len(max_scores)
        print("success rate: ", success_rate)
        print("mean score: ", mean_score)
        total_results.append({
            "success rate": success_rate,
            "mean score": mean_score
        })
        write_to_json(total_rollout_data, f"{out_path}/jericho_agentboard.json")

    with _timer('pddl_agentboard', timing_raw):
        print("TESTING pddl_agentboard")
        total_rollout_data = rollout_with_env(pddl_agentboard_env_config, 256, 50)  # 60
        max_scores_by_task = {}
        for x, config in zip(total_rollout_data, pddl_agentboard_env_config):
            task_name = config["env_config"]["game_name"]
            if task_name not in max_scores_by_task:
                max_scores_by_task[task_name] = [x["state"]["env_score"] * 10]
            else:
                max_scores_by_task[task_name].append(x["state"]["env_score"] * 10)
        for k, v in max_scores_by_task.items():
            print(k, sum(v)/len(v))

        max_scores = [x["state"]["env_score"] * 10 for x in total_rollout_data]
        print("total_max_scores: ", max_scores)
        success_rate = sum(np.array(max_scores)==100.0)/len(max_scores) * 100
        mean_score = sum(max_scores)/len(max_scores)
        print("success rate: ", success_rate)
        print("mean score: ", mean_score)
        total_results.append({
            "success rate": success_rate,
            "mean score": mean_score
        })
        write_to_json(total_rollout_data, f"{out_path}/pddl_agentboard.json")
    
    print(timing_raw)
    for item in total_results:
        print("{}, {}".format(item["success rate"], item["mean score"]), end=", ")

      
def _load_alfworld_config(tag):
    env_configs = []
    alfworld_env_test = read_jsonl("./verl/environments/alfworld/test.jsonl")
    for i, item in enumerate(alfworld_env_test):
        game_file = "./verl/environments/alfworld/alfworld_data/json_2.1.1/valid_unseen/" + \
            item["additional_info"]["description"] + "/game.tw-pddl"
        env_configs.append({
            "uid": f"alfworld_{i}",
            "env_name": "alfworld_agentboard",
            "env_config": {
                "game_file": game_file,
                "task_id": f"alfworld_{i}",
                "subgoals": item["subgoals"],
            },
            "special_settings": SETTINGS
        })
    return env_configs

def _load_sciworld_config(tag="agentboard"):
    env_configs = []
    sciworld_env_test = read_jsonl("./verl/environments/sciworld/test.jsonl")
    for i, item in enumerate(sciworld_env_test):
        task_name, var = item["additional_info"]["env_name"], item["additional_info"]["var"]
        env_configs.append({
            "uid": f"{task_name}_{var}",
            "env_name": "sciworld_agentboard",
            "env_config": {
                "task_name": task_name,
                "var": var,
                "subgoals": item["subgoals"],
            },
            "special_settings": SETTINGS
        })
    return env_configs

def _load_jericho_config():
    from verl.environments.jericho.create_dataset import get_all_environment_configs
    configs = get_all_environment_configs()
    env_configs = []
    for config in configs:
        env_configs.append({
            "uid": config["game_id"],
            "env_name": "jericho",
            "env_config": config,
            "special_settings": SETTINGS
        })
    return env_configs


def _load_babyai_config():
    from verl.environments.babyai.create_dataset import get_all_environment_configs
    configs = get_all_environment_configs()
    env_configs = []
    for config in configs:
        env_configs.append({
            "uid": config["game_id"],
            "env_name": "babyai",
            "env_config": config,
            "special_settings": SETTINGS
        })
    return env_configs


def _load_pddl_config():
    from verl.environments.pddl.create_dataset import get_all_environment_configs
    configs = get_all_environment_configs()
    env_configs = []
    for config in configs:
        env_configs.append({
            "uid": config["game_id"],
            "env_name": "pddl",
            "env_config": config,
            "special_settings": SETTINGS
        })
    return env_configs
    
if __name__ == "__main__":
    print("SETTINGS: ")
    print(SETTINGS)
    print("Actor", ACTOR_URL)
    print("Thinker", THINKER_URL)
    print("Actor USE_OPENAI_API", USE_OPENAI_API_ACTOR)
    print("Thinker USE_OPENAI_API", USE_OPENAI_API_THINKER)
    print("Actor", ACTOR_MODEL)
    print("Thinker", THINKER_MODEL)

    output_path = "./evaluate_results/Qwen2.5-7B-Instruct"
    test_all(output_path)