import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from utils.util import print_network_grad
from utils.util import update_linear_schedule
from utils.util import get_shape_from_obs_space, get_shape_from_act_space
from utils.util import get_flat_grads, get_flat_params, conjugate_gradient, set_params
from algorithms.utils.util import check
from algorithms.mat.algorithm.ma_transformer import MultiAgentTransformer
from algorithms.mat.algorithm.ma_transformer_gail import MultiAgentTransformerGail
from algorithms.mat.algorithm.ma_transformer_gail_dec import MultiAgentTransformerGailDec
from algorithms.mat.algorithm.ma_transformer_gail_gru import MultiAgentTransformerGailGru
from algorithms.mat.algorithm.ma_transformer_gail_mlp import MultiAgentTransformerGailMLP
from algorithms.mat.algorithm.ma_transformer_gail_gmlp import MultiAgentTransformerGailGMLP
from algorithms.mat.algorithm.classifier import AgentClassifier


class TransformerPolicy:
    """
    MAT Policy  class. Wraps actor and critic networks to compute actions and value function predictions.

    :param args: (argparse.Namespace) arguments containing relevant model and policy information.
    :param obs_space: (gym.Space) observation space.
    :param cent_obs_space: (gym.Space) value function input space (centralized input for MAPPO, decentralized for IPPO).
    :param action_space: (gym.Space) action space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """

    def __init__(self, args, obs_space, cent_obs_space, act_space, num_agents, device=torch.device("cpu")):
        self.args = args
        self.device = device
        self.algorithm_name = args.algorithm_name
        self.lr = args.lr
        self.disc_lr = args.disc_lr
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay
        self._use_policy_active_masks = args.use_policy_active_masks
        self._disc_use_act_prob = args.disc_use_act_prob
        # use for training disc
        self._update_times = 0
        self._train_disc_flag = True
        self._dis_now_stop_round = 0
        if act_space.__class__.__name__ == 'Box':
            self.action_type = 'Continuous'
        else:
            self.action_type = 'Discrete'

        self.obs_dim = get_shape_from_obs_space(obs_space)[0]
        self.share_obs_dim = get_shape_from_obs_space(cent_obs_space)[0]
        if self.action_type == 'Discrete':
            self.act_dim = act_space.n
            self.act_num = 1
        else:
            print("act high: ", act_space.high)
            self.act_dim = act_space.shape[0]
            self.act_num = self.act_dim

        print("obs_dim: ", self.obs_dim)
        print("share_obs_dim: ", self.share_obs_dim)
        print("act_dim: ", self.act_dim)

        self.num_agents = num_agents
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.history_obs_len = args.history_obs_len

        if self.algorithm_name in ["mat", "mat_dec"]:
            if self.args.mat_use_history:
                from algorithms.mat.algorithm.ma_transformer_gail_gru import MultiAgentTransformerGailGru as MAT
            elif self.args.disc_use_decoder:
                from algorithms.mat.algorithm.ma_transformer_gail_dec import MultiAgentTransformerGailDec as MAT
            elif self.args.disc_use_mlp:
                from algorithms.mat.algorithm.ma_transformer_gail_mlp import MultiAgentTransformerGailMLP as MAT
            elif self.args.disc_use_gmlp:
                from algorithms.mat.algorithm.ma_transformer_gail_gmlp import MultiAgentTransformerGailGMLP as MAT
            else:
                from algorithms.mat.algorithm.ma_transformer_gail import MultiAgentTransformerGail as MAT
        elif self.algorithm_name == "mat_gru":
            from algorithms.mat.algorithm.mat_gru import MultiAgentGRU as MAT
        elif self.algorithm_name == "mat_decoder":
            from algorithms.mat.algorithm.mat_decoder import MultiAgentDecoder as MAT
        elif self.algorithm_name == "mat_encoder":
            from algorithms.mat.algorithm.mat_encoder import MultiAgentEncoder as MAT
        else:
            raise NotImplementedError

        self.transformer = MAT(
            self.share_obs_dim, self.obs_dim, self.act_dim, num_agents,
            n_block=args.n_block, n_embd=args.n_embd, n_head=args.n_head,
            disc_inner_dim=args.disc_inner_dim,
            encode_state=args.encode_state, device=device,
            action_type=self.action_type, dec_actor=args.dec_actor,
            share_actor=args.share_actor, use_gail=args.use_gail,
            disc_share_value=args.disc_share_value, cat_his_with_now=args.cat_his_with_now,
            history_obs_len=args.history_obs_len, gru_num_layers=args.gru_num_layers, gru_hidden_size=args.gru_hidden_size
        ) if self.args.mat_use_history else (MAT(
            self.share_obs_dim, self.obs_dim, self.act_dim, num_agents,
            n_block=args.n_block, n_embd=args.n_embd, n_head=args.n_head,
            disc_inner_dim=args.disc_inner_dim, encode_state=args.encode_state, device=device,
            action_type=self.action_type, dec_actor=args.dec_actor, share_actor=args.share_actor, use_gail=args.use_gail,
            disc_share_value=args.disc_share_value, disc_mask_action=args.disc_mask_action,
            disc_cal_last_loss=args.disc_cal_last_loss, disc_drop_cross_atten=args.disc_drop_cross_atten
        ) if self.args.disc_use_decoder else (MAT(
            self.share_obs_dim, self.obs_dim, self.act_dim, num_agents,
            n_block=args.n_block, n_embd=args.n_embd, n_head=args.n_head,
            disc_inner_dim=args.disc_inner_dim, encode_state=args.encode_state, device=device,
            action_type=self.action_type, dec_actor=args.dec_actor, share_actor=args.share_actor,
            use_gail=args.use_gail, disc_share_value=args.disc_share_value,
            disc_gmlp_dim_ff=args.disc_gmlp_dim_ff, disc_gmlp_use_causal=args.disc_gmlp_use_causal,
            disc_gmlp_add_embd=args.disc_gmlp_add_embd, disc_gmlp_obs_encoder=args.disc_gmlp_obs_encoder,
        ) if self.args.disc_use_gmlp else (MAT(
            self.share_obs_dim, self.obs_dim, self.act_dim, num_agents,
            n_block=args.n_block, n_embd=args.n_embd, n_head=args.n_head,
            disc_inner_dim=args.disc_inner_dim,
            encode_state=args.encode_state, device=device,
            action_type=self.action_type, dec_actor=args.dec_actor,
            share_actor=args.share_actor, use_gail=args.use_gail,
            disc_use_act_prob=args.disc_use_act_prob, disc_agent_independent=args.disc_agent_independent,
            disc_mlp_obs_encoder=args.disc_mlp_obs_encoder,
            disc_mlp_act_encoder=args.disc_mlp_act_encoder,
            disc_mlp_use_first_token=args.disc_mlp_use_first_token,
        ) if self.args.disc_use_mlp else MAT(
            self.share_obs_dim, self.obs_dim, self.act_dim, num_agents,
            n_block=args.n_block, n_embd=args.n_embd, n_head=args.n_head,
            disc_inner_dim=args.disc_inner_dim,
            encode_state=args.encode_state, device=device,
            action_type=self.action_type, dec_actor=args.dec_actor,
            share_actor=args.share_actor, use_gail=args.use_gail,
            disc_use_act_prob=args.disc_use_act_prob, disc_agent_independent=args.disc_agent_independent
        ))))  # mlp disc and enc disc share the same init form

        # if args.env_name == "hands":
        if self.action_type != 'Discrete':
            self.transformer.zero_std()
        # count the volume of parameters of model
        # Total_params = 0
        # Trainable_params = 0
        # NonTrainable_params = 0
        # for param in self.transformer.parameters():
        #     mulValue = np.prod(param.size())
        #     Total_params += mulValue
        #     if param.requires_grad:
        #         Trainable_params += mulValue
        #     else:
        #         NonTrainable_params += mulValue
        # print(f'Total params: {Total_params}')
        # print(f'Trainable params: {Trainable_params}')
        # print(f'Non-trainable params: {NonTrainable_params}')

        # only need optimizer and scheduler to train MAT for behavior cloning
        self.optimizer = torch.optim.Adam(
            self.transformer.parameters(), lr=self.lr,
            eps=self.opti_eps, weight_decay=self.weight_decay,
        ) if not self.args.use_gail else None
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=args.lr_decay_step_size,
            gamma=args.lr_decay_gamma,
        ) if not self.args.use_gail else None
        # only need optimizer to train critic for gail
        self.optimizer_encoder = torch.optim.Adam(
            self.transformer.encoder.parameters(), lr=self.lr, eps=self.opti_eps,
        ) if self.args.use_gail else None
        # only need optimizer to train actor for gail
        self.optimizer_decoder = torch.optim.Adam(
            self.transformer.decoder.parameters(), lr=self.lr, eps=self.opti_eps,
        ) if self.args.use_gail else None
        # only need optimizer to train critic for gail
        self.optimizer_critic = torch.optim.Adam(
            self.transformer.critic.parameters(), lr=self.lr, eps=self.opti_eps,
        ) if self.args.use_gail else None
        # only need optimizer to train discriminator for gail
        self.optimizer_discriminator = torch.optim.Adam(
            self.transformer.discriminator.parameters(), lr=self.disc_lr,
        ) if self.args.use_gail else None
        self.scheduler_discriminator = torch.optim.lr_scheduler.StepLR(
            self.optimizer_discriminator,
            step_size=args.disc_lr_decay_step,
            gamma=args.disc_lr_decay_gamma,
        ) if self.args.use_gail and self.args.use_disc_lr_decay else None
        # add for classifier pretrain
        self.agent_classifier = AgentClassifier(
            state_dim=self.share_obs_dim, obs_dim=self.obs_dim, action_dim=self.act_dim,
            n_embd=args.classifier_n_embd, n_agent=num_agents, action_type=self.action_type, device=self.device,
            classifier_only_action=self.args.classifier_only_action, classifier_use_gru=self.args.classifier_use_gru,
            classifier_gru_his_len=self.args.classifier_gru_his_len, classifier_gru_num_layer=self.args.classifier_gru_num_layer,
            classifier_use_act_enc=self.args.classifier_use_act_enc, classifier_act_enc_mask=self.args.classifier_act_enc_mask,
            classifier_use_data_tag=self.args.classifier_use_data_tag, classifier_data_tag_num=self.args.classifier_data_tag_num,
            classifier_enc_n_block=self.args.n_block, classifier_enc_n_head=self.args.n_head,
        ) if args.pretrain_classifier or args.use_classifier_reward else None
        # load pretrain classifier parameters
        if args.use_classifier_reward:
            self.restore_classifier()

    def lr_decay(self, episode, episodes):
        """
        Decay the actor and critic learning rates.
        :param episode: (int) current training episode.
        :param episodes: (int) total number of training episodes.
        """
        update_linear_schedule(self.optimizer, episode, episodes, self.lr)

    def get_actions(self, cent_obs, obs, history_obs, rnn_states_actor, rnn_states_critic, masks,
                    available_actions=None, deterministic=False):
        """
        Compute actions and value function predictions for the given inputs.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.

        :return values: (torch.Tensor) value function predictions.
        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of chosen actions.
        :return rnn_states_actor: (torch.Tensor) updated actor network RNN states.
        :return rnn_states_critic: (torch.Tensor) updated critic network RNN states.
        """

        """
        cent_obs (n_rollout_threads, num_agents, share_obs_dim)
        obs (n_rollout_threads, num_agents, obs_dim)
        history_obs (n_rollout_threads, num_agents, history_obs_len, his_obs_dim)
        """
        cent_obs = cent_obs.reshape(-1, self.num_agents, self.share_obs_dim)
        obs = obs.reshape(-1, self.num_agents, self.obs_dim)
        history_obs = history_obs.reshape(-1, self.num_agents, self.history_obs_len, self.obs_dim)
        if available_actions is not None:
            available_actions = available_actions.reshape(-1, self.num_agents, self.act_dim)

        actions, action_log_probs, action_probs = self.transformer.get_actions(
            state=cent_obs, obs=obs, available_actions=available_actions, deterministic=deterministic
        ) if not self.args.mat_use_history else \
            self.transformer.get_actions(
                state=cent_obs, obs=obs, history_obs=history_obs, available_actions=available_actions, deterministic=deterministic
            )

        actions = actions.view(-1, self.act_num)
        action_log_probs = action_log_probs.view(-1, self.act_num)
        action_probs = action_probs.view(-1, self.act_dim)

        # get pred value from critic(No need for eval stage)
        # get pred reward from discriminator(No need for eval stage)
        if not deterministic:
            values = self.get_critic_values(cent_obs, obs, history_obs)
            values = values.view(-1, 1)

            actions_tmp = np.array(np.split(
                actions.detach().cpu().numpy() if not self._disc_use_act_prob else action_probs.detach().cpu().numpy(),
                self.args.n_rollout_threads))
            disc_values = self.get_discriminator_reward(cent_obs, obs, actions_tmp)
            disc_values = disc_values.view(-1, 1)
        else:
            disc_values = None
            values = None
        # unused, just for compatibility
        rnn_states_actor = check(rnn_states_actor).to(**self.tpdv) if rnn_states_actor is not None else None
        rnn_states_critic = check(rnn_states_critic).to(**self.tpdv) if rnn_states_critic is not None else None
        return values, actions, action_log_probs, action_probs, rnn_states_actor, rnn_states_critic, disc_values

    def get_values(self, cent_obs, obs, history_obs, rnn_states_critic, masks, available_actions=None):
        """
        Get value function predictions.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.

        :return values: (torch.Tensor) value function predictions.
        """

        cent_obs = cent_obs.reshape(-1, self.num_agents, self.share_obs_dim)
        obs = obs.reshape(-1, self.num_agents, self.obs_dim)
        history_obs = history_obs.reshape(-1, self.num_agents, self.history_obs_len, self.obs_dim)
        if available_actions is not None:
            available_actions = available_actions.reshape(-1, self.num_agents, self.act_dim)

        values = self.transformer.get_critic_values(cent_obs, obs) \
            if not self.args.mat_use_history else \
            self.transformer.get_critic_values(cent_obs, obs, history_obs)
        values = values.view(-1, 1)

        return values

    def evaluate_actions(self, cent_obs, obs, history_obs, rnn_states_actor, rnn_states_critic, actions, masks,
                         available_actions=None, active_masks=None):
        """
        Get action logprobs / entropy and value function predictions for actor update.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param actions: (np.ndarray) actions whose log probabilites and entropy to compute.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return values: (torch.Tensor) value function predictions.
        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """
        cent_obs = cent_obs.reshape(-1, self.num_agents, self.share_obs_dim)
        obs = obs.reshape(-1, self.num_agents, self.obs_dim)
        history_obs = history_obs.reshape(-1, self.num_agents, self.history_obs_len, self.obs_dim)
        actions = actions.reshape(-1, self.num_agents, self.act_num)
        if available_actions is not None:
            available_actions = available_actions.reshape(-1, self.num_agents, self.act_dim)

        action_log_probs, entropy = self.transformer(cent_obs, obs, actions, available_actions) \
            if not self.args.mat_use_history else \
            self.transformer(cent_obs, obs, history_obs, actions, available_actions)

        values = self.get_critic_values(cent_obs, obs, history_obs)
        action_log_probs = action_log_probs.view(-1, self.act_num)
        values = values.view(-1, 1)
        entropy = entropy.view(-1, self.act_num)

        if self._use_policy_active_masks and active_masks is not None:
            entropy = (entropy*active_masks).sum()/active_masks.sum()
        else:
            entropy = entropy.mean()

        return values, action_log_probs, entropy

    ############### my method
    def train_offline(self, share_obs, obs, actions):
        # config data type and device
        share_obs = check(share_obs).to(**self.tpdv).reshape(-1, self.num_agents, self.share_obs_dim)
        obs = check(obs).to(**self.tpdv).reshape(-1, self.num_agents, self.obs_dim)
        actions = check(actions).to(**self.tpdv).reshape(-1, self.num_agents, self.act_num)
        # get likelihood probability form mat
        action_log_probs, entropy = self.transformer(
            state=share_obs, obs=obs, action=actions, available_actions=None,
        )
        loss = -torch.mean(action_log_probs)

        return loss

    def act(self, cent_obs, obs, history_obs, rnn_states_actor, masks, available_actions=None, deterministic=True):
        """
        Compute actions using the given inputs.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.
        """

        # this function is just a wrapper for compatibility
        rnn_states_critic = np.zeros_like(rnn_states_actor) if rnn_states_actor is not None else None
        _, actions, _, _, rnn_states_actor, _, _ = self.get_actions(
            cent_obs, obs, history_obs, rnn_states_actor, rnn_states_critic, masks, available_actions, deterministic)

        return actions, rnn_states_actor

    def get_critic_values(self, state, obs, history_obs):
        return self.transformer.get_critic_values(state, obs) \
            if not self.args.mat_use_history else \
            self.transformer.get_critic_values(state, obs, history_obs)

    ###### add for gail disc
    def get_discriminator_logit(self, state, obs, action):
        state = state.reshape(-1, self.num_agents, self.share_obs_dim)
        obs = obs.reshape(-1, self.num_agents, self.obs_dim)
        action = action.reshape(-1, self.num_agents, self.act_num) \
            if action.shape[-1] == 1 else action.reshape(-1, self.num_agents, self.act_dim)

        return self.transformer.get_discriminator_logit(state, obs, action)

    def get_discriminator_reward(self, state, obs, action):
        return self.transformer.get_discriminator_reward(state, obs, action)

    def get_discriminator_rewards_from_logits(self, logits):
        return self.transformer.get_discriminator_rewards_from_logits(logits)
    ###### add for gail disc

    def get_classifier_reward(self, obs, action):
        obs = torch.from_numpy(obs.reshape(-1, self.obs_dim)).to(**self.tpdv)
        if self.args.classifier_use_gru:
            action = torch.from_numpy(action.reshape(
                -1, self.args.classifier_gru_his_len, self.act_num
            ) if not self.args.classifier_use_act_enc else action).to(**self.tpdv)
        else:
            action = torch.from_numpy(action.reshape(-1, self.act_num)).to(**self.tpdv)

        return self.agent_classifier.get_entropy_reward(obs, action)

    def save_classifier(self, save_dir, episode):
        torch.save(self.agent_classifier.state_dict(), str(save_dir) + "/classifier.pt")

    def restore_classifier(self):
        if self.args.classifier_model_path is not None:
            classifier_state_dict = torch.load(self.args.classifier_model_path)
            self.agent_classifier.load_state_dict(classifier_state_dict)
            print('--------------- load classifier -----------------')

    ############### my method

    def save(self, save_dir, episode):
        if self.args.use_gail:
            torch.save(self.transformer.state_dict(), str(save_dir) + "/transformer" + str(episode) + ".pt")
            torch.save({
                'optimizer_encoder': self.optimizer_encoder.state_dict(),
                'optimizer_decoder': self.optimizer_decoder.state_dict(),
                'optimizer_critic': self.optimizer_critic.state_dict(),
                'optimizer_discriminator': self.optimizer_discriminator.state_dict(),
                'train_disc_flag': self._train_disc_flag,
                'dis_now_stop_round': self._dis_now_stop_round,
                'update_times': self._update_times,
            }, str(save_dir) + "/optimizer_" + str(episode) + ".pt")
        else:
            torch.save(self.transformer.state_dict(), str(save_dir) + "/transformer.pt")
            # torch.save(self.optimizer.state_dict(), str(save_dir) + "/optimizer_" + str(episode) + ".pt")

    def restore(self, model_dir, optim_dir, encoder_dir):
        # load pretrained model parameters
        if model_dir is not None:
            transformer_state_dict = torch.load(model_dir)
            self.transformer.load_state_dict(transformer_state_dict)
            print('--------------- load transformer -----------------')
        if optim_dir is not None:
            optim_state_dict = torch.load(optim_dir)
            self.optimizer_encoder.load_state_dict(optim_state_dict['optimizer_encoder'])
            self.optimizer_decoder.load_state_dict(optim_state_dict['optimizer_decoder'])
            self.optimizer_critic.load_state_dict(optim_state_dict['optimizer_critic'])
            self.optimizer_discriminator.load_state_dict(optim_state_dict['optimizer_discriminator'])
            print('--------------- load optim -----------------')
            if 'train_disc_flag' in optim_state_dict:
                self._train_disc_flag = optim_state_dict['train_disc_flag']
                print('--------------- load train_disc_flag -----------------')
            if 'dis_now_stop_round' in optim_state_dict:
                self._dis_now_stop_round = optim_state_dict['dis_now_stop_round']
                print('--------------- load dis_now_stop_round -----------------')
            if 'update_times' in optim_state_dict:
                self._update_times = optim_state_dict['update_times']
                print('--------------- load update_times -----------------')
        if encoder_dir is not None and model_dir is None:
            encoder_state_dict = torch.load(encoder_dir)
            self.transformer.encoder.load_state_dict(encoder_state_dict, strict=False)
            print('--------------- load encoder -----------------')
        # self.transformer.reset_std()

    def train(self):
        self.transformer.train()

    def eval(self):
        self.transformer.eval()

