import time
import wandb
import numpy as np
from functools import reduce
import torch
from tqdm import tqdm
from runner.shared.base_runner import Runner
import datetime


def _t2n(x):
    return x.detach().cpu().numpy()


class SMACRunner(Runner):
    """Runner class to perform training, evaluation. and data collection for SMAC. See parent class for details."""
    def __init__(self, config):
        super(SMACRunner, self).__init__(config)

    def run2(self):
        for episode in range(1):
            self.eval(episode)

    def run(self):
        self.warmup()

        start = time.time()
        episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads

        last_battles_game = np.zeros(self.n_rollout_threads, dtype=np.float32)
        last_battles_won = np.zeros(self.n_rollout_threads, dtype=np.float32)

        for episode in range(episodes):
            if self.use_linear_lr_decay:
                self.trainer.policy.lr_decay(episode, episodes)

            for step in range(self.episode_length):
                # Sample actions
                values, actions, action_log_probs, rnn_states, rnn_states_critic = self.collect(step)
                    
                # Obser reward and next obs
                obs, share_obs, rewards, dones, infos, available_actions = self.envs.step(actions)

                data = obs, share_obs, rewards, dones, infos, available_actions, \
                       values, actions, action_log_probs, \
                       rnn_states, rnn_states_critic 
                
                # insert data into buffer
                self.insert(data)

            # compute return and update network
            self.compute()
            train_infos = self.train()
            
            # post process
            total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads           
            # save model
            if (episode % self.save_interval == 0 or episode == episodes - 1):
                # self.save(episode)
                self.save(total_num_steps)

            # log information
            if episode % self.log_interval == 0:
                end = time.time()
                print("\n Map {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                        .format(self.all_args.map_name,
                                self.algorithm_name,
                                self.experiment_name,
                                episode,
                                episodes,
                                total_num_steps,
                                self.num_env_steps,
                                int(total_num_steps / (end - start))))

                battles_won = []
                battles_game = []
                incre_battles_won = []
                incre_battles_game = []

                for i, info in enumerate(infos):
                    if 'battles_won' in info[0].keys():
                        battles_won.append(info[0]['battles_won'])
                        incre_battles_won.append(info[0]['battles_won']-last_battles_won[i])
                    if 'battles_game' in info[0].keys():
                        battles_game.append(info[0]['battles_game'])
                        incre_battles_game.append(info[0]['battles_game']-last_battles_game[i])

                incre_win_rate = np.sum(incre_battles_won) / np.sum(incre_battles_game) if np.sum(incre_battles_game) > 0 else 0.0
                print("incre win rate is {}.".format(incre_win_rate))
                if self.use_wandb:
                    wandb.log({"incre_win_rate": incre_win_rate}, step=total_num_steps)
                else:
                    self.writter.add_scalars("incre_win_rate", {"incre_win_rate": incre_win_rate}, total_num_steps)

                last_battles_game = battles_game
                last_battles_won = battles_won

                train_infos['dead_ratio'] = 1 - self.buffer.active_masks.sum() / reduce(lambda x, y: x*y, list(self.buffer.active_masks.shape)) 
                
                self.log_train(train_infos, total_num_steps)

            # eval
            if episode % self.eval_interval == 0 and self.use_eval:
                self.eval(total_num_steps)

    ############################# my method
    ######## behaviour cloning
    def run_bc(self):
        # # load train set for path
        # with open(self.all_args.train_data_path, 'rb') as f:
        #     train_set = pickle.load(f)
        # train_data_sampler = TrainDataSampler(train_set)
        for ep_i in range(self.all_args.num_epochs):
            # start a training episode
            self.policy.train()
            train_pbar = tqdm(enumerate(range(self.all_args.num_steps_per_epochs)),
                              total=self.all_args.num_steps_per_epochs, mininterval=1 if self.all_args.quick_tqdm else 10)
            all_train_losses = []
            train_infos = {}
            for _ in train_pbar:
                # share_states, states, actions = train_data_sampler.sample_batch_data(batch_size=self.all_args.batch_size, device=self.device)
                # update network and get train info
                train_loss = self.train_offline()
                all_train_losses.append(train_loss)
                # print('train_loss', train_loss)
            print('episode', ep_i, 'mean_loss', np.mean(all_train_losses))
            train_infos['mean_loss'] = np.mean(all_train_losses)
            # eval model at each episode
            if self.use_eval:
                self.eval(ep_i)
            # log medium result to tensorboard
            self.log_train(train_infos, ep_i)
            # save model at each saving interval or at last episode
            if ep_i % self.save_interval == 0 or ep_i == self.all_args.num_epochs - 1:
                self.save(ep_i)

    def train_offline(self):
        # sample data form expert dataset
        share_obs, obs, actions = self.expert_buffer.sample_batch_data(batch_size=self.all_args.batch_size)

        return self.trainer.train_offline(share_obs, obs, actions)
    ######## behaviour cloning

    ######## gail / wail
    def run_gail(self):
        self.warmup()

        start = time.time()
        episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads

        last_battles_game = np.zeros(self.n_rollout_threads, dtype=np.float32)
        last_battles_won = np.zeros(self.n_rollout_threads, dtype=np.float32)

        # add for record classifier reward
        all_classifier_rewards = []
        for episode in range(episodes):
            if self.use_linear_lr_decay:
                self.trainer.policy.lr_decay(episode, episodes)

            for step in range(self.episode_length):
                # Sample actions
                values, actions, action_log_probs, action_probs, rnn_states, rnn_states_critic, disc_values = self.collect(step)

                # Obser reward and next obs
                obs, share_obs, rewards, dones, infos, available_actions = self.envs.step(actions)

                # get reward from classifier
                if self.all_args.use_classifier_reward:
                    if self.all_args.classifier_use_gru:
                        # update history act sequence with now action
                        self.buffer.update_history_act_record(actions, dones)
                        # get updated history act sequence
                        classifier_input_action = self.buffer.get_history_act_record()
                    else:
                        classifier_input_action = actions
                    classifier_rewards = self.policy.get_classifier_reward(obs, classifier_input_action).detach().cpu().numpy()
                    classifier_rewards = classifier_rewards.reshape(self.n_rollout_threads, self.num_agents, 1)
                    all_classifier_rewards.append(classifier_rewards.reshape(-1).mean())
                    # print('classifier_reward', classifier_rewards.shape)
                else:
                    classifier_rewards = None

                data = obs, share_obs, rewards, dones, infos, available_actions, \
                       values, actions, action_log_probs, action_probs, \
                       rnn_states, rnn_states_critic, disc_values, classifier_rewards

                # insert data into buffer
                self.insert(data)

            # get next value from critic and compute gae advantage
            self.compute()
            # update network
            train_infos = self.train()

            # calculate total simulation steps
            total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads
            # save model at each saving interval or at last episode
            if (episode % self.save_interval == 0 or episode == episodes - 1):
                self.save(episode)

            # log information at each log interval
            if episode % self.log_interval == 0:
                end = time.time()
                print("\n Map {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                      .format(self.all_args.map_name,
                              self.algorithm_name,
                              self.experiment_name,
                              episode,
                              episodes,
                              total_num_steps,
                              self.num_env_steps,
                              int(total_num_steps / (end - start))))

                battles_won = []
                battles_game = []
                incre_battles_won = []
                incre_battles_game = []

                for i, info in enumerate(infos):
                    if 'battles_won' in info[0].keys():
                        battles_won.append(info[0]['battles_won'])
                        incre_battles_won.append(info[0]['battles_won'] - last_battles_won[i])
                    if 'battles_game' in info[0].keys():
                        battles_game.append(info[0]['battles_game'])
                        incre_battles_game.append(info[0]['battles_game'] - last_battles_game[i])

                incre_win_rate = np.sum(incre_battles_won) / np.sum(incre_battles_game) if np.sum(incre_battles_game) > 0 else 0.0
                print("incre win rate is {}.".format(incre_win_rate))
                if self.use_wandb:
                    wandb.log({"incre_win_rate": incre_win_rate}, step=total_num_steps)
                else:
                    self.writter.add_scalars("incre_win_rate", {"incre_win_rate": incre_win_rate}, total_num_steps)

                last_battles_game = battles_game
                last_battles_won = battles_won

                train_infos['dead_ratio'] = 1 - self.buffer.active_masks.sum() / reduce(lambda x, y: x * y, list(self.buffer.active_masks.shape))

                # add classifier info
                if self.all_args.use_classifier_reward:
                    self.writter.add_scalars("classifier_rewards", {"classifier_rewards": np.mean(all_classifier_rewards)}, total_num_steps)
                    all_classifier_rewards = []
                self.log_train(train_infos, total_num_steps)

            # eval model at each eval interval
            if episode % self.eval_interval == 0 and self.use_eval:
                self.eval(total_num_steps)
    ######## gail / wail

    ######## pretrain classifier
    def pretrain_classifier(self):
        for ep_i in range(self.all_args.num_epochs):
            # start a training episode
            train_pbar = tqdm(enumerate(range(self.all_args.num_steps_per_epochs)),
                              total=self.all_args.num_steps_per_epochs, mininterval=1 if self.all_args.quick_tqdm else 10)
            all_train_losses = []
            train_infos = {}
            self.policy.agent_classifier.train()
            for _ in train_pbar:
                # sample data form expert dataset
                obs, actions, tags = self.expert_buffer.sample_classifier_batch_data(batch_size=self.all_args.batch_size)
                loss = self.agent_classifier_trainer.update(obs, actions, tags)
                all_train_losses.append(loss)
            print('episode', ep_i, 'mean_train_loss', np.mean(all_train_losses))
            train_infos['train_loss'] = np.mean(all_train_losses)
            # start a validating episode
            finish = False
            all_valid_losses = []
            self.expert_buffer.start_val()
            self.policy.agent_classifier.eval()
            while not finish:
                val_states, val_actions, val_tags, finish = self.expert_buffer.get_all_valid_data(self.all_args.batch_size)
                val_states, val_actions, val_tags = val_states.to(self.device), val_actions.to(self.device), val_tags.to(self.device)
                # print('val_states', val_states.shape)
                val_loss = self.policy.agent_classifier.get_loss(val_states, val_actions, val_tags)
                all_valid_losses.append(val_loss.item())
            print('episode', ep_i, 'mean_valid_loss', np.mean(all_valid_losses))
            train_infos['val_loss'] = np.mean(all_valid_losses)
            # # eval model at each episode
            # if self.use_eval:
            #     self.eval(ep_i)
            # log medium result to tensorboard
            self.log_train(train_infos, ep_i)
            # save model at each saving interval or at last episode
            if ep_i % self.save_interval == 0 or ep_i == self.all_args.num_epochs - 1:
                self.save_classifier(ep_i)
    ######## pretrain classifier

    ######## test env time
    def test_env_time(self):
        eval_battles_won = 0
        eval_episode = 0
        eval_step = 0

        eval_episode_rewards = []
        one_episode_rewards = []
        eval_obs, eval_share_obs, eval_available_actions = self.eval_envs.reset()
        # eval_rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
        # eval_masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)
        start_time = datetime.datetime.now()
        while True:
            eval_actions = []
            for env_id in range(self.n_eval_rollout_threads):
                actions = []
                for agent_id in range(self.num_agents):
                    avail_actions = eval_available_actions[env_id][agent_id]
                    avail_actions_ind = np.nonzero(avail_actions)[0]
                    action = np.random.choice(avail_actions_ind)
                    actions.append(action)
                eval_actions.append(np.array(actions).reshape(1, -1))
            eval_actions = np.concatenate(eval_actions, axis=0)

            eval_obs, eval_share_obs, eval_rewards, eval_dones, eval_infos, eval_available_actions = self.eval_envs.step(eval_actions)
            one_episode_rewards.append(eval_rewards)

            eval_dones_env = np.all(eval_dones, axis=1)

            eval_step += self.n_eval_rollout_threads

            for eval_i in range(self.n_eval_rollout_threads):
                if eval_dones_env[eval_i]:
                    eval_episode += 1
                    eval_episode_rewards.append(np.sum(one_episode_rewards, axis=0))
                    one_episode_rewards = []
                    if eval_infos[eval_i][0]['won']:
                        eval_battles_won += 1

            # if eval_episode >= self.all_args.eval_episodes:
            if eval_step >= 100000:
                # self.eval_envs.save_replay()
                eval_episode_rewards = np.array(eval_episode_rewards)
                eval_win_rate = eval_battles_won / eval_episode
                print("eval win rate is {}.".format(eval_win_rate))
                # long running
                end_time = datetime.datetime.now()
                print('endtime - starttime', (end_time - start_time).seconds)
                break
    ######## test env time
    ############################# my method

    def warmup(self):
        # reset env
        obs, share_obs, available_actions = self.envs.reset()

        # replay buffer
        if not self.use_centralized_V:
            share_obs = obs

        self.buffer.share_obs[0] = share_obs.copy()
        self.buffer.obs[0] = obs.copy()
        self.buffer.available_actions[0] = available_actions.copy()
        if self.all_args.mat_use_history:
            self.buffer.history_obs_record[0: self.buffer.history_obs_len - 1] = self.buffer.history_obs_record[1: self.buffer.history_obs_len]
            self.buffer.history_obs_record[:, :, self.buffer.history_obs_len - 1] = _t2n(obs).copy()
            self.buffer.his_obs[0] = self.buffer.history_obs_record.copy()
        # if self.all_args._classifier_use_gru:
        #     self.buffer.history_act_record[0: self.buffer.history_obs_len - 1] = self.buffer.history_obs_record[1: self.buffer.history_obs_len]
        #     self.buffer.history_act_record[:, :, self.buffer.history_obs_len - 1] = _t2n(obs).copy()
        #     # self.buffer.his_act[0] = self.buffer.history_act_record.copy()

    @torch.no_grad()
    def collect(self, step):
        self.trainer.prep_rollout()
        value, action, action_log_prob, action_prob, rnn_state, rnn_state_critic, disc_value \
            = self.trainer.policy.get_actions(
            np.concatenate(self.buffer.share_obs[step]),
            np.concatenate(self.buffer.obs[step]),
            self.buffer.get_history_obs_record(),
            np.concatenate(self.buffer.rnn_states[step]),
            np.concatenate(self.buffer.rnn_states_critic[step]),
            np.concatenate(self.buffer.masks[step]),
            np.concatenate(self.buffer.available_actions[step]))
        # [self.envs, agents, dim]
        values = np.array(np.split(_t2n(value), self.n_rollout_threads))
        actions = np.array(np.split(_t2n(action), self.n_rollout_threads))
        action_log_probs = np.array(np.split(_t2n(action_log_prob), self.n_rollout_threads))
        action_probs = np.array(np.split(_t2n(action_prob), self.n_rollout_threads))
        rnn_states = np.array(np.split(_t2n(rnn_state), self.n_rollout_threads))
        rnn_states_critic = np.array(np.split(_t2n(rnn_state_critic), self.n_rollout_threads))
        disc_values = np.array(np.split(_t2n(disc_value), self.n_rollout_threads))

        return values, actions, action_log_probs, action_probs, rnn_states, rnn_states_critic, disc_values

    def insert(self, data):
        obs, share_obs, rewards, dones, infos, available_actions, \
        values, actions, action_log_probs, action_probs, \
        rnn_states, rnn_states_critic, disc_values, classifier_rewards = data

        dones_env = np.all(dones, axis=1)

        rnn_states[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
        rnn_states_critic[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, *self.buffer.rnn_states_critic.shape[3:]), dtype=np.float32)

        masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)
        masks[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)

        active_masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)
        active_masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32)
        active_masks[dones_env == True] = np.ones(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)

        bad_masks = np.array([[[0.0] if info[agent_id]['bad_transition'] else [1.0] for agent_id in range(self.num_agents)] for info in infos])
        
        if not self.use_centralized_V:
            share_obs = obs

        self.buffer.insert(share_obs, obs, rnn_states, rnn_states_critic,
                           actions, action_log_probs, action_probs, values, rewards, dones, masks, bad_masks, active_masks,
                           available_actions, disc_values, classifier_rewards)

    def log_train(self, train_infos, total_num_steps):
        train_infos["average_step_rewards"] = np.mean(self.buffer.rewards)
        for k, v in train_infos.items():
            if self.use_wandb:
                wandb.log({k: v}, step=total_num_steps)
            else:
                if k != 'disc_expert_score' and k != 'disc_policy_score' and \
                        k != 'all_expert_agent_mean_scores' and k != 'all_policy_agent_mean_scores':
                    self.writter.add_scalars(k, {k: v}, total_num_steps)
                else:
                    self.writter.add_scalars(k, v, total_num_steps)
    
    @torch.no_grad()
    def eval(self, total_num_steps):
        eval_battles_won = 0
        eval_episode = 0

        eval_episode_rewards = []
        one_episode_rewards = []

        eval_obs, eval_share_obs, eval_available_actions = self.eval_envs.reset()

        eval_rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
        eval_masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)

        while True:
            self.trainer.prep_rollout()
            eval_actions, eval_rnn_states = \
                self.trainer.policy.act(
                    np.concatenate(eval_share_obs),
                    np.concatenate(eval_obs),
                    self.buffer.get_history_obs_record(),
                    np.concatenate(eval_rnn_states),
                    np.concatenate(eval_masks),
                    np.concatenate(eval_available_actions),
                    deterministic=True)
            eval_actions = np.array(np.split(_t2n(eval_actions), self.n_eval_rollout_threads))
            eval_rnn_states = np.array(np.split(_t2n(eval_rnn_states), self.n_eval_rollout_threads))
            
            # Obser reward and next obs
            eval_obs, eval_share_obs, eval_rewards, eval_dones, eval_infos, eval_available_actions = self.eval_envs.step(eval_actions)
            one_episode_rewards.append(eval_rewards)

            eval_dones_env = np.all(eval_dones, axis=1)

            eval_rnn_states[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
            eval_masks = np.ones((self.all_args.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)
            eval_masks[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)

            for eval_i in range(self.n_eval_rollout_threads):
                if eval_dones_env[eval_i]:
                    eval_episode += 1
                    eval_episode_rewards.append(np.sum(one_episode_rewards, axis=0))
                    one_episode_rewards = []
                    if eval_infos[eval_i][0]['won']:
                        eval_battles_won += 1

            if eval_episode >= self.all_args.eval_episodes:
                # self.eval_envs.save_replay()
                eval_episode_rewards = np.array(eval_episode_rewards).mean() if self.all_args.n_eval_rollout_threads > 1\
                    else np.array(eval_episode_rewards)
                eval_env_infos = {'eval_average_episode_rewards': eval_episode_rewards}                
                self.log_env(eval_env_infos, total_num_steps)
                eval_win_rate = eval_battles_won/eval_episode
                print("eval win rate is {}.".format(eval_win_rate))
                if self.use_wandb:
                    wandb.log({"eval_win_rate": eval_win_rate}, step=total_num_steps)
                else:
                    self.writter.add_scalars("eval_win_rate", {"eval_win_rate": eval_win_rate}, total_num_steps)
                break
