import torch
from code_ptmc_mappo.algorithms.utils.cnn import CNNBase
from code_ptmc_mappo.algorithms.utils.mlp import MLPBase
from code_ptmc_mappo.algorithms.utils.act import ACTLayer
from code_ptmc_mappo.utils.util import get_shape_from_obs_space
from code_ptmc_mappo.algorithms.r_mappo.algorithm.r_actor_critic import R_Actor
import numpy as np

class Tacit_constr():
    def __init__(self, args, obs_space, action_space):
        self.hidden_size = args.hidden_size
        self._gain = args.gain
        self._use_orthogonal = args.use_orthogonal
        self.obs_space = obs_space
        self.act_space = action_space
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.args = args
        obs_shape = get_shape_from_obs_space(obs_space)
        base = CNNBase if len(obs_shape) == 3 else MLPBase
        self.base = base(args, obs_shape)
        self.act = ACTLayer(action_space, self.hidden_size, self._use_orthogonal, self._gain, args)

    def trim(self, rewards_batch, action_probs_1, action_probs_2, actions):
        f_switch = (rewards_batch >= 0.0).int()
        if actions.dtype != torch.long:
            actions = torch.from_numpy(actions).long().to(self.device)
        chosen_action_probs_1 = action_probs_1.gather(1, actions)
        chosen_action_probs_2 = action_probs_2.gather(1, actions)
        action_ratio = chosen_action_probs_1 / chosen_action_probs_2
        ar_mask = (action_ratio >= 0.999) & (action_ratio <= 1.001)
        f_switch[ar_mask] = 0
        return f_switch

    def tac_pretr_div(self, policy, obs, available_actions, actions, rewards, tacit_model_dir): # tacit_pre-trained_divergence
        # ippo action_probs
        ippo_model_path = str(tacit_model_dir) + '/actor.pt'
        ippo_state_dict = torch.load(ippo_model_path)
        ippo_actor = R_Actor(self.args, self.obs_space, self.act_space, self.device)
        ippo_actor.load_state_dict(ippo_state_dict)
        mask_obs = self.obs_mask(obs)
        mask_obs = torch.from_numpy(mask_obs).float().to(self.device)
        actor_features_1 = ippo_actor.base(mask_obs)
        mask_available_actions = available_actions.copy()
        mask_available_actions[:, 6:] = 0.0
        action_probs_1 = ippo_actor.act.get_probs(actor_features_1, available_actions)
        action_probs_m1 = ippo_actor.act.get_probs(actor_features_1, mask_available_actions)

        # mappo action_probs
        obs = torch.from_numpy(obs).float().to(self.device)
        actor_features_2 = policy.actor.base(obs)
        action_probs_2 = policy.actor.act.get_probs(actor_features_2, available_actions)
        action_probs_m2 = policy.actor.act.get_probs(actor_features_2, mask_available_actions)

        f_switch = self.trim(rewards, action_probs_1, action_probs_2, actions)

        eps = 1e-10
        kl_elementwise = action_probs_m1 * (
                torch.log(action_probs_m1 + eps) - torch.log(action_probs_m2 + eps)
        )
        mask_available_actions = torch.from_numpy(mask_available_actions).float().to(self.device)
        kl_elementwise_masked = kl_elementwise * mask_available_actions
        kl_sum = torch.sum(kl_elementwise_masked, dim=1, keepdim=True)
        valid_counts = mask_available_actions.sum(dim=1, keepdim=True) + eps
        kl_mean = kl_sum / valid_counts
        tac_constr = f_switch * kl_mean

        mask_tac_constr = torch.ones_like(tac_constr)
        zero_indices = (mask_available_actions[:, 0] == 1)
        mask_tac_constr[zero_indices] = 0
        tac_constr_mean = (tac_constr * mask_tac_constr).sum() / mask_tac_constr.sum()
        return tac_constr_mean

    def obs_mask(self, obs):
        obs_shape = self.obs_space
        h1, w1 = obs_shape[1]
        h2, w2 = obs_shape[2]
        obs[:, h1 * w1 : h1 * w1 + h2 * w2] = 0.0
        return obs
