import torch
import torch.nn as nn
import torch.nn.functional as F


class SupConLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, query_emb, passage_emb, qp_mat):
        cos_sim = ( query_emb @ passage_emb.T ) / self.temperature
        exp_sim = nn.functional.log_softmax(cos_sim, dim=1)
        soft_logits = (qp_mat * exp_sim).sum(dim=1)
        loss = - (soft_logits / qp_mat.sum(dim=1)).mean()
        return loss


class RouterLoss(nn.Module):
    def __init__(self, load_balance_weight = 0.01, z_loss_weight = 0.001, num_experts = 1):
        super().__init__()
        self.wBAL = torch.tensor(load_balance_weight, requires_grad=False)
        self.wZ = torch.tensor(z_loss_weight, requires_grad=False)
        self.num_experts = num_experts

    def forward(self, router_logits: torch.Tensor) -> torch.Tensor: 
        # enforces experts should not be used widely more than another
        num_experts = self.num_experts
        wBAL = torch.abs(self.wBAL)
        wZ = torch.abs(self.wZ)
        if isinstance(router_logits, tuple):
            router_logits = torch.cat(router_logits, dim=0)
            if len(router_logits.shape) == 3:
                router_logits = router_logits.mean(dim=1) # batch_size * num_hidden_layers, num_experts
            elif len(router_logits.shape) == 4:
                router_logits = router_logits.mean(dim=2)
        # can also be batchsize * num_tasks * num_hidden_layers, num_experts
        router_logits = router_logits.reshape(-1, num_experts)

        z_loss = torch.logsumexp(router_logits, dim=-1).square().mean()

        router_probs = F.softmax(router_logits, dim=-1)
        gate = torch.argmax(router_probs, dim=-1)
        num_tokens = F.one_hot(gate, num_experts).gt(0).sum(0)
        
        p = router_probs.mean(0)
        temp = num_tokens.float()
        f = temp / temp.sum(0, keepdim=True)

        return wBAL * num_experts * torch.sum(p * f), wZ * z_loss


if __name__ == '__main__':
    from model import SiameseEncoder
    from data import Ads_Dataset

    import torch
    model = SiameseEncoder(128, moe_type = 'mhmoe', num_experts = 20, topk = 2, intermediate_size_expert = 1152, num_expert_heads = 2)
    in_data = {
        'query_input_ids': torch.randint(0, 120067, (2, 16)),
        'query_attention_mask': torch.randint(0, 2, (2, 16)),
        'passage_input_ids': torch.randint(0, 120067, (2, 16)),
        'passage_attention_mask': torch.randint(0, 2, (2, 16)),
    }

    loss_fn = RouterLoss(num_experts = 20)

    query_emb, passage_emb, query_router_logits, passage_router_logits = model(in_data)
    router_logits = tuple((a + b) / 2 for a, b in zip(query_router_logits, passage_router_logits))
    # print(router_logits)
    loss = loss_fn(router_logits)
    print(loss)