from abc import ABC, abstractmethod
import re
from typing import Optional, List, Tuple, Any, Dict

from verl.environments.prompts import (
    SYSTEM_PROMPT_FOR_ACTION,
    USER_PROMPT_FOR_ACTION,
    USER_GuideLines_NO_THINK,
    USER_GuideLines_LESS_THINK,
    USER_GuideLines_THINK,
    SYSTEM_PROMPT_FOR_DEEPTHINK,
    USER_PROMPT_FOR_DEEPTHINK,
    USER_GuideLines_THINK_OLD
)

class BaseEnv(ABC):
    def __init__(self, special_settings):
        self.env = None

        # env states for single-step
        self.current_step = 0
        self.score = 0    # outcome reward
        self.reward = 0   # single-step reward
        self.gameDone = False  # END (complete or fail)
        self.done = False   # task complete
        self.over = False   # task fail

        # env states for the whole trajector
        self._responses = [] # list of all responses from LLM
        self._thinks = [] # list of thinks that are in the correct format
        self._actions = [] # list of actions that are in the correct format
        self._observations = [] # list of all observations
        self._rewards = [] # list of rewards
        self._format_scores = [] # list of format scores

        ####################################
        #### Special for Multi Agent RL ####
        ####################################
        self.use_old_output_format = special_settings.get("use_old_output_format", False)   # Think: ... Action: ...
        self.no_short_thought = special_settings.get("no_short_thought", False)  # <answer> ... </answer>
        self.thinker_freq = special_settings.get("thinker_freq", -1)

        ###################################
        #### prepare prompt for actor #####
        ###################################
        self.system_prompt = SYSTEM_PROMPT_FOR_ACTION
        self.user_prompt = USER_PROMPT_FOR_ACTION

        # replace guideline
        if self.no_short_thought:
            self.user_prompt = self.user_prompt.replace("<guidelines>", USER_GuideLines_NO_THINK)
        else:
            if self.use_old_output_format:
                self.user_prompt = self.user_prompt.replace("<guidelines>", USER_GuideLines_THINK_OLD)
            else:
                self.user_prompt = self.user_prompt.replace("<guidelines>", USER_GuideLines_THINK)

        ###################################
        #### prepare prompt for actor #####
        ###################################
        self.system_prompt_for_deepthink = SYSTEM_PROMPT_FOR_DEEPTHINK
        self.user_prompt_for_deepthink = USER_PROMPT_FOR_DEEPTHINK


    def success(self):
        return self.done
    
    def fail(self):
        return self.over
    
    @abstractmethod
    def step(self, action):
        pass

    def false_step(self):
        # update stage
        self.current_step += 1
        if self.current_step >= self.max_step:
            self.gameDone = True

    def reset(self, seed=None):
        if seed:
            _, info = self.env.reset(seed=seed)
        else:
            _, info = self.env.reset()

        self.current_step = 0
        self.score = 0
        self.reward = 0
        self.gameDone = False  # END (complete or fail)
        self.done = False   # task complete
        self.over = False   # task fail

        # trajectors
        self._responses = []
        self._actions = []
        self._thinks = []
        self._rewards = []
        self._format_scores = []

    def _update_tracking_variables(
            self, 
            response: str,
            think: str, 
            think_is_valid: bool,
            deepthink_is_valid: bool,
            action: str, 
            action_is_valid: bool,
            observation: str, 
            reward: float,
            format_score: float,
        ):
        """
        All of _actions, _actions_valid, _actions_effective are lists of the same length
            - None is used for _actions_valid and _actions_effective if the action is invalid or ineffective
        """
        self._responses.append(response)
        if think_is_valid or deepthink_is_valid:
            self._thinks.append(think)
        else:
            self._thinks.append(None)
        if action_is_valid:
            self._actions.append(action)
        else:
            self._actions.append(None)
        self._observations.append(observation)
        self._rewards.append(reward)
        self._format_scores.append(format_score)

    @classmethod
    def extract_action_III(cls, prediction):
        """
        Thought: [put your thought here]
        Action: [put your action here]
        """
        # valid
        think_is_valid, deepthink_is_valid, action_is_valid = False, False, False

        # 提取 thought
        think = prediction.split("Thought:")[-1].split("Action:")[0].strip()
        if think != "":
            think_is_valid = True
            think_format_score = 0.0
        else:
            think_format_score = -1.0
        
        # 提取 action
        action = prediction.split("Action:")[-1].strip()
        if action != "":
            action_is_valid = True
            action_format_score = 0.0
        else:
            action_format_score = -1.0

        other_format_score = 0.0
        format_score = (think_format_score + action_format_score + other_format_score) * 0.5
    
        return format_score, think, think_is_valid, deepthink_is_valid, action, action_is_valid

    @classmethod
    def extract_action_II(cls, prediction):
        """
        extract format such as "<think> [Your Thoughts] </think> <answer> [Your Action] </answer>"
        OR "<deepthink> [Your Thoughts] </deepthink> <answer> [Your Action] </answer>"
        """
        # prediction = "<think>" + prediction
        # 正则表达式匹配 <think> 和 <answer> 标签中的内容
        think_pattern = r"<think>(.*?)</think>"
        deep_think_pattern = r"<deepthink>(.*?)</deepthink>"
        action_pattern = r"<answer>(.*?)</answer>"

        # valid
        think_is_valid, deepthink_is_valid, action_is_valid = False, False, False

        # 提取 <think> 标签中的内容
        deep_think_match = re.search(deep_think_pattern, prediction, re.DOTALL)
        think_match = re.search(think_pattern, prediction, re.DOTALL)

        if deep_think_match:   # 优先deep think match
            think_format_score = 0.0
            think = deep_think_match.group(1)
            think = think.strip()
            deepthink_is_valid = True
        elif think_match:
            think_format_score = 0.0
            think = think_match.group(1)
            think = think.strip()
            think_is_valid = True
        else:
            think_format_score = -1.0
            think = ""

        # 提取 <answer> 标签中的内容
        action_match = re.search(action_pattern, prediction)
        if action_match:
            action_format_score = 0.0
            action = action_match.group(1)
            action = action.strip()
            action_is_valid = True
        else:
            action_format_score = -1.0
            action = ""
        
        other_format_score = 0.0
        # 1. 多重 thought
        # if think_is_valid and deepthink_is_valid:
        #     other_format_score -= 1.0
        # 2. 前缀
        if think_match and prediction.split("<think>")[0] != "":
            other_format_score -= 1.0

        if deep_think_match and prediction.split("<deepthink>")[0] != "":
            other_format_score -= 1.0

        # 3. 后缀
        if action_is_valid and prediction.split("</answer>")[-1] != "":
            other_format_score -= 1.0
        
        format_score = (think_format_score + action_format_score + other_format_score) * 0.5
    
        return format_score, think, think_is_valid, deepthink_is_valid, action, action_is_valid

    @classmethod
    def extract_action_I(cls, prediction):
        """
        extract format such as "<answer> [Your Action] </answer>"
        """
        # prediction = "<think>" + prediction
        # 正则表达式匹配 <think> 和 <answer> 标签中的内容
        action_pattern = r"<answer>(.*?)</answer>"

        # valid
        think_is_valid, deepthink_is_valid, action_is_valid = False, False, False

        # 提取 <answer> 标签中的内容
        action_match = re.search(action_pattern, prediction)
        if action_match:
            action_format_score = 0.0
            action = action_match.group(1)
            action = action.strip()
            action_is_valid = True
        else:
            action_format_score = -1.0
            action = ""
        
        other_format_score = 0.0

        # 3. 后缀
        if action_is_valid and prediction.split("<answer>")[0] != "":
            other_format_score -= 0.5
        if action_is_valid and prediction.split("</answer>")[-1] != "":
            other_format_score -= 1.0
        
        format_score = (action_format_score + other_format_score) * 0.5

        think = ""
        return format_score, think, think_is_valid, deepthink_is_valid, action, action_is_valid
    
    @classmethod
    def execute_pred_for_acitve_env(cls, env, response: str):
        """
        for not active envs, there are no predictions
        """
        # 1. parse actions
        if env.use_old_output_format:
            format_score, think, think_is_valid, deepthink_is_valid, action, action_is_valid = cls.extract_action_III(response)
        else:
            if env.no_short_thought:
                format_score, think, think_is_valid, deepthink_is_valid, action, action_is_valid = cls.extract_action_I(response)
            else:
                format_score, think, think_is_valid, deepthink_is_valid, action, action_is_valid = cls.extract_action_II(response)
        
        # 3 step, we update tracking_variables in env.step
        if action_is_valid:
            observation, reward, score, done = env.step(action)
        else:
            observation, reward, score, done = "Format Error!", 0, env.score, env.done

        if env.use_old_output_format:
            observation = "Observation: " + observation

        # update_tracking_variables
        env._update_tracking_variables(
            response=response, 
            think=think,
            think_is_valid=think_is_valid,
            deepthink_is_valid=deepthink_is_valid,
            action=action, 
            action_is_valid=action_is_valid,
            observation=observation,
            reward=reward, 
            format_score=format_score, 
        )
        return {
            "response": response,
            "think_is_valid": think_is_valid,
            "deepthink_is_valid": deepthink_is_valid,
            "action_is_valid": action_is_valid,
            "think": think,
            "action": action,
            "observation": observation,
            "reward": reward,
            "gameDone": env.gameDone,
            "done": env.done,
            "over": env.over,
            "score": env.score,
            "format_score": format_score,  # single-step format reward
            "expect_deepthink": (env.thinker_freq > 0 and action_is_valid and env.current_step > 0 and env.current_step % env.thinker_freq == 0)  # 定时器
        }