import torch
import torch.nn as nn  # Already likely imported, but ensure it exists
from torch.nn.parallel import DataParallel  # Explicit import for clarity
import numpy as np
from mat.utils.util import update_linear_schedule
from mat.utils.util import get_shape_from_obs_space, get_shape_from_act_space
from mat.algorithms.utils.util import check
from mat.algorithms.mat.algorithm.ma_transformer import MultiAgentTransformer


class TransformerPolicy:
    """
    MAT Policy  class. Wraps actor and critic networks to compute actions and value function predictions.

    :param args: (argparse.Namespace) arguments containing relevant model and policy information.
    :param obs_space: (gym.Space) observation space.
    :param cent_obs_space: (gym.Space) value function input space (centralized input for MAPPO, decentralized for IPPO).
    :param action_space: (gym.Space) action space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """

    def __init__(self, args, obs_space, cent_obs_space, act_space, num_agents, device=torch.device("cpu")):
        self.device = device
        self.algorithm_name = args.algorithm_name
        self.lr = args.lr
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay
        self._use_policy_active_masks = args.use_policy_active_masks
        self.num_quants = args.n_quants
        self.n_embd = args.n_embd
        self.truelyDistributed = args.truelyDistributed
        
        if act_space.__class__.__name__ == 'Box':
            self.action_type = 'Continuous'
        else:
            self.action_type = 'Discrete'

        self.obs_dim = get_shape_from_obs_space(obs_space)[0]
        self.share_obs_dim = get_shape_from_obs_space(cent_obs_space)[0]
        if self.action_type == 'Discrete':
            self.act_dim = act_space.n
            self.act_num = 1
        else:
            self.act_dim = act_space.shape[0]
            self.act_num = self.act_dim

        self.num_agents = num_agents
        self.tpdv = dict(dtype=torch.float32, device=device)
        
        self.obs_dim_ = self.obs_dim

        if self.algorithm_name in ["mat", "mat_dec"]:
            from mat.algorithms.mat.algorithm.ma_transformer import MultiAgentTransformer as MAT
        elif self.algorithm_name == "mat_gru":
            from mat.algorithms.mat.algorithm.mat_gru import MultiAgentGRU as MAT
        elif self.algorithm_name == "mat_decoder":
            from mat.algorithms.mat.algorithm.mat_decoder import MultiAgentDecoder as MAT
        elif self.algorithm_name == "mat_encoder":
            from mat.algorithms.mat.algorithm.mat_encoder import MultiAgentEncoder as MAT
        elif self.algorithm_name == "mappo_gnn" or self.algorithm_name == "mappo_dgnn" or self.algorithm_name == 'mappo_dgnn_dsgd':
            from mat.algorithms.mat.algorithm.ma_gnn_transformer_new import MultiAgentGnnTransformer as MAT
            self.obs_dim_ = self.obs_dim+args.n_embd
        elif self.algorithm_name == "generative_mat_gnn":
            from mat.algorithms.mat.algorithm.ma_gnn_transformer import MultiAgentGnnTransformer as MAT
            self.obs_dim_ = 2*args.n_embd
        elif self.algorithm_name == "vlm_mat_gnn":
            from mat.algorithms.mat.algorithm.vlm_ma_gnn_transformer import VLM_MultiAgentGnnTransformer as MAT
            self.obs_dim_ = 256
        else:
            raise NotImplementedError
            
        self.transformer = MAT(args, self.share_obs_dim, self.obs_dim, self.act_dim, num_agents,
                               n_block=args.n_block, n_embd=args.n_embd, n_head=args.n_head,
                               encode_state=args.encode_state, device=device,
                               action_type=self.action_type, dec_actor=args.dec_actor,
                               share_actor=args.share_actor, num_quants=args.n_quants)
        
        if args.env_name == "hands":
            self.transformer.zero_std()

        if not self.truelyDistributed:
            self.optimizer = torch.optim.Adam(self.transformer.parameters(),
                                            lr=self.lr, eps=self.opti_eps,
                                            weight_decay=self.weight_decay)
        else:
            if self.algorithm_name == "mappo_dgnn" or self.algorithm_name == "mappo_dgnn_dsgd":
                self.optimizers = [
                    torch.optim.Adam(
                        list(self.transformer.decoder.mlp_[i].parameters()) +
                        list(self.transformer.encoder.head_[i].parameters()) +
                        list(self.transformer.obs_encoder.agent_encoders[i].parameters()) +
                        list(self.transformer.obs_encoder.node_classifier_heads[i].parameters()) +
                        [self.transformer.obs_encoder.atts[k][i] for k in range(self.transformer.obs_encoder.K)],
                        # [self.transformer.obs_encoder.hop_atts[k][i] for k in range(self.transformer.obs_encoder.K + 1)] +
                        # [self.transformer.obs_encoder.hop_biases[k][i] for k in range(self.transformer.obs_encoder.K + 1)],
                        lr=self.lr,
                        eps=self.opti_eps,
                        weight_decay=self.weight_decay,
                    )
                    for i in range(self.num_agents)
                ]
            else:
                self.optimizers = [
                    torch.optim.Adam(
                        list(self.transformer.decoder.mlp_[i].parameters()) +
                        list(self.transformer.encoder.head_[i].parameters()) +
                        list(self.transformer.obs_encoder.agent_encoders[i].parameters()) +
                        list(self.transformer.obs_encoder.node_classifier_heads[i].parameters()) +
                        [self.transformer.obs_encoder.atts[k][0] for k in range(self.transformer.obs_encoder.K)],
                        lr=self.lr,
                        eps=self.opti_eps,
                        weight_decay=self.weight_decay,
                    )
                    for i in range(self.num_agents)
                ]

    def agent_parameters(self, agent_idx):
        """Returns all parameters specific to one agent"""
        params = []
        
        # MLP and head parameters
        params.extend(list(self.transformer.decoder.mlp_[agent_idx].parameters()))
        params.extend(list(self.transformer.encoder.head_[agent_idx].parameters()))
        
        # Observation encoder components
        params.extend(list(self.transformer.obs_encoder.agent_encoders[agent_idx].parameters()))
        params.extend(list(self.transformer.obs_encoder.node_classifier_heads[agent_idx].parameters()))
        params.extend([self.transformer.obs_encoder.atts[k][agent_idx] for k in range(self.transformer.obs_encoder.K)])
        # params.extend([self.transformer.obs_encoder.hop_atts[k][agent_idx] for k in range(self.transformer.obs_encoder.K)])
        # params.extend([self.transformer.obs_encoder.hop_biases[k][agent_idx] for k in range(self.transformer.obs_encoder.K)])
         
        return params

    def gnn_parameters(self, agent_idx):
        """Returns all parameters specific to one agent"""
        params = []
        
        # Observation encoder components
        params.extend(list(self.transformer.obs_encoder.agent_encoders[agent_idx].parameters()))
        params.extend(list(self.transformer.obs_encoder.node_classifier_heads[agent_idx].parameters()))
        params.extend([self.transformer.obs_encoder.atts[k][agent_idx] for k in range(self.transformer.obs_encoder.K)])
        
        return params

    def lr_decay(self, episode, episodes):
        """
        Decay the actor and critic learning rates.
        :param episode: (int) current training episode.
        :param episodes: (int) total number of training episodes.
        """
        update_linear_schedule(self.optimizer, episode, episodes, self.lr)

    def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None,
                    batched_edge_index=None, deterministic=False):
        """
        Compute actions and value function predictions for the given inputs.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.

        :return values: (torch.Tensor) value function predictions.
        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of chosen actions.
        :return rnn_states_actor: (torch.Tensor) updated actor network RNN states.
        :return rnn_states_critic: (torch.Tensor) updated critic network RNN states.
        """

        cent_obs = cent_obs.reshape(-1, self.num_agents, self.share_obs_dim)
        obs = obs.reshape(-1, self.num_agents, self.obs_dim_)
        
        if available_actions is not None:
            available_actions = available_actions.reshape(-1, self.num_agents, self.act_dim)

        if batched_edge_index is None:
            actions, action_log_probs, values = self.transformer.get_actions(cent_obs,
                                                                            obs,
                                                                            available_actions,
                                                                            deterministic)
            actions = actions.view(-1, self.act_num)        
            action_log_probs = action_log_probs.view(-1, self.act_num)
            values = values.view(-1, self.num_quants)

            # unused, just for compatibility
            rnn_states_actor = check(rnn_states_actor).to(**self.tpdv)
            rnn_states_critic = check(rnn_states_critic).to(**self.tpdv)
            return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic

        else:
            actions, action_log_probs, values = self.transformer.get_actions(cent_obs,
                                                                            obs,
                                                                            available_actions,
                                                                            deterministic,
                                                                            batched_edge_index)
            actions = actions.view(-1, self.act_num)        
            action_log_probs = action_log_probs.view(-1, self.act_num)
            values = values.view(-1, self.num_quants)

            # unused, just for compatibility
            rnn_states_actor = check(rnn_states_actor).to(**self.tpdv)
            rnn_states_critic = check(rnn_states_critic).to(**self.tpdv)
            return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic
        

    def get_values(self, cent_obs, obs, rnn_states_critic, masks, available_actions=None, action_hats=None):
        """
        Get value function predictions.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.

        :return values: (torch.Tensor) value function predictions.
        """

        cent_obs = cent_obs.reshape(-1, self.num_agents, self.share_obs_dim)
        obs = obs.reshape(-1, self.num_agents, self.obs_dim_)
        if available_actions is not None:
            available_actions = available_actions.reshape(-1, self.num_agents, self.act_dim)

        values = self.transformer.get_values(cent_obs, obs, available_actions, action_hats)

        values = values.view(-1, self.num_quants)

        return values

    def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, actions, masks,
                         available_actions=None, active_masks=None, agent_id=None, action_hats=None):
        """
        Get action logprobs / entropy and value function predictions for actor update.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param actions: (np.ndarray) actions whose log probabilites and entropy to compute.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return values: (torch.Tensor) value function predictions.
        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """
        cent_obs = cent_obs.reshape(-1, self.num_agents, self.share_obs_dim)
        obs = obs.reshape(-1, self.num_agents, self.obs_dim_)
        actions = actions.reshape(-1, self.num_agents, self.act_num)
        # action_hats = action_hats.reshape(-1, self.num_agents, action_hats.shape[-1])

        if available_actions is not None:
            available_actions = available_actions.reshape(-1, self.num_agents, self.act_dim)

        action_log_probs, values, entropy = self.transformer(cent_obs, obs, actions, available_actions, agent_id=agent_id)

        action_log_probs = action_log_probs.view(-1, self.num_agents, self.act_num)
        values = values.view(-1, self.num_agents, self.num_quants)
        entropy = entropy.view(-1, self.num_agents, self.act_num)

        if self._use_policy_active_masks and active_masks is not None:
            entropy = (entropy*active_masks).sum(dim=0)/active_masks.sum(dim=0)
            # entropy = entropy.mean(dim=0)
        else:
            entropy = entropy.mean(dim=0)

        return values, action_log_probs, entropy

    def act(self, cent_obs, obs, rnn_states_actor, masks, available_actions=None, deterministic=True, batched_edge_index=None):
        """
        Compute actions using the given inputs.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.
        """

        # this function is just a wrapper for compatibility
        rnn_states_critic = torch.zeros_like(rnn_states_actor)
        if batched_edge_index is None:
            _, actions, _, rnn_states_actor, _ = self.get_actions(cent_obs,
                                                                obs,
                                                                rnn_states_actor,
                                                                rnn_states_critic,
                                                                masks,
                                                                available_actions,
                                                                deterministic)
        else:
            _, actions, _, rnn_states_actor, _ = self.get_actions(cent_obs,
                                                                obs,
                                                                rnn_states_actor,
                                                                rnn_states_critic,
                                                                masks,
                                                                available_actions,
                                                                batched_edge_index,
                                                                deterministic)

        return actions, rnn_states_actor

    def save(self, save_dir, episode):
        torch.save(self.transformer.state_dict(), str(save_dir) + "/transformer_" + str(episode) + ".pt")

    def restore(self, model_dir):
        transformer_state_dict = torch.load(model_dir)
        self.transformer.load_state_dict(transformer_state_dict)
        # self.transformer.reset_std()

    def train(self):
        self.transformer.train()

    def eval(self):
        self.transformer.eval()

