from multiprocessing import Queue, Process
import copy
import torch
from verl.utils.torch_functional import pad_sequence_to_length
from verl import DataProto
from verl.environments import (
    AlfworldEnv,
    AlfworldEnv_AgentBoard,
    SciworldEnv,
    SciworldEnv_Agentboard,
    JerichoEnv,
    BabyAIEnv,
    PDDLEnv,
)

ENV_CLASS_MAPPING = {
    "alfworld": AlfworldEnv,
    "alfworld_agentboard": AlfworldEnv_AgentBoard,
    "sciworld": SciworldEnv,
    "sciworld_agentboard": SciworldEnv_Agentboard,
    "jericho": JerichoEnv,
    "babyai": BabyAIEnv,
    "pddl": PDDLEnv,
}

import re

###########################
#### Single Env Worker ####
###########################
class EnvironmentWorker(Process):
    """
    每个进程控制一个环境
    Args:
        Process (_type_): _description_
    """
    def __init__(self, task_queue, shared_result_queue):
        super().__init__()
        # init_env(env_config), step(action), kill env, end worker
        self.task_queue = task_queue
        # save all the results of envs, such as _actions, _thinks, _observations, _scores, 
        self.shared_result_queue = shared_result_queue
        # environments, we hope only one environment in one environment worker
        self.env = None
        self.env_class = None
        
    def run(self):
        # override the "run" function of Process
        while True:
            # 轮询
            task = self.task_queue.get()
            if task is None:   # 终止信号，在阻塞状态下，终止信号需要手动提供
                break 
            # print(task)
            task_type, running_id, data = task
            if task_type == "init":
                # 初始化环境
                uid, env_name, env_config = data["uid"], data["env_name"], data["env_config"]
                special_settings = data["special_settings"]
                self.env_class = ENV_CLASS_MAPPING[env_name]
                self.env = self.env_class(env_config, special_settings)
        
                init_res = {
                    "task_name": self.env.task_name,
                    "uid": uid,
                    "task_description": self.env.task_description,
                    "init_obs": self.env.init_obs,
                    "answer": self.env.answer if hasattr(self.env, 'answer') else None,
                    "system_prompt": self.env.system_prompt,
                    "user_prompt": self.env.user_prompt,
                    "system_prompt_for_deepthink": self.env.system_prompt_for_deepthink,
                    "user_prompt_for_deepthink": self.env.user_prompt_for_deepthink,
                    "done": self.env.done,
                    "score": self.env.score,
                    "init_history_traj": self.env.init_history_traj if hasattr(self.env, 'init_history_traj') else []
                }
                self.shared_result_queue.put((running_id, init_res))
            
            elif task_type == "execute":
                # 执行 step
                response = data
                exe_result = self.env_class.execute_pred_for_acitve_env(self.env, response)
                self.shared_result_queue.put((running_id, exe_result))
            elif task_type == "clear":
                # 销毁环境
                if self.env is not None:
                    del self.env
                self.env = None
                self.env_class = None
                

############################
#### Multi Envs Manager ####
############################
class MultiEnvManager:
    
    def __init__(self, total_env_infos):
        self.total_env_infos = total_env_infos
        
        # 启动工作进程, 等于要部署的环境数量
        num_workers = len(total_env_infos)
        self.task_queues = [Queue() for _ in range(num_workers)]
        self.shared_result_queue = Queue()
        self.workers = []
        for i in range(num_workers):
            worker = EnvironmentWorker(self.task_queues[i], self.shared_result_queue)
            worker.start()
            self.workers.append(worker)
    
    def init_envs(self):
        assert len(self.total_env_infos) == len(self.workers)
        
        for running_id, env_info in enumerate(self.total_env_infos):
            self.task_queues[running_id].put(("init", running_id, env_info))
        
        initial_feedbacks = []
        random_results = self._get_results(len(self.total_env_infos))  # 获取的结果顺序随机
        for _, init_res in sorted(random_results):   # sorted by running ids
            initial_feedbacks.append(init_res)
            
        return initial_feedbacks
    
    # 支持任意指定进程的并发环境执行
    def execute_actions(self, running_ids, responses):
        assert len(responses) <= len(self.workers)
        assert len(running_ids) == len(responses)
    
        for running_id, response in zip(running_ids, responses):
            self.task_queues[running_id].put(("execute", running_id, response))

        feedbacks = []
        random_results = self._get_results(len(responses))
        for running_id, exe_res in sorted(random_results):
            feedbacks.append(exe_res)
            
        return feedbacks
    
    ## auxiliary function for multiprocess envs
    def _get_results(self, expected_num):
        results = []
        while len(results) < expected_num:
            data = self.shared_result_queue.get()
            results.append(data)
        return results
    
    ## auxiliary function for multiprocess envs
    def _clear_all_envs(self):
        for task_queue in self.task_queues:
            task_queue.put(("clear", None, None))
            
    def shutdown(self):
        self._clear_all_envs()
        for q in self.task_queues:
            q.put(None)
        for worker in self.workers:
            worker.join()


