import copy

from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch
import numpy as np
import os,wandb,csv


class EpisodeRunner:

    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.batch_size = 1
        #assert self.batch_size == 1

        self.env = env_REGISTRY[self.args.env](**self.args.env_args)
        self.episode_limit = self.env.episode_limit
        self.t = 0

        self.t_env = 0

        self.train_returns = []
        self.test_returns = []
        self.train_stats = {}
        self.test_stats = {}

        self.n_native_state = None

        # Log the first run
        self.log_train_stats_t = -1000000

        map_name = self.args.env_args['map_name']
        seed = self.args.env_args['seed']
        self.csv_dir = f'./csv_files/{map_name}/{args.mixer}/{args.label}/'
        self.csv_path = f'{self.csv_dir}seed_{seed}_{args.label}.csv'
        if not os.path.exists(self.csv_dir):
            os.makedirs(self.csv_dir)
        if args.wandb:
            job_type = 'training'
            wandb_name = f'{args.env_args["map_name"]}_{args.label}'
            wandb.login(key=args.key, relogin=True)
            wandb.init(project=args.project, entity=args.entity, name=wandb_name, group=map_name, job_type=job_type, config=args)

    def setup(self, scheme, groups, preprocess, mac):
        self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
                                 preprocess=preprocess, device=self.args.device)
        self.mac = mac

    def get_env_info(self):
        env_info = self.env.get_env_info()
        env_info["native_state_size"] = self.env.get_native_state_size()
        env_info["alive_state_size"] = self.env.get_alive_state_size()
        env_info['n_enemies'] = self.env.get_n_enemies()
        env_info['native_state_summary'] = self.env.get_native_state_summary()
        self.env_info = env_info
        return self.env_info

    def save_replay(self):
        self.env.save_replay()

    def close_env(self):
        self.env.close()

    def reset(self):
        self.batch = self.new_batch()
        self.env.reset()
        self.t = 0

    def run(self, test_mode=False):
        self.reset()

        terminated = False
        episode_return = 0
        self.mac.init_hidden(batch_size=self.batch_size)

        while not terminated:

            pre_transition_data = {
                "state": [self.env.get_state()],
                "avail_actions": [self.env.get_avail_actions()],
                "alive_state": [self.env.get_alive_state()],
                "native_state": [self.env.get_native_state()],
                "obs": [self.env.get_obs()],
            }

            self.batch.update(pre_transition_data, ts=self.t)

            # Pass the entire batch of experiences up till now to the agents
            # Receive the actions for each agent at this timestep in a batch of size 1
            actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
            # Fix memory leak
            cpu_actions = actions.to("cpu").numpy()
            
            reward, terminated, env_info = self.env.step(actions[0])
            episode_return += reward

            post_transition_data = {
                "actions": cpu_actions,
                "reward": [(reward,)],
                "terminated": [(terminated,)],
            }

            self.batch.update(post_transition_data, ts=self.t)

            self.t += 1

        last_data = {
            "state": [self.env.get_state()],
            "avail_actions": [self.env.get_avail_actions()],
            "alive_state": [self.env.get_alive_state()],
            "native_state": [self.env.get_native_state()],
            "obs": [self.env.get_obs()]
        }
        self.batch.update(last_data, ts=self.t)

        # Select actions in the last stored state
        actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
        # Fix memory leak
        cpu_actions = actions.to("cpu").numpy()
        self.batch.update({"actions": cpu_actions}, ts=self.t)

        if not test_mode:
            self.t_env += self.t

        cur_stats = self.test_stats if test_mode else self.train_stats
        cur_returns = self.test_returns if test_mode else self.train_returns
        log_prefix = "test_" if test_mode else ""
        cur_stats.update({k: cur_stats.get(k, 0) + env_info.get(k, 0) for k in set(cur_stats) & set(env_info)})
        cur_stats["n_episodes"] = 1 + cur_stats.get("n_episodes", 0)
        cur_stats["ep_length"] = self.t + cur_stats.get("ep_length", 0)
        cur_stats['step'] = self.t_env

        if test_mode:
            if 'battle_won' in cur_stats.keys():
                win_rate = cur_stats['battle_won'] / cur_stats['n_episodes']
            else:
                win_rate = ((np.array(cur_returns) > 8).astype('int')).sum() / cur_stats['n_episodes']
            cur_stats['win_rate'] = win_rate

        stats = copy.deepcopy(cur_stats)
        cur_returns.append(episode_return)

        if test_mode and (len(self.test_returns) == self.args.test_nepisode):
            reward=np.mean(cur_returns)
            mean_steps = cur_stats["ep_length"] / cur_stats['n_episodes']
            self.writereward(self.csv_path, reward,win_rate,self.t_env)
            if self.args.wandb:
                wandb.log({'Test_win_rate':win_rate, log_prefix + "return_mean": reward,
                           log_prefix + "return_std": np.std(cur_returns), 'mean_steps:': mean_steps}, step=self.t_env)
            self._log(cur_returns, cur_stats, log_prefix)
        elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval:
            self._log(cur_returns, cur_stats, log_prefix)
            if hasattr(self.mac.action_selector, "epsilon"):
                self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env)
            self.log_train_stats_t = self.t_env

        return self.batch, stats

    def writereward(self, path, reward, win_rate, step):
        if os.path.isfile(path):
            with open(path, 'a+') as f:
                csv_write = csv.writer(f)
                csv_write.writerow([step, reward, win_rate])
        else:
            with open(path, 'w') as f:
                csv_write = csv.writer(f)
                csv_write.writerow(['step', 'reward', 'win_rate'])
                csv_write.writerow([step, reward, win_rate])
    def _log(self, returns, stats, prefix):
        self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env)
        self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env)
        returns.clear()

        for k, v in stats.items():
            if k != "n_episodes":
                self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env)
        stats.clear()
