# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
from torch.distributions.normal import Normal
import torch.nn.functional as F

def _one_hot_with_dtype(data, num_classes, dtype, hot_value=1):
    result = torch.zeros([data.size(0), num_classes], device=data.device, dtype=dtype)
    result.scatter_(1, data.unsqueeze(-1), hot_value)
    return result

def _one_hot_with_dtype_and_padding(data, num_classes, dtype, hot_value=1):
    result = torch.zeros([data.size(0), num_classes + 1], device=data.device, dtype=dtype)
    result.scatter_(1, data.unsqueeze(-1), hot_value)
    return result[:, :-1]

def empty_loss(scores_w_noise, top_ids):
    return 0.0

def sparse_loss(scores_w_noise, top_ids, num_global_experts):
    num_samples, max_num_global_experts = int(scores_w_noise.size(0)), int(scores_w_noise.size(1))
    topk_prob = torch.sum(scores_w_noise, dim=1) / (num_global_experts + 0.0)

    return torch.mean(topk_prob)

def diverse_and_simple_gate_loss(scores_w_noise, top_ids, gates, expert_mask):
    sims = torch.matmul(F.normalize(gates, dim=0).T, F.normalize(gates, dim=0))

    targets = torch.eye(sims.shape[0]).to(sims.device)

    sim_mask = torch.matmul(expert_mask.unsqueeze(0).T, expert_mask.unsqueeze(0))
    # sim_mask = sim_mask * (1.0 - torch.eye(sim_mask.shape[0]).to(sim_mask.device))

    # sims = sims * sim_mask

    sim_loss = torch.norm(sims * sim_mask - targets * sim_mask)
    # sim_loss = torch.sum(sims)

    simple_loss = torch.mean(torch.norm(gates, dim=0))

    return sim_loss + simple_loss





def gshard_loss(scores_w_noise, top_ids, num_global_experts):
    num_samples, max_num_global_experts = int(scores_w_noise.size(0)), int(scores_w_noise.size(1))
    mask = _one_hot_with_dtype_and_padding(top_ids[:, 0], max_num_global_experts, dtype=scores_w_noise.dtype,
        hot_value=num_global_experts / num_samples)
    me = torch.sum(scores_w_noise, dim=0)
    ce = torch.sum(mask, dim=0)
    l_aux = torch.sum(me * ce) / num_samples
    return l_aux

def load_importance_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise):
    def load_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise):
        assert gate_noise > 0, "`gate_noise` must be > 0 for normalization in load_importance_loss()."
        normal = Normal(
            torch.tensor([0.0], device=scores_wo_noise.device),
            torch.tensor([gate_noise / num_global_experts], device=scores_wo_noise.device),
        )
        threshold = topk_logits[:, -1].view(-1, 1).float()
        diff = scores_wo_noise.float() - threshold.float()
        prob = normal.cdf(diff)
        Load = prob.sum(0)
        l_load = Load.float().var() / (Load.float().mean() ** 2 + 1e-10)
        return l_load

    def importance_loss(scores_wo_noise):
        Impi = scores_wo_noise.float().sum(0)
        l_imp = Impi.float().var() / (Impi.float().mean() ** 2 + 1e-10)

        return l_imp

    l_imp = importance_loss(scores_wo_noise)
    l_load = load_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise)
    return (l_imp + l_load) / 2.0