########################
#### Buffer Manager ####
########################
class BufferManager:
    """
    trajectory:
        step: int
        action: standard action
        deepthought: summary, analysis, and new plan / None
        original_response1: str
        original_response2: str
    """
    
    def __init__(self, initial_feedbacks):
        self.step = 0
        self.batch_rollout_data = []
        
        for info in initial_feedbacks:
            self.batch_rollout_data.append({
                "uid": info["uid"],
                "task": info["task_description"],
                "init_obs": info["init_obs"],
                "answer": info["answer"],
                "system_prompt": info["system_prompt"],
                "user_prompt": info["user_prompt"],
                "system_prompt_for_deepthink": info["system_prompt_for_deepthink"],
                "user_prompt_for_deepthink": info["user_prompt_for_deepthink"],
                "state": {
                    "init_env_score": info["score"],  # 0.0
                    "env_score": info["score"],  # 0.0
                    "format_score": 0.0,  # 0.0
                    "deepthink_score": 0.0,   # 0.0  for deepthink
                    "last_deepthink_point": 0, # int
                    "expect_deepthink": False,
                    "current_deepthink": "",
                    "done": info["done"], # False
                    "score_up": False,
                },
                "trajectory": [],
                "actor_messages": [],
                "deepthink_messages": [],
            })
        
        env_feedback_trajs = [info["init_history_traj"] for info in initial_feedbacks]
        self.update_trajectory_batch(env_feedback_trajs)

    def build_prompts_for_actors(self):
        running_ids = []
        total_messages = []
        for running_id, data in enumerate(self.batch_rollout_data):
            if data["state"]["done"]:
                pass
            else:
                messages = BufferManager._build_single_message_for_actor(data)
                total_messages.append(messages)
                running_ids.append(running_id)     # update running ids
                
        return running_ids, total_messages

    def build_prompts_earlystop_for_actors(self):
        running_ids = []
        total_messages = []
        for running_id, data in enumerate(self.batch_rollout_data):
            if data["state"]["score_up"]:
                pass
            else:
                messages = BufferManager._build_single_message_for_actor(data)
                total_messages.append(messages)
                running_ids.append(running_id)     # update running ids
                
        return running_ids, total_messages

    def build_prompts_for_deepthinks(self, running_ids, force=False):
        summary_ids = []
        total_messages = []
        for running_id in running_ids:
            data = self.batch_rollout_data[running_id]
            if force or data["state"]["expect_deepthink"]:
                messages = BufferManager._build_single_message_for_deepthink(data)
                total_messages.append(messages)
                summary_ids.append(running_id)     # update running ids

        return summary_ids, total_messages

    def update_trajectory_batch(self, env_feedback_trajs):
        assert len(self.batch_rollout_data) == len(env_feedback_trajs)
        
        for data, feedback_traj in zip(self.batch_rollout_data, env_feedback_trajs):
            # update trajectory
            for feedback in feedback_traj:
                # update actor_messages
                _messages = BufferManager._build_single_message_for_actor(data)

                # uodate trajectory
                data["trajectory"].append({
                    "step": -1,
                    "original_response": feedback["original_response"],
                    "thought": feedback["think"],
                    "action": feedback["action"],
                    "score": feedback["score"],
                    "format_score": feedback["format_score"],  # 多步累积format penalty -> 单步step format 
                    "observation": feedback["observation"],
                })

                # update actor_messages
                data["actor_messages"][-1].append({
                    "role": "assistant",
                    "content": feedback["original_response"],
                    "format_score": feedback["format_score"],
                    "cal_loss": False,  # 不对之前的路径计算loss
                })
    
    def update_trajectory(self, running_ids, model_responses, env_feedbacks):
        assert len(running_ids) == len(model_responses)
        assert len(running_ids) == len(env_feedbacks)
        
        for running_id, response, feedback in zip(running_ids, model_responses, env_feedbacks):
            # update state
            self.batch_rollout_data[running_id]["state"]["done"] = feedback["done"]   # if True, 不再为其生成prompt
            self.batch_rollout_data[running_id]["state"]["env_score"] = feedback["score"]  # outcome reward
            self.batch_rollout_data[running_id]["state"]["expect_deepthink"] = feedback["expect_deepthink"]  # outcome reward
            if not self.batch_rollout_data[running_id]["state"]["score_up"]:  # 判断整条轨迹中是否有分数提升
                self.batch_rollout_data[running_id]["state"]["score_up"] = (feedback["reward"] > 0.0)

            # update trajectory
            self.batch_rollout_data[running_id]["trajectory"].append({
                "step": self.step,
                "original_response": response,
                "thought": feedback["think"],
                "action": feedback["action"],
                "score": feedback["score"],
                "format_score": feedback["format_score"],  # 多步累积format penalty -> 单步step format 
                "observation": feedback["observation"],
            })
            # update messages
            self.batch_rollout_data[running_id]["actor_messages"][-1].append({
                "role": "assistant",
                "content": response,
                "format_score": feedback["format_score"],
                "cal_loss": True,
            })
    
    def update_trajectory_for_deepthinks(self, summary_ids, model_responses):
        assert len(summary_ids) == len(model_responses)
        
        for running_id, response in zip(summary_ids, model_responses):
            ####################
            deep_think_pattern = r"<deepthink>(.*?)</deepthink>"
            deep_think_match = re.search(deep_think_pattern, response, re.DOTALL)

            if deep_think_match:   # 优先deep think match
                think_format_score = 0.0
                deepthink = deep_think_match.group(1)
                deepthink = deepthink.strip()
            else:
                think_format_score = -1.0
                deepthink = ""

            if deep_think_match and response.split("<deepthink>")[0] != "":
                think_format_score -= 0.5
            if deep_think_match and response.split("</deepthink>")[-1] != "":
                think_format_score -= 0.5
            #############
  
            self.batch_rollout_data[running_id]["trajectory"][-1].update({
                "original_deepthink": response,
                "deepthink": deepthink,
                "deepthink_score": think_format_score
            })
            self.batch_rollout_data[running_id]["state"].update({
                "last_deepthink_point": len(self.batch_rollout_data[running_id]["trajectory"]),
                "current_deepthink": deepthink
            })
            self.batch_rollout_data[running_id]["deepthink_messages"][-1].append({
                "role": "assistant",
                "content": response,
                "format_score": think_format_score,
                "cal_loss": True,
            })

    def update_final_score(self):
        for item in self.batch_rollout_data:
            item["state"]["format_score"] = sum([x["format_score"] for x in item["trajectory"]])
            deepthink_sum_score = 0.0
            for x in item["trajectory"]:
                if "deepthink_score" in x:
                    deepthink_sum_score += x["deepthink_score"]
            item["state"]["deepthink_score"] = deepthink_sum_score

    @staticmethod
    def _build_abbr_history_trajectory(trajectory):
        history = ""
        for traj in trajectory:
            act, obs = traj["action"], traj["observation"]
            # act = "error action" if act == "" else act
            obs = obs.replace("\n", " ")
            if act != "":
                history += f"\n> {act}\n{obs}"
        return history
            
    @staticmethod
    def _build_single_message_for_actor(data):
        user_prompt = data["user_prompt"]
        system_prompt = data["system_prompt"]

        # init
        if len(data["trajectory"]) == 0:
            data["actor_messages"].append([
                {
                    "role": "system",
                    "content": system_prompt,
                    "cal_loss": False,
                },
                {
                    "role": "user",
                    "content": user_prompt,
                    "cal_loss": False,
                }
            ])
            return data["actor_messages"][-1]

        elif data["state"]["last_deepthink_point"] != len(data["trajectory"]):
            data["actor_messages"][-1].append({
                "role": "user",
                "content": data["trajectory"][-1]["observation"],
                "cal_loss": False,
            })
            return data["actor_messages"][-1]


        elif data["state"]["last_deepthink_point"] == len(data["trajectory"]):
            data["actor_messages"][-1].extend([
                {
                    "role": "assistant",
                    "content": data["trajectory"][-1]["deepthink"],
                    "cal_loss": False,
                },
                {
                    "role": "user",
                    "content": "OK.",
                    "cal_loss": False,
                }
            ])

        return data["actor_messages"][-1]

    @staticmethod
    def _build_single_message_for_deepthink(data):
        user_prompt = data["user_prompt_for_deepthink"]

        # get history messages
        history = BufferManager._build_abbr_history_trajectory(data["trajectory"])

        assert history != ""
        user_prompt = user_prompt.replace("<interactive history>", history)

        message = [
            {
                "role": "system",
                "content": data["system_prompt_for_deepthink"],
                "cal_loss": False,
            },
            {
                "role": "user",
                "content": user_prompt,
                "cal_loss": False,
            }
        ]
        data["deepthink_messages"].append(message)
        return message


    @classmethod
    def pad_or_truncate(cls, tensor, max_len, pad_token_id, left_pad=False):
        padded_tensors = []
        for t in tensor:
            if t.size(0) > max_len:
                padded_tensors.append(t[:max_len])
                print("WARNING: prompt or response ({}) large than the max_length ({}) !!!".format(t.size(0), max_len))
            elif t.size(0) < max_len:
                padded_tensors.append(pad_sequence_to_length(t, max_len, pad_token_id, left_pad=left_pad))
            else:
                padded_tensors.append(t)
        return torch.stack(padded_tensors, dim=0)

    @classmethod
    def _process_messages(cls, messages, adv, chat_template, tokenizer, update_role):
        prompt_ids, response_ids, response_mask, advantages = [], [], [], []
        assert len(messages) >= 3
        chat_template = chat_template.lower()

        ### system prompt ######
        assert messages[0]["role"] == "system"
        input_text = messages[0]["content"]
        if chat_template == "qwen":
            input_ids = tokenizer.encode(f'<|im_start|>system\n{input_text}<|im_end|>\n', add_special_tokens=False)
        elif chat_template == "llama":
            input_ids = tokenizer.encode(f'<|start_header_id|>system<|end_header_id|>\n\n{input_text}<|eot_id|>\n', add_special_tokens=False)
        else:
            raise ValueError("Unknown chat template value.")
        prompt_ids += input_ids

        for item in messages[1:]:
            input_text = item["content"]

            if item["role"] == "user": 
                #### user prompt ######
                if chat_template == "qwen":
                    input_ids = tokenizer.encode(f'<|im_start|>user\n{input_text}<|im_end|>\n', add_special_tokens=False)
                elif chat_template == "llama":
                    input_ids = tokenizer.encode(f'<|start_header_id|>user<|end_header_id|>\n\n{input_text}<|eot_id|>\n', add_special_tokens=False)
                response_ids += input_ids
                response_mask += [False] * len(input_ids)
                advantages += [0.0] * len(input_ids)

            elif item["role"] == "assistant":
                #### assitant prompt ######
                if chat_template == "qwen":
                    input_ids = tokenizer.encode(f'<|im_start|>assistant\n{input_text}<|im_end|>\n', add_special_tokens=False)
                    assistant_head = tokenizer.encode(f'<|im_start|>assistant\n', add_special_tokens=False)
                    assistant_tail = tokenizer.encode(f'<|im_end|>\n', add_special_tokens=False)
                elif chat_template == "llama":
                    input_ids = tokenizer.encode(f'<|start_header_id|>assistant<|end_header_id|>\n\n{input_text}<|eot_id|>\n', add_special_tokens=False)
                    assistant_head = tokenizer.encode(f'<|start_header_id|>assistant<|end_header_id|>\n\n', add_special_tokens=False)
                    assistant_tail = tokenizer.encode(f'<|eot_id|>\n', add_special_tokens=False)
                response_ids += input_ids
                if not item["cal_loss"]:
                    response_mask += [False] * len(input_ids)
                    advantages += [0.0] * len(input_ids)
                else:
                    response_mask += [False] * len(assistant_head) + [True] * ((len(input_ids) - len(assistant_head)))

                    # cal advantage
                    head_len, tail_len = len(assistant_head), len(assistant_tail)
                    middle_len = len(input_ids) - head_len - tail_len

                    if update_role == "thinker":
                        advantages += [0.0] * head_len + [adv] * (middle_len + tail_len)

                    elif update_role == "actor":
                        if item["format_score"] < 0.0:  # step-level format penalty
                            advantages += [0.0] * (head_len + middle_len + tail_len)
                        else:
                            advantages += [0.0] * head_len + [adv] * (middle_len + tail_len)
        
        return prompt_ids, response_ids, response_mask, advantages

    @classmethod
    def rebuild_prompt_repsonse_mask_advantage_from_trajectory(cls, outcome_adv, rollout_data, tokenizer, config):
        """
        texts -> 
            prompts,     [0, 0, 0, ... p, p, p]  BN * max_prompt_length
            responses,   [R, R, R, ... 0, 0, 0]  BN * max_response_length
            # input_ids, 
            # attention_mask, 
            # position_ids, 
            response_mask
             ->
            rewards
            advantages
        """
        # settings
        max_prompt_length = config.data.max_prompt_length
        max_response_length = config.data.max_response_length
        update_role = config.actor_rollout_ref.rollout.train_actor_or_thinker
        chat_template = config.actor_rollout_ref.model.chat_template

        batch_prompts, batch_responses, batch_response_mask = [], [], []
        batch_advantages = []

        adv = outcome_adv

        if update_role == "thinker":
            used_messages = rollout_data["deepthink_messages"]
        elif update_role == "actor":
            used_messages = rollout_data["actor_messages"]
        else:
            raise ValueError("Unknown update_role value.")
       
        for messages in used_messages:
            prompt_ids, response_ids, response_mask, advantages = cls._process_messages(messages, adv, chat_template, tokenizer, update_role)
            
            batch_prompts.append(torch.tensor(prompt_ids, dtype=torch.int64))
            batch_responses.append(torch.tensor(response_ids, dtype=torch.int64))
            batch_response_mask.append(torch.tensor(response_mask, dtype=torch.bool))
            batch_advantages.append(torch.tensor(advantages))

         
        if len(batch_prompts) > 0:
            batch_prompts = cls.pad_or_truncate(batch_prompts, max_prompt_length, tokenizer.pad_token_id, left_pad=True)
            batch_responses = cls.pad_or_truncate(batch_responses, max_response_length, tokenizer.pad_token_id)
            batch_response_mask = cls.pad_or_truncate(batch_response_mask, max_response_length, False)
            batch_advantages = cls.pad_or_truncate(batch_advantages, max_response_length, 0.0)

        return batch_prompts, batch_responses, batch_response_mask, batch_advantages


    @classmethod
    def remake_dataproto(cls, gen_batch_output, outcome_adv_batch, tokenizer, config):
        total_prompts, total_responses = [], []
        total_response_mask, total_advantages = [], []

        assert len(outcome_adv_batch) == len(gen_batch_output.non_tensor_batch["batch_rollout_data"])
        for outcome_adv, item in zip(outcome_adv_batch, gen_batch_output.non_tensor_batch["batch_rollout_data"]):

            prompts, responses, response_mask, advantages = \
                cls.rebuild_prompt_repsonse_mask_advantage_from_trajectory(outcome_adv, item, tokenizer, config)
            
            if len(prompts) > 0:
                total_prompts.append(prompts)
                total_responses.append(responses)
                total_response_mask.append(response_mask)
                total_advantages.append(advantages)
                
        
        total_prompts = torch.concatenate(total_prompts, dim=0)
        total_responses = torch.concatenate(total_responses, dim=0)
        total_response_mask = torch.concatenate(total_response_mask, dim=0)
        total_advantages = torch.concatenate(total_advantages, dim=0)
        
        # make attention_mask and position_ids
        total_input_ids = torch.concat((total_prompts, total_responses), dim=-1)
        total_attention_mask = torch.where(total_input_ids != tokenizer.pad_token_id, 1, 0)
        total_position_ids = (torch.cumsum(total_attention_mask, dim=1) - 1) * total_attention_mask
        
        """
        left padding for prompt, and right padding for response
        0 0 0 p p p | r r r r r 0 0 0 
        0 0 p p p p | r r r r r r r r 
        0 0 0 p p p | r r r r r 0 0 0 
        """
        final_output = {
            "prompts": total_prompts,
            "responses": total_responses,
            "input_ids": total_input_ids,
            "attention_mask": total_attention_mask,
            "position_ids": total_position_ids,
            "generation_mask": total_response_mask,
            "advantages": total_advantages,
            "returns": total_advantages,
        }

        ## WARNING, 当batch size 改变时，需要做出调整。
        total_env_num = len(gen_batch_output.non_tensor_batch["batch_rollout_data"])
        external_info = {}
        external_info['total_env'] = np.array([1 for _ in range(total_env_num)], dtype=np.int64)
        external_info['finished_env'] = np.array([0 for _ in range(total_env_num)], dtype=np.int64)
        external_info['traj_length'] = np.array([0 for _ in range(total_env_num)], dtype=np.int64)

        external_info['env_score'] = np.array([0 for _ in range(total_env_num)], dtype=np.float32)
        external_info['format_score'] = np.array([0 for _ in range(total_env_num)], dtype=np.float32)
        external_info['deepthink_score'] = np.array([0 for _ in range(total_env_num)], dtype=np.float32)

        for idx in range(total_env_num):
            external_info['finished_env'][idx] = int(gen_batch_output.non_tensor_batch["batch_rollout_data"][idx]["state"]["done"])
            external_info['traj_length'][idx] = len(gen_batch_output.non_tensor_batch["batch_rollout_data"][idx]["trajectory"]) 
            external_info['env_score'][idx] = float(gen_batch_output.non_tensor_batch["batch_rollout_data"][idx]["state"]["env_score"])
            external_info['format_score'][idx] = float(gen_batch_output.non_tensor_batch["batch_rollout_data"][idx]["state"]["format_score"])
            external_info['deepthink_score'][idx] = float(gen_batch_output.non_tensor_batch["batch_rollout_data"][idx]["state"]["deepthink_score"])

        return DataProto.from_dict(final_output), external_info

import numpy as np
from collections import defaultdict
def compute_grpo_outcome_advantage_numpy(scores: list[float], index: list[str], epsilon: float = 1e-6):
    """
    Compute advantage for GRPO, operating only on Outcome reward 
    (with only one scalar reward for each response).
    Args:
        scores: `(torch.Tensor)`
            shape: (bs, 1)
        index: `(torch.Tensor)`
            shape: (bs, 1)
    
    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, 1)
    """
    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}
    scores = np.array(scores, dtype=np.float32)

    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])
    for idx in id2score:
        if len(id2score[idx]) == 1:
            id2mean[idx] = np.array(0.0)
            id2std[idx] = np.array(1.0)
        elif len(id2score[idx]) > 1:
            id2mean[idx] = np.mean(np.array(id2score[idx]))
            id2std[idx] = np.std(np.array([id2score[idx]]), ddof=1)  # 启用无偏估计
        else:
            raise ValueError(f"no score in prompt index: {idx}")
    
    # print(id2mean)
    # print(id2std)
    for i in range(bsz):
        scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
    return scores, id2std


