from modules.agents import REGISTRY as agent_REGISTRY
from components.action_selectors import REGISTRY as action_REGISTRY
import torch as th
import torch.nn as nn

from .basic_controller import BasicMAC


def get_exp_neighbors(bs, n_agents, topk):
    """
    positions: (batch_size, n_agents* 2)"""

    topk_indices = th.arange(topk - 1)
    topk_indices = th.pow(2, topk_indices)
    topk_indices = th.cat([th.zeros(1), topk_indices])

    agent_ind = th.arange(n_agents)
    topk_indices = agent_ind[:, None] + topk_indices[None, :]
    topk_indices = topk_indices % n_agents
    topk_indices = topk_indices[None, :, :].expand(bs, -1, -1)

    # (bs, n_agents, topk)
    return topk_indices.long()


class DTAPSMAC(BasicMAC):
    def __init__(self, scheme, groups, args):
        super().__init__(scheme, groups, args)
        self.one_peer = False
        self.tau_mode = getattr(args, "tau_mode", "linear").lower()  # "exp" | "linear"
        self.tau_init = float(getattr(args, "tau_init", 1.0))
        self.tau_min  = float(getattr(args, "tau_min", 0.3))

        # for exp annealing
        self.tau_decay = float(getattr(args, "tau_decay", 0.97))
        self.tau_decay_steps = int(getattr(args, "tau_decay_steps", 5000))

        # for linear annealing
        self.tau_total_steps = int(getattr(args, "tau_total_steps", 100000))

        self.eval_tau = float(getattr(args, "eval_tau", self.tau_min))
        self._tau = self.tau_init
        self._tau_step = 0
        self._tiny = 1e-6
        self.pairwise_scorer = PairwiseScorer(
            d_hid=args.hidden_dim
        )

        if args.use_cuda and th.cuda.is_available():
            self.pairwise_scorer = self.pairwise_scorer.cuda()

    def init_hidden(self, batch_size):

        h, msg = self.agent.init_hidden()
        h = h.unsqueeze(0).expand(batch_size, self.n_agents, -1)  # bav
        msg = msg.unsqueeze(0).expand(batch_size, self.n_agents, -1)
        self.hidden_states = (h, msg)

    def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False):
        # Only select actions foser the selected batch elements in bs
        avail_actions = ep_batch["avail_actions"][:, t_ep]
        agent_outputs, _ = self.forward(ep_batch, t_ep, test_mode=test_mode)
        chosen_actions = self.action_selector.select_action(
            agent_outputs[bs], avail_actions[bs], t_env, test_mode=test_mode
        )
        return chosen_actions

    def save_models(self, path):
        th.save(self.agent.state_dict(), f"{path}/agent.th")
        th.save(self.pairwise_scorer.state_dict(), f"{path}/scorer.th")

    def load_models(self, path):
        self.agent.load_state_dict(
            th.load(f"{path}/agent.th", map_location=lambda storage, loc: storage)
        )
        self.pairwise_scorer.load_state_dict(
            th.load(f"{path}/scorer.th", map_location=lambda storage, loc: storage)
        )

    def _next_tau(self, training: bool):
        if not training:
            return max(self._tiny, self.eval_tau)

        self._tau_step += 1
        if self.tau_mode == "exp":
            if self._tau_step % max(1, self.tau_decay_steps) == 0:
                self._tau = max(self.tau_min, self._tau * self.tau_decay)
        elif self.tau_mode == "linear":
            frac = min(1.0, self._tau_step / max(1, self.tau_total_steps))
            self._tau = max(self.tau_min, self.tau_init - (self.tau_init - self.tau_min) * frac)
        return max(self._tiny, self._tau)

    def forward(self, ep_batch, t, test_mode=False):
        agent_inputs, topk_indices = self._build_inputs(ep_batch, t)
        avail_actions = ep_batch["avail_actions"][:, t]

        h, temp_hidden_states = self.agent.get_encoding(agent_inputs, self.hidden_states)
        P = self._compute_adjacency_logits(h) 
        sampled_adj = self._gumbel_sigmoid(P, training=not test_mode)

        agent_outs, self.hidden_states, msgs = self.agent.communicate_and_act(
            h, temp_hidden_states, topk_indices, sampled_adj
        )

        states_predicted = self.agent.aux_forward(msgs)

        # Softmax the agent outputs if they're policy logits
        if self.agent_output_type == "pi_logits":

            if getattr(self.args, "mask_before_softmax", True):
                # Make the logits for unavailable actions very negative to minimise their affect on the softmax
                reshaped_avail_actions = avail_actions.reshape(
                    ep_batch.batch_size * self.n_agents, -1
                )
                agent_outs[reshaped_avail_actions == 0] = -1e10
            agent_outs = th.nn.functional.softmax(agent_outs, dim=-1)

        # CAUTION: API changed
        return agent_outs.view(
            ep_batch.batch_size, self.n_agents, -1
        ), states_predicted.view(ep_batch.batch_size, self.n_agents, -1)

    def _build_inputs(self, batch, t):
        # Assumes homogenous agents with flat observations.
        # Other MACs might want to e.g. delegate building inputs to each agent
        bs = batch.batch_size
        inputs = []

        inputs.append(batch["obs"][:, t])  # b1av
        if self.args.obs_last_action:
            if t == 0:
                inputs.append(th.zeros_like(batch["actions_onehot"][:, t]))
            else:
                inputs.append(batch["actions_onehot"][:, t - 1])
        if self.args.obs_agent_id:
            inputs.append(
                th.eye(self.n_agents, device=batch.device)
                .unsqueeze(0)
                .expand(bs, -1, -1)
            )

        inputs = th.cat([x.reshape(bs * self.n_agents, -1) for x in inputs], dim=1)
        
        # (bs * n_agents, input_dim)
        static_topk_indices = get_exp_neighbors(
            bs=bs, n_agents=self.n_agents, topk=self.args.topk_neighbors
        )
        static_topk_indices = static_topk_indices.to(device=batch.device)
        
        return inputs, static_topk_indices
    
    def _compute_adjacency_logits(self, hidden_states):
        bs = hidden_states.shape[0] // self.n_agents
        h = hidden_states.reshape(bs, self.n_agents, -1)  # (bs, n_agents, d)
        h = th.nn.functional.normalize(h, p=2, dim=-1)

        static_indices = get_exp_neighbors(bs, self.n_agents, self.args.topk_neighbors).to(h.device)
        if len(static_indices.shape) == 2:
            static_indices = static_indices.unsqueeze(0).expand(bs, -1, -1)

        logits = self.pairwise_scorer(h, static_indices)
        return logits

    def _gumbel_sigmoid(self, logits, training=True):
        if training:
            g1 = -th.empty_like(logits).exponential_().log()
            g2 = -th.empty_like(logits).exponential_().log()
            noise = g1 - g2  # Logistic(0,1)

            # tau = getattr(self.args, "gumbel_temperature", 1.0)
            tau = self._next_tau(training=True)
            y_soft = th.sigmoid((logits + noise) / tau)

            y_hard = (y_soft > 0.5).float()
            return (y_hard - y_soft).detach() + y_soft
        else:
            return (logits > 0).float()

