from modules.agents.n_rnn_agent import NRNNAgent
import torch as th
import numpy as np
import json
import os

class Tacit_constr():
    def __init__(self, batch, args):
        self.batch = batch
        self.data = batch.data.transition_data
        self.args = args
        self.obs = self.data['obs'].data
        self.hidden_state = self.data['hidden_state'].data
        self.action_2 = self.data['actions'].data
        self.probs_2 = self.data['probs'].data
        self.avail_action = self.data["avail_actions"].data
        self.actions_onehot = self.data["actions_onehot"].data
        self.reward = self.data['reward'].data
        if self.args.env == "sc2_v2":
            with open(".../envs/smac_v2/obs_size.json") as f:
                config = next(
                    (item for item in json.load(f)
                     if (item["map_name"] == self.args.env_args["map_name"] and
                     item["n_units"] == self.args.env_args['capability_config']['n_units'] and
                     item["n_enemies"] == self.args.env_args['capability_config']['n_enemies'])),
                    None)
        elif self.args.env == "sc2":
            with open(".../envs/starcraft/obs_size.json") as f:
                config = next(
                    (item for item in json.load(f)
                     if (item["map_name"] == self.args.env_args["map_name"])),
                    None)
        self.config = config
        self.agent_model = NRNNAgent(self.config["input_shape"], self.args)
        self.tacit_model_path = args.tacit_pretr_path

    def tac_pretr_div(self):
        pretr_model_path = os.path.join(self.args.tacit_pretr_path, str(self.args.tac_timesteps))
        self.agent_model.load_state_dict(th.load("{}/agent.th".format(pretr_model_path), map_location=lambda storage, loc: storage))
        self.agent_model.to(self.batch.device)
        obs_1 = self.obs_mask(self.obs)
        agent_input = self.build_inputs(obs_1)

        B, T, N, D_obs = agent_input.shape
        _, _, N, D_hid = self.hidden_state.shape
        _, _, _, D_prob = self.probs_2.shape
        _, _, _, D_act = self.action_2.shape
        _, _, _, D_avact = self.avail_action.shape

        agent_input_flat = agent_input.reshape(-1, N, D_obs).to(self.batch.device)
        hidden_state_flat = self.hidden_state.reshape(-1, N, D_hid).to(self.batch.device)

        probs_flat, hidden_state_out_flat = self.agent_model.forward(agent_input_flat, hidden_state_flat)
        probs_1 = probs_flat.reshape(-1, D_prob)
        probs_2 = self.probs_2.reshape(-1, D_prob)
        action_flat = self.action_2.reshape(-1, D_act)
        mask_avail_actions = self.avail_action.clone()
        mask_avail_actions[:, :, :, 6:] = 0
        expanded_reward = self.reward.unsqueeze(2).repeat(1, 1, N, 1)
        reward_flat = expanded_reward.reshape(-1, 1)

        mask_avact_flat = mask_avail_actions.reshape(-1, D_avact)
        mask_probs_1 = self.masked_normalize_probs(probs_1, mask_avact_flat)
        mask_probs_2 = self.masked_normalize_probs(probs_2, mask_avact_flat)

        eps = 1e-10
        kl_elementwise = mask_probs_1 * (th.log(mask_probs_1 + eps) - th.log(mask_probs_2 + eps))
        f_switch = self.trim(reward_flat, probs_1, probs_2, action_flat)
        tac_constr = f_switch * kl_elementwise
        mean_tac_constr = tac_constr.mean()
        return mean_tac_constr

    def trim(self, reward, action_probs_1, action_probs_2, actions):
        f_switch = (reward < 0.0).int()
        if actions.dtype != th.long:
            actions = th.from_numpy(actions).long().to(self.batch.device)
        chosen_action_probs_1 = action_probs_1.gather(1, actions)
        chosen_action_probs_2 = action_probs_2.gather(1, actions)
        action_ratio = th.abs((chosen_action_probs_1 + 1e-7) / (chosen_action_probs_2 + 1e-7))
        ar_mask = (action_ratio >= 0.99) & (action_ratio <= 1.01)
        f_switch[ar_mask] = 0
        return f_switch

    def select_action(self, probs, avail_actions):
        masked_q_values = probs.clone()
        masked_q_values[avail_actions == 0] = -float("inf")  # should never be selected!
        picked_actions = masked_q_values.max(dim=2)[1]
        return picked_actions

    def obs_mask(self, obs):
        om_inl = self.config["om_inl"]
        om_fin = self.config["om_fin"]
        obs_array = np.stack([o.cpu().numpy() for o in obs])
        obs_array[:, :, :, om_inl:om_fin] = 0
        obs_tensor = th.from_numpy(obs_array).float()
        return obs_tensor

    def build_inputs(self, obs_1):
        bs, ep_len, n_agents, obs_dim = obs_1.shape
        _, _, _, action_dim = self.actions_onehot.shape
        device = obs_1.device
        obs_input = obs_1
        zero_actions = th.zeros_like(self.actions_onehot[:, :1])  # shape: (bs, 1, n_agents, action_dim)
        last_actions = th.cat([zero_actions, self.actions_onehot[:, :-1]],dim=1).to(device)  # shape: (bs, ep_len, n_agents, action_dim)
        agent_ids = th.eye(n_agents)
        agent_ids = agent_ids.unsqueeze(0).unsqueeze(0).expand(bs, ep_len, -1, -1)  # shape: (bs, ep_len, n_agents, n_agents)

        # 拼接
        inputs = th.cat([obs_input, last_actions, agent_ids],
                        dim=-1)  # shape: (bs, ep_len, n_agents, obs_dim + action_dim + n_agents)

        return inputs  # (bs, ep_len, n_agents, total_input_dim)

    def masked_normalize_probs(self, probs, mask_avail_actions):
        probs = probs * mask_avail_actions
        min_vals = (probs + (1 - mask_avail_actions) * 1e10).min(dim=-1, keepdim=True)[0]
        shifted_probs = probs - min_vals
        shifted_probs = shifted_probs * mask_avail_actions
        sum_shifted = shifted_probs.sum(dim=-1, keepdim=True) + 1e-8
        normalized_probs = shifted_probs / sum_shifted
        normalized_probs = normalized_probs * mask_avail_actions
        return normalized_probs
