import gym
import numpy as np
import torch
from onpolicy.algorithms.r_mappo.algorithm.r_actor_critic import R_Actor, R_Critic
from onpolicy.utils.util import update_linear_schedule
from onpolicy.algorithms.utils.util import check

def make_input_obs_space(state_dim, action_dim):
    return gym.spaces.Box(low=-np.inf * np.ones(state_dim + action_dim),
                          high=np.inf * np.ones(state_dim + action_dim))

def to_onehot(x, d):
    ret = torch.zeros(*x.shape, d).to(x).float()
    for i in range(d):
        ret[..., i] = (x == i).float()
    return ret

class DiffRLPolicy:
    """
    Diffusion-based RL Policy  class. Wraps actor and critic networks to compute actions and value function predictions.

    :param args: (argparse.Namespace) arguments containing relevant model and policy information.
    :param obs_space: (gym.Space) observation space.
    :param cent_obs_space: (gym.Space) value function input space (centralized input for MAPPO, decentralized for IPPO).
    :param action_space: (gym.Space) action space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """

    def __init__(self, args, obs_space, cent_obs_space, act_space, device=torch.device("cpu")):
        self.device = device
        self.lr = args.lr
        self.stack_obs = args.stack_obs
        self.max_action = args.max_action
        self.diffusion_timesteps = args.diffusion_timesteps
        self.predict_epsilon = bool(args.diffusion_predict_epsilon)
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.split_action_prob = getattr(args, "split_action_prob")

        self.critic_lr = args.critic_lr
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay

        self.dummy_act_space = args.dummy_act_space
        self.dummy_obs_dim = dummy_obs_dim = args.dummy_obs_dim
        self.dummy_state_dim = dummy_state_dim = args.dummy_state_dim
        self.dummy_action_dim = dummy_action_dim = args.dummy_act_dim
        self.dummy_action_input_dim = getattr(args, 'dummy_act_input_dim', dummy_action_dim)

        if isinstance(args.dummy_act_space, gym.spaces.Discrete):
            self.obs_space = make_input_obs_space(dummy_obs_dim, self.dummy_action_input_dim * args.dummy_act_space.n)
        elif isinstance(args.dummy_act_space, gym.spaces.MultiDiscrete):
            self.obs_space = make_input_obs_space(dummy_obs_dim, self.dummy_action_input_dim * args.dummy_act_space.nvec[0])
        else:
            self.obs_space = make_input_obs_space(dummy_obs_dim, self.dummy_action_input_dim)
        self.share_obs_space =  make_input_obs_space(dummy_state_dim, 0)
        self.act_space = [args.dummy_act_space, gym.spaces.Discrete(2)]

        self.diffusion_policy_dir = args.diffusion_policy_dir

        self.actor = R_Actor(args, self.obs_space, self.act_space, self.device)
        self.critic = R_Critic(args, self.share_obs_space, self.device)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.lr, eps=self.opti_eps,
                                                weight_decay=self.weight_decay)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.critic_lr,
                                                 eps=self.opti_eps,
                                                 weight_decay=self.weight_decay)

        from diffusion_q_learning.bc import build_diffusion_bc_agent

        self.diffusion_agent = build_diffusion_bc_agent(dummy_state_dim,
                                                        dummy_action_dim,
                                                        max_action=self.max_action,
                                                        stack_obs=self.stack_obs,
                                                        device=self.device,
                                                        n_timesteps=self.diffusion_timesteps,
                                                        predict_epsilon=self.predict_epsilon,
                                                        layer_norm=args.diffusion_layer_norm,
                                                        t_dim=args.diffusion_t_dim,
                                                        beta_schedule=getattr(args, "diffusion_beta_schedule", "vp"),
                                                        model_type=getattr(args, "diffusion_model_type", "MLP"))
        self.diffusion_agent.load_model(self.diffusion_policy_dir)

    def lr_decay(self, episode, episodes):
        """
        Decay the actor and critic learning rates.
        :param episode: (int) current training episode.
        :param episodes: (int) total number of training episodes.
        """
        update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)
        update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)

    def generate_reference_actions(self, stack_obs):
        with torch.no_grad():
            reference_actions = self.diffusion_agent.actor.sample(check(stack_obs).to(**self.tpdv))
        return reference_actions

    def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None,
                    deterministic=False, stack_obs=None, reference_actions=None):
        """
        Compute actions and value function predictions for the given inputs.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.

        :return values: (torch.Tensor) value function predictions.
        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of chosen actions.
        :return rnn_states_actor: (torch.Tensor) updated actor network RNN states.
        :return rnn_states_critic: (torch.Tensor) updated critic network RNN states.
        """
        if reference_actions is None:
            reference_actions = self.generate_reference_actions(stack_obs)
        if isinstance(self.dummy_act_space, gym.spaces.Discrete):
            reference_actions_input = to_onehot(reference_actions, self.dummy_act_space.n).reshape(*reference_actions.shape[:-1], -1)
        if isinstance(self.dummy_act_space, gym.spaces.MultiDiscrete):
            reference_actions_input = to_onehot(reference_actions, self.dummy_act_space.nvec[0]).reshape(*reference_actions.shape[:-1], -1)
        else:
            reference_actions_input = reference_actions
        obs = torch.cat([check(obs).to(**self.tpdv), reference_actions_input], dim=-1)

        # print("[DEBUG]", obs.shape, reference_actions_input.shape, reference_actions.shape, flush=True)

        actions, action_log_probs, rnn_states_actor = self.actor(obs,
                                                                 rnn_states_actor,
                                                                 masks,
                                                                 available_actions,
                                                                 deterministic)

        values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, masks)

        actions = torch.cat([reference_actions, actions], dim=-1)

        if self.split_action_prob:
            action_log_probs = torch.cat([torch.zeros(*action_log_probs.shape[:-1], reference_actions.shape[-1]).to(action_log_probs), action_log_probs], dim=-1)

        return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic

    def get_values(self, cent_obs, rnn_states_critic, masks):
        """
        Get value function predictions.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.

        :return values: (torch.Tensor) value function predictions.
        """
        values, _ = self.critic(cent_obs, rnn_states_critic, masks)
        return values

    def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, action, masks, latent_actions=None,
                         available_actions=None, active_masks=None):
        """
        Get action logprobs / entropy and value function predictions for actor update.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param action: (np.ndarray) actions whose log probabilites and entropy to compute.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return values: (torch.Tensor) value function predictions.
        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """
        # assert (action.shape[-1] == self.dummy_action_input_dim + 2), (action.shape, self.dummy_action_input_dim)
        reference_action = action[..., :self.dummy_action_input_dim]
        action = action[..., self.dummy_action_input_dim:]
        reference_actions = check(reference_action).to(**self.tpdv)
        if isinstance(self.dummy_act_space, gym.spaces.Discrete):
            reference_actions_input = to_onehot(reference_actions, self.dummy_act_space.n).reshape(*reference_actions.shape[:-1], -1)
        if isinstance(self.dummy_act_space, gym.spaces.MultiDiscrete):
            reference_actions_input = to_onehot(reference_actions, self.dummy_act_space.nvec[0]).reshape(*reference_actions.shape[:-1], -1)
        else:
            reference_actions_input = reference_actions
        obs = torch.cat([check(obs).to(**self.tpdv), reference_actions_input], dim=-1)

        action_log_probs, dist_entropy = self.actor.evaluate_actions(obs,
                                                                     rnn_states_actor,
                                                                     action,
                                                                     masks,
                                                                     available_actions,
                                                                     active_masks,
                                                                     )


        if self.split_action_prob:
            action_log_probs = torch.cat([torch.zeros(*action_log_probs.shape[:-1], reference_actions.shape[-1]).to(action_log_probs), action_log_probs], dim=-1)

        values, _ = self.critic(cent_obs, rnn_states_critic, masks)
        return values, action_log_probs, dist_entropy

    def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False, stack_obs=None):
        """
        Compute actions using the given inputs.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.
        """
        with torch.no_grad():
            reference_actions = self.diffusion_agent.actor.sample(check(stack_obs).to(**self.tpdv))
        obs = torch.cat([check(obs).to(**self.tpdv), reference_actions], dim=-1)

        actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)
        
        actions = torch.cat([reference_actions, actions], dim=-1)

        return actions, rnn_states_actor
