import abc
from typing import Dict, List, Optional
import numpy as np
import copy
import pdb
import torch
from utils import print_with_rank
from transformers import PreTrainedTokenizer
from reason.inference.lm_call import LMCallingConfig, ConcatedLMGenResult
from metrics.recorder import LatencyRecorder

INVALID_ANS = "[invalid]"


class NoLegalActionException(Exception):
    pass


class ResetException(Exception):
    pass


class BaseEnv(abc.ABC):
    """Basic environment to use for MCTS"""

    @abc.abstractmethod
    def reset(self, update_legal_action: bool):
        raise NotImplementedError

    @abc.abstractmethod
    def step(self, action, update_legal_action=True):
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def legal_actions(self):
        raise NotImplementedError

    @abc.abstractmethod
    def copy(self):
        raise NotImplementedError

    @staticmethod
    def build_query_str(
        cot_task_desc: Optional[str],
        cot_examples: Optional[str],
        problem_format_str: str,
        problem_input: str,
        is_few_shot: bool = False,
        model_names = [],
    ):
        """a wrap function that wrap the problem text with certrain format
        e.g. prompt_str = "Input: " + join_numbers(" ", xs) + "\nSteps:\n"
        >>> query_str = Game24Env.build_query_str("1 1 1 1")
        >>> print(query_str)
        >>> Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
        Input: 1 1 1 1
        Steps:

        >>>
        """

        messages = []
        problem_format_str = problem_format_str.format(question=problem_input)
        if 'deepseek-r1' in model_names[0].lower():
            messages.append({"role": "user", "content": problem_format_str + '\n' + cot_task_desc})
        else:
            if cot_task_desc:
                messages.append({"role": "system", "content": cot_task_desc})
            if is_few_shot:
                for example in cot_examples:
                    messages.append({"role": "user", "content": example["question"]})
                    messages.append({"role": "assistant", "content": example["answer"]})
            messages.append({"role": "user", "content": problem_format_str})

        return messages

    @staticmethod
    def build_response_str(answer_str: str, tokenizer: PreTrainedTokenizer, add_eos_token: bool):
        raise NotImplementedError


