import asyncio
import concurrent.futures
import logging
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
import os
import json

import numpy as np
import openai
import torch
from openai.types import Completion

from rllm.agents.agent import Action, BaseAgent, Trajectory
from rllm.agents.utils import (
    convert_messages_to_tokens_and_masks,
    get_recent_assistant_user_messages,
)
from rllm.environments.base.base_env import BaseEnv
from rllm.environments.env_utils import (
    compute_mc_return,
    compute_trajectory_reward,
)
from rllm.misc import colorful_print
from rllm.parser.chat_template.parser import ChatTemplateParser
from rllm.router.router import Router
from rllm.api_utils import azure_openai_credential, SPECIAL_MODEL_LIST

logger = logging.getLogger(__name__)

class ReplayEngine:
    def __init__(
        self,
        tokenizer=None,
        chat_parser=None,
        max_steps=5,
        max_response_length=8192,
        max_prompt_length=1024,
        n_parallel_agents=8,
        config=None,
        agent_class=None,
        env_class=None,
        agent_args=None,
        env_args=None,
        max_workers=64,
        enforce_max_prompt_length=False,  # If enabled, applies max_prompt check per step
        mode="sft",
        **kwargs,
    ):
        if agent_args is None:
            agent_args = {}
        if env_args is None:
            env_args = {}

        self.config = config
        self.tokenizer = tokenizer
        self.n_parallel_agents = n_parallel_agents

        # For interaction
        self.max_steps = max_steps
        self.max_response_length = max_response_length
        self.max_prompt_length = max_prompt_length
        self.enforce_max_prompt_length = enforce_max_prompt_length

        self.agent_class = agent_class
        self.agent_args = agent_args
        self.env_class = env_class
        self.env_args = env_args

        self.agents = [None for _ in range(n_parallel_agents)]
        self.envs = [None for _ in range(n_parallel_agents)]

        self.mode = mode

        if env_class is not None:
            assert env_class.is_multithread_safe(), "Environment must be multithread safe for async engine"

        self.chat_mode = kwargs.get("chat_mode", False)

        # Create a thread pool executor for environment interactions (i.e. step, reset, close)
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)

        if chat_parser is None:
            self.chat_parser = ChatTemplateParser.get_parser(self.tokenizer, disable_thinking=kwargs.get("disable_thinking", False))
        else:
            self.chat_parser = chat_parser

    def update_envs_and_agents(self, envs, agents):
        """
        Update the environments and agents.

        Args:
            envs: List of environments to use
            agents: List of agents to use
        """
        assert len(agents) == len(envs), f"Number of agents must equal to number of environments but received, {len(agents)} and {len(envs)}"
        self.envs = envs
        # For keeping track of the environment index in the batch.
        for idx, env in enumerate(envs):
            env.idx = idx
        self.agents = agents
        self.n_parallel_agents = len(envs)

    async def run_sft_trajectory_async(self, idx, application_id, seed=0, mode="Text", **kwargs):
        """Run a single agent's trajectory asynchronously"""
        agent = self.agents[idx]
        env = self.envs[idx]
        task = agent.trajectory.task

        loop = asyncio.get_event_loop()
        initial_observation, initial_info = await loop.run_in_executor(self.executor, env.reset)
        task_info = initial_observation.get("task_info", None)
        if task_info is None:
            raise ValueError("task_info is required")
        # env_id = env.env_id
        def reformat_trajectory(trajectory: list[dict]):
            current_turn = "user"
            raw_trajectory = []
            tmp_step = {}
            for step in trajectory:
                if isinstance(step, str):
                    step = json.loads(step)
                if current_turn == step["role"]:
                    if step["role"] == "user":
                        tmp_step["observation"] = step["info"]["observation"]
                        current_turn = "assistant"
                    else:
                        thought = step['info']['thought'].replace("<think>", "").replace("</think>", "").strip()
                        reason = step['info']['reason'].replace("REASON: ", "").strip()
                        action = step['info']['action'].replace("ACTION: ", "").strip()
                        tmp_step["response"] = f"<think>{thought}</think>\nREASON: {reason}\nACTION: {action}"
                        tmp_step["reward"] = step.get("reward", 0.0)
                        tmp_step["done"] = step.get("done", False)
                        if len(raw_trajectory) == 0:
                            tmp_step["info"] = {
                                "task_info": task_info,
                            }
                        else:
                            tmp_step["info"] = {}
                        raw_trajectory.append(tmp_step)
                        tmp_step = {}
                        current_turn = "user"
                else:
                    continue
            try:
                final_observation = task_info.get("solution", "")
            except:
                final_observation = ""
            # last step
            final_step = {
                "observation": final_observation,
                "response": "",
                "reward": 1.0,
                "done": True,
                "info": {},
            }
            raw_trajectory.append(final_step)

            return raw_trajectory
        
        if "gold_trace" in task:
            given_trajectory = task["gold_trace"]
        else:
            given_trajectory = task.get("trajectory", None)
        if given_trajectory is None:
            raise ValueError("given_trajectory is required")
        
        if isinstance(given_trajectory, str):
            given_trajectory = json.loads(given_trajectory)
        
        raw_trajectory = reformat_trajectory(given_trajectory)

        agent.reset()
        try:
            agent.update_from_env(
                observation=raw_trajectory[0]["observation"],  # Raw observation from environment
                reward=raw_trajectory[0]["reward"],
                done=raw_trajectory[0]["done"],
                info={"task_info": task_info},
            )
        except Exception as e:
            print(e)
        for step_idx, step in enumerate(raw_trajectory[:-1]):
            response = step["response"]
            action: Action = agent.update_from_model(response)

            agent.update_from_env(
                observation=raw_trajectory[step_idx+1]["observation"],  # Raw observation from environment
                reward=raw_trajectory[step_idx+1]["reward"],
                done=raw_trajectory[step_idx+1]["done"],
                info=raw_trajectory[step_idx+1]["info"],
            )
            cur_step = agent.get_current_state()
            cur_step.reward = raw_trajectory[step_idx+1]["reward"]
            cur_step.done = raw_trajectory[step_idx+1]["done"]
            cur_step.info.update(raw_trajectory[step_idx+1]["info"])
        
        trajectory: Trajectory = agent.trajectory
        trajectory.reward = 1.0

        return trajectory
        

    async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Text", **kwargs):
        """Run a single agent's trajectory asynchronously"""
        agent = self.agents[idx]
        env = self.envs[idx]
        task = agent.trajectory.task

        def make_response_list(gold_trace: list[dict]):
            response_list = []
            for step in gold_trace:
                if isinstance(step, str):
                    step = json.loads(step)
                if step["role"] == "assistant":
                    response = f"<think>{step['info']['thought']}</think>\nREASON: {step['info']['reason']}\nACTION: {step['info']['action']}"
                    response_list.append(response)
                else:
                    continue
            return response_list
        
        gold_trace = task.get("gold_trace", None)
        if gold_trace is None:
            raise ValueError("gold_trace is required")
        
        response_list = make_response_list(gold_trace)
        # env_id = env.env_id

        termination_reason = None
        prompt_tokens = []
        response_token_len = 0
        total_time = 0.0
        llm_time = 0.0
        env_time = 0.0
        reward = 0.0

        # for step return
        episode_steps = []

        # Reset environment with the task using the executor
        loop = asyncio.get_event_loop()
        observation, info = await loop.run_in_executor(self.executor, env.reset)
        info["max_steps"] = self.max_steps

        # Reset agent
        agent.reset()
        # Update agent internal state from environment.
        agent.update_from_env(
            observation=observation,  # Raw observation from environment
            reward=0.0,
            done=False,
            info=info,
        )
        messages = agent.chat_completions
        prompt_tokens, _ = convert_messages_to_tokens_and_masks(messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=True, contains_generation_msg=True)
        prompt_token_len = len(prompt_tokens)
        # # Note, this should never happen!
        # if prompt_token_len > self.max_prompt_length:
        #     agent.reset()
        #     raise Exception(f"Trajectory {idx}: initial prompt length {prompt_token_len} already exceeded max_prompt_length {self.max_prompt_length}, retrying")

        for step_idx in range(self.max_steps):
            # Get action from agent
            prompt_messages = agent.chat_completions.copy()
            # Max remaining tokens left for the response
            # For enforced max prompt at each step, no need to deduct here
            prompt_str = self.chat_parser.parse(prompt_messages, add_generation_prompt=True, is_first_msg=True)
            prompt_len = len(self.tokenizer.encode(prompt_str, add_special_tokens=False))
            if not self.enforce_max_prompt_length:
                max_tokens = self.max_response_length - response_token_len
            else:
                max_tokens = self.max_response_length
                # since max prompt is enforced, we filter out too long prompts.
                # prompt_str = self.chat_parser.parse(prompt_messages, add_generation_prompt=True, is_first_msg=True)
                # prompt_len = len(self.tokenizer.encode(prompt_str, add_special_tokens=False))
            if prompt_len > self.max_prompt_length:
                termination_reason = "PROMPT_TRUNCATION"
                break

            kwargs["max_tokens"] = max_tokens

            start_time = time.time()
            # get response from action list
            response = response_list[step_idx]
            delta_time = time.time() - start_time
            llm_time += delta_time
            total_time += delta_time
            # Update steps
            prompt_response_pair = {
                "prompt": self.chat_parser.parse(prompt_messages, add_generation_prompt=True, is_first_msg=True),
                "response": response,
            }
            episode_steps.append(prompt_response_pair)

            # Update agent with model response
            action: Action = agent.update_from_model(response)
            action = action.action

            # Take step in environment using the executor
            start_time = time.time()

            try:
                next_observation, reward, done, info = await asyncio.wait_for(loop.run_in_executor(self.executor, env.step, action), timeout=(self.trajectory_timeout - total_time))
            except asyncio.TimeoutError:
                termination_reason = "ENV_TIMEOUT"
                if step_idx == 0:
                    colorful_print(f"Warning: Trajectory {idx} completed due to: {termination_reason} before able to perform 1 complete action. This might cause unexpected behavior. Consider increasing trajectory timeout limit.\n", "red")
                reward = 0

                cur_step = agent.get_current_state()
                done = True
                cur_step.done = done
                break

            delta_time = time.time() - start_time
            env_time += delta_time
            total_time += delta_time
            info["max_steps"] = self.max_steps
            info["cur_tokens"] = response_token_len

            # Update agent internal state.
            agent.update_from_env(
                observation=next_observation,
                reward=reward,
                done=done,
                info=info,
            )

            cur_step = agent.get_current_state()
            cur_step.reward = reward
            cur_step.done = done
            cur_step.info.update(info)

            chat_completions_messages = agent.chat_completions
            assistant_message, env_messages = get_recent_assistant_user_messages(chat_completions_messages)

            # Check and convert to tokens if necessary
            assert assistant_message is not None or mode != "Token", "Assistant messages is none when accumulating token trajectories which should be conversations. This should not happen."
            assert env_messages is not None or mode != "Token", "Environment messages is none when accumulating token trajectories which should be conversations. This should not happen."
            assistant_msg_tokens = []
            env_msg_tokens = []
            if assistant_message:
                assistant_msg_tokens, _ = convert_messages_to_tokens_and_masks([assistant_message], tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=False, contains_generation_msg=False)
            if env_messages:
                env_msg_tokens, _ = convert_messages_to_tokens_and_masks(env_messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=False, contains_generation_msg=True)

            # Update repsonse token length
            response_token_len += len(assistant_msg_tokens) + len(env_msg_tokens)
            # Reached maximum number of tokens for the trajectory
            if not self.enforce_max_prompt_length and response_token_len >= self.max_response_length:
                cur_step = agent.get_current_state()
                if response_token_len - len(env_msg_tokens) > self.max_response_length:
                    cur_step.reward = 0.0
                cur_step.done = True
                termination_reason = "TRUNCATION"

                break

            observation = next_observation

            if total_time >= self.trajectory_timeout:
                termination_reason = "TIMEOUT"
                cur_step = agent.get_current_state()
                done = True
                cur_step.done = done
                break

            # Check if episode is done
            if done:
                termination_reason = "ENV_DONE"
                break


            if step_idx == self.max_steps - 1:
                termination_reason = "MAX_STEPS"

        # Closing environment using the executor.
        await loop.run_in_executor(self.executor, env.close)
        if termination_reason:
            if reward > 0:
                color = "green"
            else:
                color = "yellow"
            colorful_print(
                f"Trajectory {idx} completed due to: {termination_reason}. Reward is {reward}. \n",
                color,
            )
        trajectory: Trajectory = agent.trajectory
        # Aggregate final trajectory statistics
        compute_trajectory_reward(trajectory)

        return trajectory


    async def execute_tasks(self, tasks: list[dict]):
        """
        Run asynchronous interactions between the agent and environment where each agent
        has its own environment instance and can proceed independently.

        Args:
            tasks: List of tasks to process
            max_concurrent: Maximum number of concurrent tasks to process (defaults to self.n_parallel_agents)

        Returns:
            A list of trajectories, one for each task.
        """

        max_concurrent = self.n_parallel_agents

        # Initialize results list to store trajectories for all tasks
        all_trajectories = {}

        # Create a queue of tasks to process
        task_queue = list(enumerate(tasks))
        semaphore = asyncio.Semaphore(max_concurrent)
        index_queue: asyncio.Queue[int] = asyncio.Queue(maxsize=max_concurrent)
        for i in range(max_concurrent):
            index_queue.put_nowait(i)

        # Track completed trajectories
        completed = 0
        total = len(tasks)

        async def sem_wrapper(task_id, task):
            nonlocal completed
            async with semaphore:
                # Get an available index
                index = await index_queue.get()
                try:
                    self.envs[index] = self.env_class.from_dict({**task, **self.env_args})

                    # Build agent instance WITHOUT overwriting self.agent_class (keep it as a class)
                    if "task_type" in task:
                        if hasattr(self.agent_class, "create_agent"):
                            agent_instance = self.agent_class.create_agent(task["task_type"], **self.agent_args)
                        else:
                            agent_instance = self.agent_class(task["task_type"], **self.agent_args)
                    else:
                        agent_instance = self.agent_class(**self.agent_args)

                    self.agents[index] = agent_instance
                    assert self.agents[index] is not None and isinstance(self.agents[index], BaseAgent), "Agent is not initalized or not inheriting from BaseAgent"

                    self.agents[index].trajectory.task = task  # type: ignore
                    try:
                        if self.mode == "sft":
                            res = await self.run_sft_trajectory_async(index, application_id=task_id)
                        elif self.mode == "interactive":
                            res = await self.run_agent_trajectory_async(index, application_id=task_id)
                        else:
                            raise ValueError(f"Invalid mode: {self.mode}")
                    except Exception as e:
                        logger.error(f"Error executing trajectory {index}: {e}")
                        res = None

                    if res is not None and hasattr(res, "task"):
                        res.task = task
                    completed += 1
                    colorful_print(f"Progress: {completed}/{total} trajectories completed", "cyan")
                    return task_id, res
                finally:
                    # Put the index back in the queue when done
                    await index_queue.put(index)

        # Run all tasks concurrently
        results = await asyncio.gather(*[sem_wrapper(task_id, task) for task_id, task in task_queue])

        all_trajectories = {task_id: trajectory for task_id, trajectory in results}
        ordered_trajectories = [all_trajectories[i] for i in range(len(all_trajectories))]
        return ordered_trajectories

class AsyncReplayEngine(ReplayEngine):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
