from runners.on_policy_runner import OnPolicyRunner as Runner
from replay import OnPolicyBuffer
from elements import Agg, FPS
from logger import Logger, TerminalOutput, WandBOutput
from actor_critic.rnn_actor import RNNActor
from actor_critic.rnn_critic import RNNCritic
from utils import conditions
from utils.tools import init_device, get_task_name, make_env, n2t, build_returns
from parallel import Remote, Dummy
import elements

import torch
torch.set_float32_matmul_precision("high")

class OnPolicyTrainer:
    def __init__(self, config):
        self.config = config
        self.step = elements.Counter()
        # initialize aggregator
        self.agg = Agg()
        # initialize logger
        output_handles = [TerminalOutput(pattern=config.logging.terminal_filter)]
        if config.logging.use_wandb:
            output_handles.append(
                WandBOutput(
                    name=config.logdir.split("/")[-1],
                    pattern=config.logging.wandb_filter,
                    config=config,
                    group=config.env + "." + get_task_name(config),
                    **config.logging.wandb_config,
                )
            )
        self.logger = Logger(output_handles=output_handles)
        # initialize envs
        if config.train.parallel_rollout:
            self.envs = [Remote(make_env, config, i) for i in range(config.train.n_rollout_threads)]
        else:
            self.envs = [Dummy(make_env, config, i) for i in range(config.train.n_rollout_threads)]
        # initialize replay
        self.replay = OnPolicyBuffer(config, config.train.n_rollout_threads, self.agg)
        # initialize device
        self.device = init_device(config)
        self.tpdv = dict(dtype=torch.float32, device=self.device)
        # initialize actor
        obs_shape = self.envs[0].obs_shape
        n_actions = self.envs[0].n_actions
        n_agents = self.envs[0].n_agents
        if config.train.share_actors:
            self.actors = [RNNActor(config, obs_shape, n_agents, n_actions, self.device)] * n_agents
        else:
            self.actors = [RNNActor(config, obs_shape, n_agents, n_actions, self.device) for _ in range(n_agents)]
        # initialize critics
        if config.train.share_critics:
            self.critics = [RNNCritic(config, obs_shape, self.device)] * n_agents
        else:
            self.critics = [RNNCritic(config, obs_shape, self.device) for _ in range(n_agents)]
        # initialize runner
        self.runner = Runner(
            config=config,
            envs=self.envs,
            actors=self.actors,
            critics=self.critics,
            replay=self.replay,
            device=self.device,
        )
        # initialize for evaluation
        if config.use_eval:
            if config.eval.parallel_rollout:
                self.eval_envs = [Remote(make_env, config, i) for i in range(config.eval.n_rollout_threads)]
            else:
                self.eval_envs = [Dummy(make_env, config, i) for i in range(config.eval.n_rollout_threads)]
            self.eval_agg = Agg()
            self.eval_replay = OnPolicyBuffer(config, config.eval.n_rollout_threads, self.eval_agg)
            self.eval_runner = Runner(
                config=config,
                envs=self.eval_envs,
                actors=self.actors,
                critics=self.critics,
                replay=self.eval_replay,
                device=self.device,
            )
        # initialize conditions
        train_every = config.train.n_rollout_threads * config.train.batch_length
        self.should_train = conditions.Every(every=train_every, initial=False)
        self.should_eval = conditions.Every(every=config.eval.eval_interval, initial=False)
        self.should_log = conditions.Every(every=config.logging.log_interval, initial=False)
        self.should_save = conditions.Every(every=self.config.train.checkpoint.save_interval, initial=False)
        # initialize fps tracker
        self.env_fps = FPS()
        self.train_fps = FPS()
        # setup checkpoint
        self.checkpoint = elements.Checkpoint(directory=config.logdir + "/ckpt", step=self.step)
        self.checkpoint.step = self.step
        self.checkpoint.should_train = self.should_train
        self.checkpoint.should_eval = self.should_eval
        self.checkpoint.should_log = self.should_log
        self.checkpoint.should_save = self.should_save
        for i in range(len(self.actors)):
            setattr(self.checkpoint, f"actor_{i}", self.actors[i])
        for i in range(len(self.critics)):
            setattr(self.checkpoint, f"critic_{i}", self.critics[i])
        if config.train.checkpoint.from_checkpoint:
            self.checkpoint.load(path=config.train.checkpoint.from_checkpoint)

    def train(self):
        print("On-policy trainer is running")
        self.runner.reset()
        while self.step < self.config.train.num_env_steps:
            num_steps, _ = self.runner.step(self.agg)
            self.env_fps.step(num_steps)
            self.step.increment(num_steps)

            # train model
            self.train_step()

            # evaluation
            if self.config.use_eval:
                self.eval()

            # log metrics
            self.log_step()

            # save checkpoint
            self.save_step()

        self.close()

    @elements.timer.section("train")
    def train_step(self):
        if self.should_train(int(self.step)):
            self.train_fps.step()

            # sample data from buffer
            data = self.replay.create_dataset()
            obs = n2t(data["obs"], **self.tpdv)
            rnn_states = n2t(data["rnn_states"], **self.tpdv)
            rnn_states_critic = n2t(data["rnn_states_critic"], **self.tpdv)
            value_preds = n2t(data["value_preds"], **self.tpdv)
            rewards = n2t(data["rewards"], **self.tpdv)
            terminated = n2t(data["terminated"], **self.tpdv)
            truncated = n2t(data["truncated"], **self.tpdv)
            agent_mask = n2t(data["agent_mask"], **self.tpdv)
            actions_env = n2t(data["actions_env"], **self.tpdv)
            avail_actions = n2t(data["avail_actions"], **self.tpdv) if "avail_actions" in data else None

            # calculate the value targets
            value_preds_list = []
            for i in range(self.runner.n_agents):
                last_value_preds = self.critics[i](obs[-1:, :, i], rnn_states_critic[-1:, :, i])["value_preds"]
                value_preds_list.append(torch.cat([value_preds[:, :, i], last_value_preds], dim=0))
            value_preds = torch.stack(value_preds_list, dim=2)
            rewards = rewards.mean(dim=2, keepdim=True).expand_as(rewards)
            target_returns = build_returns(
                rewards=rewards,
                value_preds=value_preds,
                terminated=terminated.unsqueeze(2),
                truncated=truncated.unsqueeze(2),
                gamma=self.config.train.gamma,
                gae_lambda=self.config.train.gae_lambda,
            )

            # calculate the advantages
            advantages = []
            for i in range(len(self.actors)):
                advantage = target_returns[:, :, i] - value_preds[:, :, i]
                advantage_mean = advantage[agent_mask[:, :, i] == 1].mean()
                advantage_std = advantage[agent_mask[:, :, i] == 1].std()
                advantages.append((advantage - advantage_mean) / (advantage_std + 1e-5))
            advantages = torch.stack(advantages, dim=2)

            # train actor
            if self.config.train.share_actors:
                train_metrics = self.actors[0].ppo_update(
                    obs=obs[:-1],
                    rnn_states=rnn_states[:-1],
                    actions_env=actions_env,
                    agent_mask=agent_mask[:-1],
                    advantages=advantages[:-1],
                    avail_actions=avail_actions[:-1] if avail_actions is not None else None,
                )
                self.logger.add(int(self.step), train_metrics, prefix="agent_0")
            else:
                for i in range(len(self.actors)):
                    train_metrics = self.actors[i].ppo_update(
                        obs=obs[:-1, :, i],
                        rnn_states=rnn_states[:-1, :, i],
                        actions_env=actions_env[:, :, i],
                        agent_mask=agent_mask[:-1, :, i],
                        advantages=advantages[:-1, :, i],
                        avail_actions=avail_actions[:-1, :, i] if avail_actions is not None else None,
                    )
                    self.logger.add(int(self.step), train_metrics, prefix=f"agent_{i}")

            # train critic
            if self.config.train.share_critics:
                train_metrics = self.critics[0].ppo_update(
                    obs=obs[:-1],
                    rnn_states_critic=rnn_states_critic[:-1],
                    target_returns=target_returns[:-1],
                    agent_mask=agent_mask[:-1],
                )
                self.logger.add(int(self.step), train_metrics, prefix="agent_0")
            else:
                for i in range(len(self.critics)):
                    train_metrics = self.critics[i].ppo_update(
                        obs=obs[:-1, :, i],
                        rnn_states_critic=rnn_states_critic[:-1, :, i],
                        target_returns=target_returns[:-1, :, i],
                        agent_mask=agent_mask[:-1, :, i],
                    )
                    self.logger.add(int(self.step), train_metrics, prefix=f"agent_{i}")

            # reset the replay buffer
            self.replay.reset()

    @torch.no_grad()
    def eval(self):
        if self.should_eval(int(self.step)):
            with elements.timer.section("eval"):
                self.eval_replay.clear()
                self.eval_runner.reset()
                episodes = elements.Counter()
                while episodes < self.config.eval.eval_episode_num:
                    _, num_episodes = self.eval_runner.step(self.eval_agg, evaluation=True)
                    episodes.increment(num_episodes)
                self.logger.add(int(self.step), self.eval_agg.result(reset=True, prefix="eval"))

    def log_step(self):
        if self.should_log(int(self.step)):
            with elements.timer.section("log"):
                self.logger.add(int(self.step), self.agg.result(reset=True))
                self.logger.add(
                    int(self.step),
                    {
                        "env_fps": self.env_fps.result(reset=False),
                        "train_fps": self.train_fps.result(reset=False),
                    }
                )
                if self.config.logging.timer:
                    timer_dict = elements.timer.stats()
                    timer_dict.pop('summary')
                    self.logger.add(int(self.step), timer_dict, prefix="timer")
                self.logger.flush()

    def save_step(self):
        if self.should_save(int(self.step)):
            with elements.timer.section("save"):
                self.checkpoint.save()

    def close(self):
        for env in self.envs:
            env.close()
        self.logger.close()
