import torch
from tensordict.tensordict import TensorDict
import numpy as np

from typing import Dict, List
import re

class DreamerRunner:
    def __init__(self, config, envs, actors, world_model, replay, device):
        self.config = config
        self.envs = envs
        self.actors = actors
        self.wm = world_model
        self.replay = replay
        self.n_agents = envs[0].n_agents
        self.device = device

    def reset(self):
        # initialize the environments
        futures = [env.reset() for env in self.envs]
        self.step_dict_list: List[Dict[str, np.ndarray]] = [future() for future in futures]
        self.merged_step_dict: TensorDict = torch.stack(
            [
                TensorDict(
                    step_dict,
                    device=self.device,
                ).named_apply(lambda k, v: v if not k.startswith("log_") else None) for step_dict in self.step_dict_list
            ],
        )

        # initialize the agent state
        self.agent_state: TensorDict = self.wm.initialize_agent_state(batch_size=len(self.envs))

        # store the first step
        for i, step_dict in enumerate(self.step_dict_list):
            step_dict["prev_actions"] = self.agent_state["actions"][i].cpu().numpy()
            step_dict["prev_log_probs"] = self.agent_state["log_probs"][i].cpu().numpy()
            self.replay.add(step_dict, worker=i)

    @torch.no_grad()
    def step(self, agg=None, evaluation=False):
        actor_outputs = self.get_actions(
            obs=self.merged_step_dict["obs"],
            is_first=self.merged_step_dict["is_first"],
            agent_state=self.agent_state,
            avail_actions=(self.merged_step_dict["avail_actions"] if "avail_actions" in self.merged_step_dict else None),
            evaluation=evaluation,
        )
        dones = self.merged_step_dict["terminated"] | self.merged_step_dict["truncated"]
        actions_env = actor_outputs["actions_env"].cpu().numpy()
        futures = [
            env.step(actions_env[i]) if not dones[i]
            else env.reset()
            for i, env in enumerate(self.envs)
        ]
        self.step_dict_list: List[Dict[str, np.ndarray]] = [future() for future in futures]
        self.merged_step_dict: TensorDict = torch.stack(
            [
                TensorDict(
                    step_dict,
                    device=self.device,
                ).named_apply(lambda k, v: v if not k.startswith("log_") else None) for step_dict in self.step_dict_list
            ],
        )

        # add step_dict to replay
        for i, step_dict in enumerate(self.step_dict_list):
            step_dict["prev_actions"] = self.agent_state["actions"][i].cpu().numpy()
            step_dict["prev_log_probs"] = self.agent_state["log_probs"][i].cpu().numpy()
            self.replay.add(step_dict, worker=i)

        # aggregate env stats
        if agg is not None:
            for step_dict in self.step_dict_list:
                for key, value in step_dict.items():
                    if key.startswith("log_"):
                        if re.match(self.config.logging.log_keys_avg, key):
                            agg.add(key, value, agg="avg")
                        if re.match(self.config.logging.log_keys_sum, key):
                            agg.add(key, value, agg="sum")
                        if re.match(self.config.logging.log_keys_max, key):
                            agg.add(key, value, agg="max")

        # calculate the number of steps and episodes
        num_steps = 0
        num_episodes = 0
        for step_dict in self.step_dict_list:
            done = step_dict["terminated"] or step_dict["truncated"]
            num_episodes += done.item()
            if self.config.env_args.get("use_absorbing_state", False):
                num_steps += not step_dict["is_trailing_absorbing_state"]
            else:
                num_steps += not step_dict["is_first"]

        return num_steps, num_episodes

    def get_actions(
            self,
            obs: torch.Tensor,
            is_first: torch.Tensor,
            agent_state: TensorDict,
            avail_actions: torch.Tensor | None = None,
            evaluation: bool = False,
    ):
        embed = self.wm.encoder(obs)
        latent: TensorDict = self.wm.observe_step(
            embed=embed,
            is_first=is_first,
            prev_actions=agent_state["actions"],
            prev_latent=agent_state["latents"],
        )
        actor_outputs: List[TensorDict] = [
            self.actors[i](
                latent=latent[:, i],
                avail_actions=avail_actions[:, i] if avail_actions is not None else None,
                evaluation=evaluation,
            )
            for i in range(len(self.actors))
        ]
        actor_outputs = torch.stack(actor_outputs, dim=1)

        # update the agent state
        self.agent_state["latents"] = latent
        self.agent_state["actions"] = actor_outputs["actions"]
        self.agent_state["log_probs"] = actor_outputs["log_probs"]

        return actor_outputs
