from collections import defaultdict, deque
from itertools import chain
import os
import time

import imageio
import numpy as np
import torch
import wandb
import os
import pickle
from onpolicy.utils.util import update_linear_schedule
from onpolicy.runner.shared.base_runner_trsyn import Runner
from onpolicy.algorithms.utils.distributions import FixedNormal, FixedCategorical


def _t2n(x):
    return x.detach().cpu().numpy()

class FootballRunner(Runner):
    def __init__(self, config):
        super(FootballRunner, self).__init__(config)
        self.env_infos = defaultdict(list)
       
    def run(self):
        self.warmup()   

        start = time.time()
        episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads

        win_rate = []
        time_steps = []

        for episode in range(episodes):
            if self.use_linear_lr_decay:
                self.trainer.policy.lr_decay(episode, episodes)

            team_rewards = []
            idv_rewards = []

            for step in range(self.episode_length):
                # Sample actions
                    
                values, actions, action_log_probs, rnn_states, rnn_states_critic, act_dists = \
                    self.idv_collect(step)

                # Get data using Team Policy
                team_values, team_actions, team_log_probs, team_rnn, team_rnn_critic, team_act_dists = \
                    self.team_collect(step)

                if self.all_args.change_reward and episode > self.all_args.change_reward_episode and \
                        self.all_args.change_use_policy == "team":
                    actions = team_actions
                    values, action_log_probs = self.evaluate_actions("idv", step, actions)
                else:
                    team_values, team_log_probs = self.evaluate_actions("team", step, actions)

                # Obser reward and next obs
                obs, rewards, dones, infos, = self.envs.step(actions)

                data = obs,rewards, dones, infos, \
                       values, actions, action_log_probs, rnn_states, rnn_states_critic, act_dists, \
                       team_values, team_log_probs, team_rnn, team_rnn_critic, team_act_dists

                # insert data into buffer
                self.insert(data)

            # compute return and update network
            self.compute()
            train_infos = self.train(episode)
            
            # post process
            total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads
            
            # save model
            if (total_num_steps % self.save_interval == 0 or episode == episodes - 1):
                self.save()

            # log information
            if total_num_steps % self.log_interval == 0:
                end = time.time()
                print("\n Env {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                        .format(self.env_name,
                                self.algorithm_name,
                                self.experiment_name,
                                episode,
                                episodes,
                                total_num_steps,
                                self.num_env_steps,
                                int(total_num_steps / (end - start))))
                
                train_infos["average_episode_rewards"] = np.mean(self.buffer.team_rewards) * self.episode_length
                print("average episode rewards is {}".format(train_infos["average_episode_rewards"]))
                self.log_train(train_infos, total_num_steps)
                self.log_env(self.env_infos, total_num_steps)
                self.env_infos = defaultdict(list)

                win_rate.append(np.mean(self.env_infos['win_rate']))
                time_steps.append(total_num_steps)

            # eval
            if total_num_steps % self.eval_interval == 0 and self.use_eval:
                self.eval(total_num_steps)

        d = {"win_rate_" + str(self.all_args.seed): win_rate}
        with open(os.path.join(str(self.run_dir) + "/" + "win_rate_seed_" + str(self.all_args.seed) + ".pkl"), 'wb') as f:
            pickle.dump(d, f, pickle.HIGHEST_PROTOCOL)

        d = {"time_steps_" + str(self.all_args.seed): time_steps}
        with open(os.path.join(str(self.run_dir) + "/" + "time_steps_seed_" + str(self.all_args.seed) + ".pkl"), 'wb') as f:
            pickle.dump(d, f, pickle.HIGHEST_PROTOCOL)


    def warmup(self):
        # reset env
        obs = self.envs.reset()

        # replay buffer
        if not self.use_centralized_V:
            team_share_obs = obs.copy()
        else:
            team_share_obs = obs.copy()

        if not self.idv_use_shared_obs:
            idv_share_obs = obs.copy()
        else:
            idv_share_obs = obs.copy()

        self.buffer.idv_share_obs[0] = idv_share_obs.copy()
        self.buffer.team_share_obs[0] = team_share_obs.copy()
        self.buffer.obs[0] = obs.copy()

    @torch.no_grad()
    def idv_collect(self, step):
        self.trainer.idv_prep_rollout()
        value, action, action_log_prob, rnn_states, rnn_states_critic, act_dist \
            = self.trainer.idv_policy.get_actions(np.concatenate(self.buffer.idv_share_obs[step]),
                                                  np.concatenate(self.buffer.obs[step]),
                                                  np.concatenate(self.buffer.idv_rnn_states[step]),
                                                  np.concatenate(self.buffer.idv_rnn_states_critic[step]),
                                                  np.concatenate(self.buffer.masks[step]),
                                                  np.concatenate(self.buffer.available_actions[step]))

        # [self.envs, agents, dim]
        values = np.array(np.split(_t2n(value), self.n_rollout_threads))
        actions = np.array(np.split(_t2n(action), self.n_rollout_threads))
        action_log_probs = np.array(np.split(_t2n(action_log_prob), self.n_rollout_threads))
        rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads))
        rnn_states_critic = np.array(np.split(_t2n(rnn_states_critic), self.n_rollout_threads))

        act_dists = []
        for dist in act_dist:
            if type(dist) == FixedCategorical:
                tmp_probs = dist.probs.detach()
                tps = []
                for tp in tmp_probs:
                    tps.append(type(dist)(probs=tp))
            elif type(dist) == FixedNormal:
                tmp_mu = dist.loc.detach()
                tmp_sigma = dist.scale.detach()
                tps = []
                for tm, ts in zip(tmp_mu, tmp_sigma):
                    tps.append(type(dist)(loc=tm, scale=ts))
            else:
                raise NotImplementedError
            act_dists.append(tps)
        act_dists = np.array(np.split(np.array(act_dists).transpose((1, 0)), self.n_rollout_threads))

        return values, actions, action_log_probs, rnn_states, rnn_states_critic, act_dists

    @torch.no_grad()
    def team_collect(self, step):
        self.trainer.team_prep_rollout()
        value, action, action_log_prob, rnn_states, rnn_states_critic, act_dist \
            = self.trainer.team_policy.get_actions(np.concatenate(self.buffer.team_share_obs[step]),
                                                   np.concatenate(self.buffer.obs[step]),
                                                   np.concatenate(self.buffer.team_rnn_states[step]),
                                                   np.concatenate(self.buffer.team_rnn_states_critic[step]),
                                                   np.concatenate(self.buffer.masks[step]),
                                                   np.concatenate(self.buffer.available_actions[step]),
                                                   deterministic=True)

        # [self.envs, agents, dim]
        values = np.array(np.split(_t2n(value), self.n_rollout_threads))
        actions = np.array(np.split(_t2n(action), self.n_rollout_threads))
        action_log_probs = np.array(np.split(_t2n(action_log_prob), self.n_rollout_threads))
        rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads))
        rnn_states_critic = np.array(np.split(_t2n(rnn_states_critic), self.n_rollout_threads))

        act_dists = []
        for dist in act_dist:
            if type(dist) == FixedCategorical:
                tmp_probs = dist.probs.detach()
                tps = []
                for tp in tmp_probs:
                    tps.append(type(dist)(probs=tp))
            elif type(dist) == FixedNormal:
                tmp_mu = dist.loc.detach()
                tmp_sigma = dist.scale.detach()
                tps = []
                for tm, ts in zip(tmp_mu, tmp_sigma):
                    tps.append(type(dist)(loc=tm, scale=ts))
            else:
                raise NotImplementedError
            act_dists.append(tps)
        act_dists = np.array(np.split(np.array(act_dists).transpose((1, 0)), self.n_rollout_threads))

        return values, actions, action_log_probs, rnn_states, rnn_states_critic, act_dists

    @torch.no_grad()
    def evaluate_actions(self, policy: str, step, actions):
        if policy == "team":
            self.trainer.team_prep_rollout()
            exc_policy = self.trainer.team_policy
            share_obs = np.concatenate(self.buffer.team_share_obs[step])
            obs = np.concatenate(self.buffer.obs[step])
            rnn_states = np.concatenate(self.buffer.team_rnn_states[step])
            rnn_states_critic = np.concatenate(self.buffer.team_rnn_states_critic[step])
            masks = np.concatenate(self.buffer.masks[step])
            available_actions = np.concatenate(self.buffer.available_actions[step])
            active_masks = np.concatenate(self.buffer.active_masks[step])
        else:
            self.trainer.idv_prep_rollout()
            exc_policy = self.trainer.idv_policy
            share_obs = np.concatenate(self.buffer.idv_share_obs[step])
            obs = np.concatenate(self.buffer.obs[step])
            rnn_states = np.concatenate(self.buffer.idv_rnn_states[step])
            rnn_states_critic = np.concatenate(self.buffer.idv_rnn_states_critic[step])
            masks = np.concatenate(self.buffer.masks[step])
            available_actions = np.concatenate(self.buffer.available_actions[step])
            active_masks = np.concatenate(self.buffer.active_masks[step])

            # print(type(policy), policy.evaluate_actions)
        value, action_log_prob, _, _ \
            = exc_policy.evaluate_actions(share_obs, obs, rnn_states, rnn_states_critic,
                                          np.concatenate(actions), masks, available_actions, active_masks)

        values = np.array(np.split(_t2n(value), self.n_rollout_threads))
        action_log_probs = np.array(np.split(_t2n(action_log_prob), self.n_rollout_threads))

        return values, action_log_probs

    def insert(self, data):
        obs,  rewards, dones, infos, \
            values, actions, action_log_probs, rnn_states, rnn_states_critic, act_dists, \
            team_values, team_log_probs, team_rnn, team_rnn_critic, team_act_dists = data

        dones_env = np.all(dones, axis=1)

        rnn_states[dones_env == True] = np.zeros(
            ((dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
        rnn_states_critic[dones_env == True] = np.zeros(
            ((dones_env == True).sum(), self.num_agents, *self.buffer.idv_rnn_states_critic.shape[3:]),
            dtype=np.float32)

        team_rnn[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents,
                                                self.recurrent_N, self.hidden_size),
                                               dtype=np.float32)
        team_rnn_critic[dones_env == True] = np.zeros(((dones_env == True).sum(),
                                                       self.num_agents, *self.buffer.team_rnn_states_critic.shape[3:]),
                                                      dtype=np.float32)

        masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)
        masks[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)


        if not self.use_centralized_V:
            team_share_obs = obs.copy()
        else:
            team_share_obs = obs.copy()

        if not self.idv_use_shared_obs:
            idv_share_obs = obs.copy()
        else:
            idv_share_obs = obs.copy()

        idv_rewards, team_rewards = [], []
        for info in infos:
            trw, irw = [], []
            for i in range(self.num_agents):
                #irw.append(rewards[:,i,:])
                trw.append([info["score_reward"]])
            #idv_rewards.append(irw)
            team_rewards.append(trw)
        idv_rewards = rewards.reshape([self.n_rollout_threads, self.num_agents, 1])
        team_rewards = np.array(team_rewards).reshape([self.n_rollout_threads, self.num_agents, 1])

        # update env_infos if done
        dones_env = np.all(dones, axis=-1)
        if np.any(dones_env):
            for done, info in zip(dones_env, infos):
                if done:
                    self.env_infos["goal"].append(info["score_reward"])
                    if info["score_reward"] > 0:
                        self.env_infos["win_rate"].append(1)
                    else:
                        self.env_infos["win_rate"].append(0)
                    self.env_infos["steps"].append(info["max_steps"] - info["steps_left"])

        self.buffer.insert(idv_share_obs, team_share_obs, obs,
                           rnn_states, team_rnn, rnn_states_critic, team_rnn_critic,
                           actions, act_dists, team_act_dists, action_log_probs, team_log_probs,
                           values, team_values, idv_rewards, team_rewards, masks,
                           )

    def log_env(self, env_infos, total_num_steps):
        for k, v in env_infos.items():
            if len(v) > 0:
                if self.use_wandb:
                    wandb.log({k: np.mean(v)}, step=total_num_steps)
                else:
                    self.writter.add_scalars(k, {k: np.mean(v)}, total_num_steps)    

    @torch.no_grad()
    def eval(self, total_num_steps, title):
        # reset envs and init rnn and mask
        eval_obs = self.eval_envs.reset()
        eval_rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
        eval_masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)

        # init eval goals
        num_done = 0
        eval_goals = np.zeros(self.all_args.eval_episodes)
        eval_win_rates = np.zeros(self.all_args.eval_episodes)
        eval_steps = np.zeros(self.all_args.eval_episodes)
        step = 0
        quo = self.all_args.eval_episodes // self.n_eval_rollout_threads
        rem = self.all_args.eval_episodes % self.n_eval_rollout_threads
        done_episodes_per_thread = np.zeros(self.n_eval_rollout_threads, dtype=int)
        eval_episodes_per_thread = done_episodes_per_thread + quo
        eval_episodes_per_thread[:rem] += 1
        unfinished_thread = (done_episodes_per_thread != eval_episodes_per_thread)

        # loop until enough episodes
        while num_done < self.all_args.eval_episodes and step < self.episode_length:
            # get actions
            if title == "team_policy":
                self.trainer.team_prep_rollout()
                policy = self.trainer.team_policy
            else:
                self.trainer.idv_prep_rollout()
                policy = self.trainer.idv_policy

            # [n_envs, n_agents, ...] -> [n_envs*n_agents, ...]
            eval_actions, eval_rnn_states = self.trainer.policy.act(
                np.concatenate(eval_obs),
                np.concatenate(eval_rnn_states),
                np.concatenate(eval_masks),
                deterministic=self.all_args.eval_deterministic
            )
            
            # [n_envs*n_agents, ...] -> [n_envs, n_agents, ...]
            eval_actions = np.array(np.split(_t2n(eval_actions), self.n_eval_rollout_threads))
            eval_rnn_states = np.array(np.split(_t2n(eval_rnn_states), self.n_eval_rollout_threads))

            eval_actions_env = [eval_actions[idx, :, 0] for idx in range(self.n_eval_rollout_threads)]

            # step
            eval_obs, eval_rewards, eval_dones, eval_infos = self.eval_envs.step(eval_actions_env)

            # update goals if done
            eval_dones_env = np.all(eval_dones, axis=-1)
            eval_dones_unfinished_env = eval_dones_env[unfinished_thread]
            if np.any(eval_dones_unfinished_env):
                for idx_env in range(self.n_eval_rollout_threads):
                    if unfinished_thread[idx_env] and eval_dones_env[idx_env]:
                        eval_goals[num_done] = eval_infos[idx_env]["score_reward"]
                        eval_win_rates[num_done] = 1 if eval_infos[idx_env]["score_reward"] > 0 else 0
                        eval_steps[num_done] = eval_infos[idx_env]["max_steps"] - eval_infos[idx_env]["steps_left"]
                        # print("episode {:>2d} done by env {:>2d}: {}".format(num_done, idx_env, eval_infos[idx_env]["score_reward"]))
                        num_done += 1
                        done_episodes_per_thread[idx_env] += 1
            unfinished_thread = (done_episodes_per_thread != eval_episodes_per_thread)

            # reset rnn and masks for done envs
            eval_rnn_states[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
            eval_masks = np.ones((self.all_args.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)
            eval_masks[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)
            step += 1

        # get expected goal
        eval_goal = np.mean(eval_goals)
        eval_win_rate = np.mean(eval_win_rates)
        eval_step = np.mean(eval_steps)
    
        # log and print
        print("eval expected goal is {}.".format(eval_goal))
        if self.use_wandb:
            wandb.log({"eval_goal": eval_goal}, step=total_num_steps)
            wandb.log({"eval_win_rate": eval_win_rate}, step=total_num_steps)
            wandb.log({"eval_step": eval_step}, step=total_num_steps)
        else:
            self.writter.add_scalars("eval_goal", {"expected_goal": eval_goal}, total_num_steps)
            self.writter.add_scalars("eval_win_rate", {"eval_win_rate": eval_win_rate}, total_num_steps)
            self.writter.add_scalars("eval_step", {"expected_step": eval_step}, total_num_steps)

    @torch.no_grad()
    def render(self):        
        # reset envs and init rnn and mask
        render_env = self.envs

        # init goal
        render_goals = np.zeros(self.all_args.render_episodes)
        for i_episode in range(self.all_args.render_episodes):
            render_obs = render_env.reset()
            render_rnn_states = np.zeros((self.n_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
            render_masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)

            if self.all_args.save_gifs:        
                frames = []
                image = self.envs.envs[0].env.unwrapped.observation()[0]["frame"]
                frames.append(image)

            render_dones = False
            while not np.any(render_dones):
                self.trainer.prep_rollout()
                render_actions, render_rnn_states = self.trainer.policy.act(
                    np.concatenate(render_obs),
                    np.concatenate(render_rnn_states),
                    np.concatenate(render_masks),
                    deterministic=True
                )

                # [n_envs*n_agents, ...] -> [n_envs, n_agents, ...]
                render_actions = np.array(np.split(_t2n(render_actions), self.n_rollout_threads))
                render_rnn_states = np.array(np.split(_t2n(render_rnn_states), self.n_rollout_threads))

                render_actions_env = [render_actions[idx, :, 0] for idx in range(self.n_rollout_threads)]

                # step
                render_obs, render_rewards, render_dones, render_infos = render_env.step(render_actions_env)

                # append frame
                if self.all_args.save_gifs:        
                    image = render_infos[0]["frame"]
                    frames.append(image)
            
            # print goal
            render_goals[i_episode] = render_rewards[0, 0]
            print("goal in episode {}: {}".format(i_episode, render_rewards[0, 0]))

            # save gif
            if self.all_args.save_gifs:
                imageio.mimsave(
                    uri="{}/episode{}.gif".format(str(self.gif_dir), i_episode),
                    ims=frames,
                    format="GIF",
                    duration=self.all_args.ifi,
                )
        
        print("expected goal: {}".format(np.mean(render_goals)))