class CoTEnv(BaseEnv):
    """The basic environment for solving natural language problems using CoT"""

    def _is_correct(self, completion) -> bool:
        raise NotImplementedError

    def get_reward(self):
        """To implement based on learned reward model"""
        raise NotImplementedError

    def __init__(
        self,
        config,
        math_problems,
        llm_gen_fns,
        rm_call,
        task_desc_str: str,
        cot_example_str: str,
        problem_format_str: str,
        reset=True,
        sep=None,
        model_names=[],
        update_legal_action=True,
        latency_recorder: Optional[LatencyRecorder] = None,
    ):
        self.config = config
        self.mcts_mode = "play_with_bot_mode"
        self.math_problems = math_problems
        self.llm_gen_fns = llm_gen_fns
        self.rm_call = rm_call
        self.action_history = None
        self.reward_history = None
        self.token_history = None
        self.prob_history = None
        self.model_history = None
        self.math_problem = None
        self._legal_actions = None
        self.is_few_shot = config.get("is_few_shot", False)
        self.add_step_prompt = config.get("add_step_prompt", False)
        self.direct_io = config.get("direct_io", 0)
        self.double_line_break = config.get("double_line_break", 0)
        # self.prm_step_tag = rm_call.prm_step_tag  # "ки\n"
        self.prm_step_tag = "ки\n"
        self.sep = sep
        self.model_names = model_names

        self.selected_model_idx = config.get("selected_model_idx", 0)
        self.latency_recorder = latency_recorder or LatencyRecorder()

        if config.get("cot_prompt", ""):
            task_desc_str = config["cot_prompt"]
        self._task_desc_str = task_desc_str
        self._cot_example_str = cot_example_str
        self._problem_format_str = problem_format_str

        prefixes = []
        if self._task_desc_str is not None:
            prefixes.append(self._task_desc_str)
        if self.is_few_shot:
            prefixes.append(self._cot_example_str)
        if len(prefixes) > 0:
            self.task_prefix = "\n".join(prefixes)
        else:
            self.task_prefix = None

        if reset:
            self.reset(update_legal_action=update_legal_action)

    def reset(self, update_legal_action=True):
        # reset environment to problem idx
        self.set_problem(idx=0)
        self.action_history = []
        self.reward_history = []
        self.token_history = []
        self.latency_recorder.reset(question_id=self.math_problem.get("id"))
        self.question_start_time = self.latency_recorder.start_question()
        self.question_latency = 0.0        # Total latency for current question
        self.prob_history = []
        self.model_history = []
        self._init_query = self.build_query_str(
            cot_examples=self._cot_example_str,
            cot_task_desc=self._task_desc_str,
            problem_format_str=self._problem_format_str,
            problem_input=self.math_problem["question"],
            is_few_shot=self.is_few_shot,
            model_names=self.model_names,
        )
        if update_legal_action:
            cnt = 0
            max_try = 1
            while cnt < max_try:
                cnt += 1
                try:
                    self._legal_actions, api_completion_token = self.update_legal_actions(initial=True)
                    break
                except Exception as e:
                    if cnt == max_try:
                        self._legal_actions, api_completion_token = self.update_legal_actions(initial=True, force_update=True)
                        print("Force update legal actions:", self._legal_actions)
        else:
            api_completion_token = 0
        info = {"api_completion_token": api_completion_token}
        return self.get_state(model_name='raw'), info

    def step(self, action, update_legal_action=True, model_name="", custom_n=0, reward=0.0, num_token=0, prob=0.0):
        self.action_history.append(action)
        self.reward_history.append(reward)
        self.token_history.append(num_token)
        self.prob_history.append(prob)
        if model_name:
            self.model_history.append(model_name)
        state = self.get_state(model_name=model_name)
        reward = self.get_reward()
        terminated, truncated, info = self.get_done_and_info()

        if not (terminated or truncated) and update_legal_action:  # update legal actions
            cnt = 0
            # while cnt < 3:
            while cnt < 5:
                cnt += 1
                try:
                    self._legal_actions, api_completion_token = self.update_legal_actions(custom_n=custom_n)
                    info["api_completion_token"] = api_completion_token
                    break
                except NoLegalActionException as e:
                    # if cnt == 3:
                    if cnt == 5:
                        terminated = True
                        reward = 0
                        self._legal_actions = None
                        info["winner"] = 2
                        info["api_completion_token"] = 0
                    else:
                        pass
        else:
            self._legal_actions = None
            if info["winner"] == 1:
                reward = 1.0
            info["api_completion_token"] = 0
        return state, reward, terminated, truncated, info

    def get_state(self, model_name='other', add_step_prompt=False):
        messages = copy.deepcopy(self._init_query)
        messages.append({"role": "assistant", "content": "".join(self.action_history)})

        if add_step_prompt and self.direct_io != 2:
            if 'llama-3' in self.model_names[0].lower():  # TODO: Check llama
                sep = "## Step"
            else:
                sep = "Step"
            if not self.double_line_break:  # TODO: Check double
                messages[-1]["content"] += f"{sep} {len(self.action_history) + 1}: "
        if model_name == 'raw':
            ret = ""
            for idx, mess in enumerate(messages):
                ret += f'{mess["role"]}: {mess["content"]}'
                if idx < len(messages) - 1:
                    ret += '\n'
            return ret
        return messages
    
    def get_prefix_key(self):
        """get a hashable key representing the current prefix state"""
        return tuple(self.action_history)
    
    @staticmethod
    def batch_get_states(envs, model_name='other', add_step_prompt=False):
        """Batch get environment states to avoid redundant operations"""
        return [env.get_state(model_name, add_step_prompt) for env in envs]

    def post_process_act(self, action: str):
        # This step may change the token count
        return action

    def update_legal_actions(self, initial=False, force_update=False, custom_n=0):
        completion_tokens = 0
        
        lm = self.llm_gen_fns[0] if len(self.llm_gen_fns) == 1 else self.llm_gen_fns[self.selected_model_idx]
        model_idx = 0 if len(self.llm_gen_fns) == 1 else self.selected_model_idx
        
        if initial:
            n = self.config["max_actions"]
        elif custom_n: # Allow dynamic adjustment of sampling number for advanced search algorithms
            n = custom_n
        else:
            n = self.config["max_actions"] // self.config["beam_size"]
            # n = self.config["beam_size"]

        
        stop_str = None if self.direct_io else self.sep
        include_stop_str_in_output = not self.direct_io
        
        first_generation = len(self.action_history) == 0
        messages = self.get_state(lm.model_name, add_step_prompt=self.add_step_prompt)
        
        import time
        lm_start_time = time.time()
        result: ConcatedLMGenResult = lm(
            messages=messages,
            config=LMCallingConfig(
                n=n,
                stop_str=stop_str,
                include_stop_str_in_output=include_stop_str_in_output,
                first_generation=first_generation,
                **self.config["generation_config"],
            ),
        )
        lm_latency = time.time() - lm_start_time
        self.latency_recorder.record_lm_latency(lm_latency)
        
        texts = result.text
        logps_avg_by_len = result.logp_avg_by_len
        token_len = result.num_tokens
        temp_model_names = [lm.model_name] * len(texts)
        temp_model_ids = [model_idx] * len(texts)
        
        if not isinstance(result.finish_reason, list):
            raise ValueError("finish_reason should be a list")
        finish_reason_list = result.finish_reason
        try:
            completion_tokens = result.completion_tokens
        except Exception:
            completion_tokens = 0

        text_list, prob_list, num_token_list = [], [], []
        model_names, model_ids = [], []
        next_state_terminated = {}
        raw_text_list = []

        for i in range(len(texts)):
            terminated = True
            if not self.direct_io:
                if isinstance(self.sep, str):
                    terminated = not texts[i].endswith(self.sep)
                elif isinstance(self.sep, list):
                    terminated = not any(texts[i].endswith(sep) for sep in self.sep)
            processed_act = self.post_process_act(texts[i])
            finish_reason = finish_reason_list[i]
            if not self.double_line_break:
                temp_act = processed_act.replace("## Step ", "Step ")
                is_double_line_break = temp_act.endswith("\n\n") and temp_act.startswith("Step ") and (len(temp_act) == len("Step 1: \n\n") or len(temp_act) == len("Step 10: \n\n"))
                if is_double_line_break:
                    finish_reason = "length"
            
            
            if len(processed_act) > 0 and processed_act not in text_list and finish_reason == "stop":
            # if len(processed_act) > 0 and finish_reason == "stop":
            

                text_list.append(processed_act)
                raw_text_list.append(texts[i])
                prob_list.append(logps_avg_by_len[i])
                num_token_list.append(token_len[i])
                next_state_terminated[processed_act] = terminated
                model_names.append(temp_model_names[i])
                model_ids.append(temp_model_ids[i])
            elif force_update and len(processed_act) > 0:
                text_list.append(processed_act)
                raw_text_list.append(texts[i])
                prob_list.append(logps_avg_by_len[i])
                num_token_list.append(token_len[i])
                next_state_terminated[processed_act] = terminated
                model_names.append(temp_model_names[i])
                model_ids.append(temp_model_ids[i])

        # Optimization: Ensure a minimum number of candidate actions
        if len(prob_list) == 0:
            print_with_rank("state: {}".format(self.get_state(model_name='raw')))
            if len(self.llm_gen_fns) == 1:
                print_with_rank("gen_result: {}".format(result))
            raise NoLegalActionException("No possible action have been generated.")
        

        prob_list = list(np.exp(prob_list))

        _legal_actions = [{
            "action": action,
            "prob": prob,
            "num_token": n_token,
            "finish_reason": finish_reason,
            "model_name": model_name,
            "model_id": model_id,
            "messages": messages,
            "stop_str": stop_str,
            "raw_action": raw_action,
        } for action, prob, n_token, finish_reason, model_name, model_id, raw_action in zip(text_list, prob_list, num_token_list,
            finish_reason_list, model_names, model_ids, raw_text_list)]

        self._next_state_terminated = next_state_terminated

        return _legal_actions, completion_tokens

    def set_problem(self, idx):
        self.math_problem = self.math_problems[idx]

    @property
    def query(self):
        return self._init_query

    @property
    def question(self) -> str:
        return self.math_problem["question"]

    @property
    def answer(self):
        if len(self.action_history) == 0:
            return ""
        elif self.direct_io == 2:
            assert len(self.action_history) == 1
            return self.action_history[0]
        elif self.direct_io == 1:
            assert len(self.action_history) == 1
            steps = self.action_history[0].split("\n\n")
            answer = ""
            for step in steps:
                if step.strip() == "":
                    continue
                answer += step.strip() + f" {self.prm_step_tag}"
            return answer
        else:
            answer = ""
            for action in self.action_history:
                answer += action.strip() + f" {self.prm_step_tag}"
            return answer

    def check_stop_by_answer(self):
        if isinstance(self._stop_str, str) and self._stop_str in self.action_history[-1]:
            terminated = True
        elif isinstance(self._stop_str, list):
            terminated = True
            for stop_str in self._stop_str:
                if stop_str not in self.action_history[-1]:
                    terminated = False
        return terminated

    def check_stop_by_sep(self):
        if isinstance(self.sep, str):
            return self.sep not in self.action_history[-1]
        elif isinstance(self.sep, list):
            for sep in self.sep:
                if sep in self.action_history[-1]:
                    return False
        return False

    def get_done_and_info(self):
        info = {"winner": 0}
        # done when reaches maximum length or LLM generates stop words
        if self._stop_str is not None and self.check_stop_by_answer():
            terminated = True
        elif self._next_state_terminated[self.action_history[-1]]:
            terminated = True
        else:
            terminated = self.check_stop_by_sep()

        if self.config["max_length"] > 1:
            truncated = len(self.action_history) >= self.config["max_length"]
            assert len(self.action_history) <= self.config["max_length"]
        else:
            truncated = False
        if terminated or truncated:
            if self._is_correct(self.action_history[-1]):
                info["winner"] = 1
            else:
                info["winner"] = 2
            return terminated, truncated, info
        return terminated, truncated, info

    def copy(self):
        """Optimized environment copy: use shallow copy + selective deep copy to improve performance"""
        env = self.__class__(
            self.config,
            self.math_problems,
            self.llm_gen_fns,
            self.rm_call,
            self._task_desc_str,
            self._cot_example_str,
            self._problem_format_str,
            reset=False,
        )
        # For objects that won't be modified, use shallow copy
        env.math_problem = self.math_problem  
        env._init_query = self._init_query  
        
        # For lists that will be modified, use list() for shallow copy (faster than deepcopy)
        env.action_history = list(self.action_history)
        env.reward_history = list(self.reward_history)
        env.token_history = list(self.token_history)
        env.prob_history = list(self.prob_history)
        env.model_history = list(self.model_history)
        env.latency_recorder = self.latency_recorder.clone()
        
        # For dictionaries, use shallow copy
        env._legal_actions = copy.copy(self._legal_actions) if self._legal_actions else None
        env._next_state_terminated = dict(self._next_state_terminated) if hasattr(self, '_next_state_terminated') else {}
        
        # Simple values are directly assigned
        env.question_start_time = self.question_start_time
        env.question_latency = self.latency_recorder.question_latency
        return env

    @property
    def legal_actions(self):
        return self._legal_actions

    @property
    def step_latency_history(self) -> List[float]:
        return self.latency_recorder.step_latency_history

    @property
    def step_lm_latency_history(self) -> List[float]:
        return self.latency_recorder.step_lm_latency_history

    @property
    def step_rm_latency_history(self) -> List[float]:
        return self.latency_recorder.step_rm_latency_history

    @property
    def step_wait_history(self) -> List[float]:
        return self.latency_recorder.step_wait_history
