import torch
import torch.nn as nn
import torch.nn.functional as F
from deguc.core.stats import global_stats

class HierarchicalRouter(nn.Module):
    def __init__(self, hidden_dim: int, group_expert_map, top_k: int = 2, group_top_g: int = 1, device=None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.group_expert_map = group_expert_map
        self.top_k = top_k
        self.group_top_g = group_top_g
        self.device = device
        self.group_ids = sorted(group_expert_map.keys())
        self.num_groups = len(self.group_ids)
        self.group_router = nn.Linear(hidden_dim, self.num_groups)
        self.intra_group_routers = nn.ModuleDict()
        for g, experts in group_expert_map.items():
            self.intra_group_routers[str(g)] = nn.Linear(hidden_dim, len(experts))
        self.to(device) if device else None

    def update_group_map(self, new_map):
        self.group_expert_map = new_map
        self.group_ids = sorted(new_map.keys())
        self.num_groups = len(self.group_ids)
        old_state = self.group_router.state_dict()
        self.group_router = nn.Linear(self.hidden_dim, self.num_groups).to(self.group_router.weight.device)
        with torch.no_grad():
            k = min(old_state["weight"].shape[0], self.group_router.weight.shape[0])
            self.group_router.weight[:k].copy_(old_state["weight"][:k])
            self.group_router.bias[:k].copy_(old_state["bias"][:k])
        new_intra = nn.ModuleDict()
        for g, exps in new_map.items():
            new_intra[str(g)] = nn.Linear(self.hidden_dim, len(exps))
        self.intra_group_routers = new_intra.to(self.group_router.weight.device)

    def forward(self, hidden: torch.Tensor):
        B, H = hidden.shape
        group_logits = self.group_router(hidden)
        group_probs = F.softmax(group_logits, dim=-1)
        mean_probs = group_probs.mean(dim=0)
        uniform = torch.full_like(mean_probs, 1.0 / self.num_groups)
        balance_loss = F.kl_div(mean_probs.log(), uniform, reduction="batchmean")

        if self.group_top_g == 1:
            top_group_scores, top_group_idx = torch.max(group_probs, dim=-1)
            selected_groups = top_group_idx.unsqueeze(-1)
            selected_group_scores = top_group_scores.unsqueeze(-1)
        else:
            selected_group_scores, selected_groups = torch.topk(group_probs, k=self.group_top_g, dim=-1)

        routing_info = []
        for i in range(B):
            token_hidden = hidden[i:i+1]
            token_routes = []
            for j in range(selected_groups.shape[1]):
                g_id = self.group_ids[selected_groups[i, j].item()]
                group_score = selected_group_scores[i, j].item()
                experts = self.group_expert_map[g_id]
                if not experts:
                    continue
                intra_router = self.intra_group_routers[str(g_id)]
                logits = intra_router(token_hidden)
                probs = F.softmax(logits, dim=-1).squeeze(0)
                k = min(self.top_k, probs.shape[0])
                exp_scores, exp_idx = torch.topk(probs, k=k)
                for s, idx_e in zip(exp_scores.tolist(), exp_idx.tolist()):
                    expert_id = experts[idx_e]
                    token_routes.append((expert_id, group_score * s))
            if token_routes:
                total = sum(x[1] for x in token_routes)
                token_routes = [(e, w / (total + 1e-9)) for e, w in token_routes]
            routing_info.append((i, token_routes))
            for e, _ in token_routes:
                global_stats.ensure_expert(e)
        return routing_info, balance_loss