class PairwiseScorer(nn.Module):
    def __init__(self, d_hid, hidden_mlp=128, edge_embed_dim=16):
        super().__init__()
        self.edge_embed_dim = edge_embed_dim
        self.edge_emb = nn.Embedding(num_embeddings=10, embedding_dim=edge_embed_dim)  # 假设最多10种边类型
        
        self.mlp = nn.Sequential(
            nn.Linear(2*d_hid + edge_embed_dim, hidden_mlp),
            nn.ReLU(),
            nn.Linear(hidden_mlp, 1)
        )

        with th.no_grad():
            last = self.mlp[-1]           # Linear(hidden_mlp, 1)
            th.nn.init.uniform_(last.weight, -0.01, 0.01)
            th.nn.init.constant_(last.bias, 0.0)

    
    def forward(self, h, skel_idx):
        B, N, d = h.shape
        
        if len(skel_idx.shape) == 2:
            skel_idx = skel_idx.unsqueeze(0).expand(B, -1, -1)  # (B, N, K)
        
        K = skel_idx.shape[2]
        hi = h.unsqueeze(2).expand(-1, -1, K, -1)  # (B, N, K, d)
        
        batch_idx = th.arange(B, device=h.device).view(B, 1, 1).expand(B, N, K)
        hj = h[batch_idx, skel_idx]  # (B, N, K, d)

        edge_type_indices = th.arange(K, device=h.device)  # [0, 1, 2, ..., K-1]
        edge_embeddings = self.edge_emb(edge_type_indices)  # (K, edge_embed_dim)
        edge_embeddings = edge_embeddings.view(1, 1, K, self.edge_embed_dim).expand(B, N, -1, -1)  # (B, N, K, edge_embed_dim)
        pair = th.cat([hi, hj, edge_embeddings], dim=-1)  # (B, N, K, d + edge_embed_dim)
        
        logits = self.mlp(pair).squeeze(-1)  # (B, N, K)
        logits = 5.0 * th.tanh(logits)
        
        return logits
