import sys
import time
from debug import debug_print, get_size
import torch
from onpolicy.algorithms.diffusion_ac.diffusion_actor_critic import Diffusion_R_Actor, Diffusion_R_Critic
from onpolicy.utils.util import update_linear_schedule


class Diffuison_R_MAPPOPolicy:
    """
    MAPPO Policy  class. Wraps actor and critic networks to compute actions and value function predictions.

    :param args: (argparse.Namespace) arguments containing relevant model and policy information.
    :param obs_space: (gym.Space) observation space.
    :param cent_obs_space: (gym.Space) value function input space (centralized input for MAPPO, decentralized for IPPO).
    :param action_space: (gym.Space) action space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """

    def __init__(self, args, obs_space, cent_obs_space, act_space, device=torch.device("cpu")):
        self.device = device
        self.lr = args.lr
        self.critic_lr = args.critic_lr
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay

        self.obs_space = obs_space
        self.share_obs_space = cent_obs_space
        self.act_space = act_space
        self.clone_weight_decay = args.clone_weight_decay

        self.actor = Diffusion_R_Actor(args, self.obs_space, self.act_space, self.device)
        self.critic = Diffusion_R_Critic(args, self.share_obs_space, self.act_space, self.device)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.lr, eps=self.opti_eps,
                                                weight_decay=self.weight_decay)
        self.clone_actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.lr, eps=self.opti_eps,
                                                weight_decay=self.clone_weight_decay)
        # self.actor_optimizer = torch.optim.SGD(self.actor.parameters(),
        #                                         lr=self.lr, momentum=0.5,
        #                                         weight_decay=self.weight_decay)
        self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(),
                                                 lr=self.critic_lr,
                                                 eps=self.opti_eps,
                                                 weight_decay=self.weight_decay)

    def lr_decay(self, episode, episodes):
        """
        Decay the actor and critic learning rates.
        :param episode: (int) current training episode.
        :param episodes: (int) total number of training episodes.
        """
        update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)
        update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)

    def get_actions(self, obs, rnn_states_actor, masks,    
                    available_actions=None,
                    deterministic=False, return_noise=False):
        """
        Compute actions and value function predictions for the given inputs.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.

        :return values: (torch.Tensor) value function predictions.
        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of chosen actions.
        :return rnn_states_actor: (torch.Tensor) updated actor network RNN states.
        :return rnn_states_critic: (torch.Tensor) updated critic network RNN states.
        """
        # debug_print(obs.shape)
        # debug_print(get_size(self.actor))
        action_seqs, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states_actor, noise = self.actor(obs,
                                                                 rnn_states_actor,
                                                                 masks,
                                                                 available_actions,
                                                                 deterministic)
        
        if not return_noise:
            return action_seqs, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states_actor
        else:
            return action_seqs, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states_actor, noise
    
    def critic_forward(self, cent_obs, rnn_states_critic, action_seqs, masks):

        values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, action_seqs, masks)
        # debug_print(cent_obs.shape, values.shape)
        return values, rnn_states_critic

    def get_values(self, cent_obs, action_seqs, rnn_states_critic, masks):
        """
        Get value function predictions.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.

        :return values: (torch.Tensor) value function predictions.
        """
        values, _ = self.critic(cent_obs, rnn_states_critic, action_seqs, masks)
        return values

    def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, sampled_actions, action_ts, agent_idx, masks,
                         available_actions=None, active_masks=None, joint_ppo=False, noises=None, r_active_masks_batch=None, latent_actions=None):
        """
        Get action logprobs / entropy and value function predictions for actor update.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param action: (np.ndarray) actions whose log probabilites and entropy to compute.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return values: (torch.Tensor) value function predictions.
        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """
        action_log_probs, latent_probs, dist_entropy = self.actor.evaluate_actions(obs,
                                                                     rnn_states_actor,
                                                                     sampled_actions,
                                                                     action_ts,
                                                                     agent_idx,
                                                                     masks,
                                                                     available_actions,
                                                                     active_masks, joint_ppo=joint_ppo, noises=noises, r_active_masks_batch=r_active_masks_batch, latent_actions=latent_actions)
        # debug_print('T1', torch.cuda.memory_allocated()/1024/1024)


        values, _ = self.critic.evaluate_states(cent_obs, rnn_states_critic, action_ts, masks)
        # debug_print(values.shape, cent_obs.shape)
        # debug_print('T2', torch.cuda.memory_allocated()/1024/1024)
        return values, action_log_probs, latent_probs, dist_entropy

    def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False):
        """
        Compute actions using the given inputs.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.
        """
        _, _, actions, _, _, rnn_states_actor, _ = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)
        return actions, rnn_states_actor
