from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch
import numpy as np


class ADEpisodeRunner:

    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.batch_size = self.args.batch_size_run
        assert self.batch_size == 1

        self.env = env_REGISTRY[self.args.env](**self.args.env_args)
        self.episode_limit = self.env.episode_limit
        self.t = 0

        self.t_env = 0

        self.returns = {}
        self.stats = {}

    def setup(self, scheme, groups, preprocess, mac):
        self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
                                 preprocess=preprocess, device=self.args.device)
        self.mac = mac

    def get_env_info(self):
        return self.env.get_env_info()

    def save_replay(self):
        self.env.save_replay()

    def close_env(self):
        self.env.close()

    def reset(self):
        self.batch = self.new_batch()
        self.env.reset()
        self.t = 0

    def run(self, test_mode=False, tag='train'):
        self.reset()

        terminated = False
        episode_return = 0
        self.mac.init_hidden(batch_size=self.batch_size)

        while not terminated:

            pre_transition_data = {
                "state": [self.env.get_state()],
                "avail_actions": [self.env.get_avail_actions()],
                "obs": [self.env.get_obs()], 
                "obs_origin": [self.env.get_obs()], 
                "obs_perturbed": [self.env.get_obs()]
            }

            self.batch.update(pre_transition_data, ts=self.t)

            # Pass the entire batch of experiences up till now to the agents
            # Receive the actions for each agent at this timestep in a batch of size 1
            actions, ad_discrete_actions, ad_continuous_actions, ad_discrete_emb, ad_continuous_emb = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)

            reward, terminated, env_info = self.env.step(actions[0])
            episode_return += reward

            post_transition_data = {
                "actions": actions,
                "ad_discrete_actions": ad_discrete_actions,
                "ad_continuous_actions": ad_continuous_actions,
                "ad_discrete_emb": ad_discrete_emb,
                "ad_continuous_emb": ad_continuous_emb,
                "reward": [(reward,)],
                "ad_reward": [(reward * self.args.ad_policy_reward_scale,)],
                "terminated": [(terminated != env_info.get("episode_limit", False),)],
            }

            self.batch.update(post_transition_data, ts=self.t)

            self.t += 1

        last_data = {
            "state": [self.env.get_state()],
            "avail_actions": [self.env.get_avail_actions()],
            "obs": [self.env.get_obs()],
            "obs_origin": [self.env.get_obs()], 
            "obs_perturbed": [self.env.get_obs()]
        }
        self.batch.update(last_data, ts=self.t)

        # Select actions in the last stored state
        actions, ad_discrete_actions, ad_continuous_actions, ad_discrete_emb, ad_continuous_emb = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
        self.batch.update({"actions": actions}, ts=self.t)
        self.batch.update({"ad_discrete_actions": ad_discrete_actions}, ts=self.t)
        self.batch.update({"ad_continuous_actions": ad_continuous_actions}, ts=self.t)
        self.batch.update({"ad_discrete_emb": ad_discrete_emb}, ts=self.t)
        self.batch.update({"ad_continuous_emb": ad_continuous_emb}, ts=self.t)

        self.returns[tag] = self.returns.get(tag, []) + [episode_return]
        if tag not in self.stats:
            self.stats[tag] = {}
        self.stats[tag].update({k: self.stats[tag].get(k, 0) + env_info.get(k, 0) for k in set(self.stats[tag]) | set(env_info)})
        self.stats[tag]["n_episodes"] = 1 + self.stats[tag].get("n_episodes", 0)
        self.stats[tag]["ep_length"] = self.t + self.stats[tag].get("ep_length", 0)
        if tag == 'train':
            self.t_env += self.t

        return self.batch

    def log_info(self, tag="train"):
        log_dic = {}
        log_dic["return_mean"] = float(np.mean(self.returns[tag]))
        for k, v in self.stats[tag].items():
            if k != "n_episodes":
                log_dic[f"{k}_mean"] = float(v/self.stats[tag]["n_episodes"])
        if hasattr(self.mac, 'action_selector') and hasattr(self.mac.action_selector, "epsilon"):
                log_dic['epsilon'] = float(self.mac.action_selector.epsilon)
        self.stats[tag].clear()
        self.returns[tag].clear()

        log_dic = {f'{tag}/{k}': v for k, v in log_dic.items()}
        for k, v in log_dic.items():
            self.logger.log_stat(k, v, self.t_env)
        return log_dic
