import asyncio
import json
import os
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from pprint import pprint
from queue import Queue
from threading import Thread

import numpy as np
from omegaconf import OmegaConf
from rllm.engine.agent_execution_engine import AsyncAgentExecutionEngine
from verl import DataProto
from verl.trainer.ppo.ray_trainer import (
    RayPPOTrainer,
    RayWorkerGroup,
    ResourcePoolManager,
    Role,
    WorkerType,
    marked_timer,
)

from rllm.report.monitor import (
    get_instance_summary,
    get_metrics,
)


class AgentInferencePipeline(RayPPOTrainer):
    def __init__(
        self,
        config,
        tokenizer,
        role_worker_mapping: dict[Role, WorkerType],
        resource_pool_manager: ResourcePoolManager,
        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
        reward_fn=None,
        val_reward_fn=None,
        env_class=None,
        agent_class=None,
        env_args=None,
        agent_args=None,
    ):
        super().__init__(config=config, tokenizer=tokenizer, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn)
        self.env_class = env_class
        self.agent_class = agent_class
        self.env_args = env_args or {}
        self.agent_args = agent_args or {}

        if self.config.agent.use_stepwise_advantage:
            print("Using step-level advantage, max_prompt_length and max_response_length will be applied step-wise")
        else:
            print("Using trajectory-level advantage, max_prompt_length and max_response_length will be applied episode-wise")
        
        self.experiment_start_idx = self.config.trainer.get("experiment_start_idx", 0)
        self.num_experiments = self.config.trainer.get("num_experiments", 1)

    def init_workers(self):
        super().init_workers()

        # Initialize additional agent class
        # Number of agents is set to be 0 initially
        if self.hybrid_engine:
            agent_rollout_wg = self.actor_rollout_wg
        else:
            agent_rollout_wg = self.rollout_wg

        if self.config.actor_rollout_ref.rollout.mode == "async":
            self.rollout_engine = self.async_rollout_manager
        else:
            self.rollout_engine = agent_rollout_wg

        self.agent_execution_engine = AsyncAgentExecutionEngine(
            rollout_engine=self.rollout_engine,
            config=self.config,
            engine_name="verl",
            tokenizer=self.tokenizer,
            model_path=self.config.actor_rollout_ref.model.path,
            max_steps=self.config.agent.max_steps,
            max_response_length=self.config.data.max_response_length,
            max_prompt_length=self.config.data.max_prompt_length,
            agent_class=self.agent_class,
            agent_args=self.agent_args,
            env_class=self.env_class,
            env_args=self.env_args,
            gamma=self.config.agent.get("gamma", 0.99),
            enforce_max_prompt_length=self.config.agent.use_stepwise_advantage,
            trajectory_timeout=self.config.agent.trajectory_timeout,
            overlong_filter=self.config.agent.overlong_filter,
            **self.config.agent.get("engine_args", {}),
        )



    def init_envs_and_agents(self, batch):
        """
        Initialize environment depending on env_class with the necessary extra_info, also set uid of the batch.
        """
        env_args = batch.non_tensor_batch["extra_info"].tolist()
        task_types = [env_arg.get("task_type", None) for env_arg in env_args]

        full_agent_args = dict(self.config.agent.get("agent_args", {})) | self.agent_args
        base_env_args = dict(self.config.env.get("env_args", {})) | self.env_args

        def _create_env(i):
            if isinstance(env_args[i], str):
                env_args[i] = json.loads(env_args[i])
            return i, self.env_class.from_dict({**env_args[i], **base_env_args})

        def _create_agent(i):
            try:
                if hasattr(self.agent_class, "create_agent"):
                    agent_class = self.agent_class.create_agent(task_types[i], **full_agent_args)
                else:
                    agent_class = self.agent_class(task_types[i], **full_agent_args)
            except Exception as e:
                agent_class = self.agent_class(**full_agent_args)
            return i, agent_class

        # Create environments in parallel while preserving order
        envs = [None] * len(env_args)
        with ThreadPoolExecutor(max_workers=64) as executor:
            env_futures = [executor.submit(_create_env, i) for i in range(len(env_args))]
            for future in as_completed(env_futures):
                idx, env = future.result()
                envs[idx] = env

        # Create agents in parallel while preserving order
        agents = [None] * len(envs)
        with ThreadPoolExecutor(max_workers=64) as executor:
            agent_futures = [executor.submit(_create_agent, i) for i in range(len(envs))]
            for future in as_completed(agent_futures):
                idx, agent = future.result()
                agents[idx] = agent
        self.agent_execution_engine.update_envs_and_agents(envs, agents)
        return envs

    def inference_agent(self):
        """
        The training loop of PPO. Adapted to train the underlying model of agent.
        """
        self.global_steps = 0

        # perform validation before training
        import time

        for i in range(self.experiment_start_idx, self.experiment_start_idx + self.num_experiments):
            print(f"Running experiment {i+1} of {self.num_experiments}")
            start_time = time.time()

            self.results_path = os.path.join(self.config.trainer.default_local_dir, f"sudoku_results_trial_{i+1}.json")

            if os.path.exists(self.results_path):
                print(f"Validation results already exist for experiment {i+1}, skipping validation")
            else:
                val_metrics = self._validate_agent()
                pprint(f"Validation metrics: {val_metrics}")

                print(f"Time taken to validate agent: {time.time() - start_time}")

        # memory cleanup
        import gc
        import torch
        gc.collect()
        torch.cuda.empty_cache()
    
    def _save_validation_results(self, raw_results: list):
        final_results = []
        for result in raw_results:
            task = result["task"]
            last_observation = result["last_observation"]
            rewards = result["rewards"]
            raw_steps = result["steps"]
            steps = []
            for step in raw_steps:
                steps.append({
                    "observation": step.observation["observation"],
                    "model_response": step.model_response,
                })
            final_results.append({
                "task": task,
                "steps": steps,
                "last_observation": last_observation,
                "rewards": rewards,
            })
        
        # save the results
        print(f"Saving validation results to {self.results_path}")
        with open(self.results_path, "w") as f:
            json.dump(final_results, f, indent=4)
        # report the results        
        try:
            results_summary = [get_instance_summary(result, use_difficulty=True) for result in final_results]
            subtype_order = ["easy", "medium", "hard"]
            metrics = get_metrics(results_summary, subtype_order)
        except Exception as e:
            print(f"Error getting metrics: {e}")
            metrics = {}

        return metrics

    def _validate_agent(self):
        raw_results = []
        self.rollout_engine.wake_up()
        for test_data in self.val_dataloader:
            test_batch = DataProto.from_single_dict(test_data)
            test_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object)
            n_val_samples = self.config.actor_rollout_ref.rollout.val_kwargs.n
            test_batch = test_batch.repeat(repeat_times=n_val_samples, interleave=True)
            test_batch.pop(["input_ids", "attention_mask", "position_ids"])
            test_batch.meta_info = {
                "eos_token_id": self.tokenizer.eos_token_id,
                "pad_token_id": self.tokenizer.pad_token_id,
                "agent_rollout": True,
                "validate": True,
            }
            self.init_envs_and_agents(test_batch)

            task_info = test_batch.non_tensor_batch["extra_info"].tolist()
            for i in range(len(task_info)):
                if isinstance(task_info[i], str):
                    task_info[i] = json.loads(task_info[i])
            
            test_output = self.generate_validation_agent_trajectory(meta_info=test_batch.meta_info)

            for i in range(len(task_info)):
                raw_results.append({
                    "task": task_info[i],
                    "steps": test_output["steps"][i],
                    "last_observation": test_output["last_observation"][i],
                    "rewards": test_output["rewards"][i],
                })
        self.rollout_engine.sleep()
 
        metrics = self._save_validation_results(raw_results)
        metric_dict = {}
        try:
            metric_dict[f"val/total/success_rate"] = metrics["total"]["is_success"]
            metric_dict[f"val/total/step_num_mean"] = metrics["total"]["step_num"]
            metric_dict[f"val/total/progress_mean"] = metrics["total"]["progress"]
        except:
            # set to zero if extraction fails
            metric_dict[f"val/total/success_rate"] = 0
            metric_dict[f"val/total/step_num_mean"] = 0
            metric_dict[f"val/total/progress_mean"] = 0

        if metrics.get("subtype", None) is not None:
            for subkey in metrics["subtype"].keys():
                try:
                    metric_dict[f"val/{subkey}/success_rate"] = metrics["subtype"][subkey]["is_success"]
                    metric_dict[f"val/{subkey}/step_num_mean"] = metrics["subtype"][subkey]["step_num"]
                    metric_dict[f"val/{subkey}/progress_mean"] = metrics["subtype"][subkey]["progress"]
                except:
                    # set to zero if extraction fails
                    metric_dict[f"val/{subkey}/success_rate"] = 0
                    metric_dict[f"val/{subkey}/step_num_mean"] = 0
                    metric_dict[f"val/{subkey}/progress_mean"] = 0
        return metric_dict
    
    def generate_validation_agent_trajectory(self, timing_raw=None, meta_info=None):
        """
        Generates agent trajectories for validation. Does not close or reset the environment afterwards
        """
        if timing_raw is None:
            timing_raw = {}
        with marked_timer("collect_trajectory", timing_raw):
            trajectories = []
            gen_seq_generator = self.generate_agent_trajectories_async(timing_raw=timing_raw, meta_info=meta_info, mode="Text")
            for _, trajectory in enumerate(gen_seq_generator):
                trajectories.append(trajectory)

        # Sort trajectories by their idx, to ensure they are in order.
        trajectories.sort(key=lambda x: x["idx"])

        step_list = []
        last_observation_list = []
        reward_list = []
        for result in trajectories:
            trajectory = result["trajectory"]
            step_list.append(trajectory.steps)
            last_observation_list.append(trajectory.last_observation)
            rewards = []
            for step in trajectory.steps:
                rewards.append(step.reward)
            reward_list.append(rewards)

        outputs = {
            "steps": step_list,
            "last_observation": last_observation_list,
            "rewards": reward_list,
        }

        last_interaction = trajectories[-1]["trajectory"].steps[-1].chat_completions
        print(f"Last interaction: {last_interaction}")

        return outputs

    def generate_agent_trajectories_async(self, timing_raw=None, meta_info=None, mode="Token"):
        """
        Generates agent trajectories asynchronously using the agent execution engine.

        This method runs the asynchronous `trajectory_generator` in a
        separate thread and yields the results synchronously through a queue.
        This allows the main training loop (which might be synchronous) to consume
        asynchronously generated trajectories.

        Args:
            timing_raw (dict, optional): Dictionary to store timing information. Defaults to {}.
            meta_info (dict, optional): Additional metadata for the generation process. Defaults to None.

        Yields:
            Any: Items generated by the `trajectory_generator`, typically
                 representing parts or results of agent trajectories in token format.
        """
        if timing_raw is None:
            timing_raw = {}
        queue = Queue()

        def runner():
            async def consume():
                async for item in self.agent_execution_engine.trajectory_generator(timing_raw=timing_raw, mode=mode, meta_info=meta_info):
                    queue.put(item)
                queue.put(None)  # sentinel to signal done

            asyncio.run(consume())

        Thread(target=runner, daemon=True).start()
        while True:
            item = queue.get()
            if item is None:
                break
            yield item