#! /usr/bin/python
# -*- encoding: utf-8 -*-
# Adapted from https://github.com/wujiyang/Face_Pytorch (Apache License)

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class AAMSoftmax(nn.Module):
    def __init__(self, speaker_embedding_dim, n_speakers, reduction, output_rep, encoder_layers, margin=0.3, scale=15):
        super(AAMSoftmax, self).__init__()
        self.m = margin
        self.s = scale
        self.in_feats = speaker_embedding_dim
        if output_rep == "elbo":
            self.weight_smx = nn.Parameter(
                torch.rand((encoder_layers, n_speakers, speaker_embedding_dim)), requires_grad=True
            )
        else:
            self.weight_smx = nn.Parameter(torch.randn((n_speakers, speaker_embedding_dim)), requires_grad=True)

        self.ce = nn.CrossEntropyLoss(reduction=reduction)
        self.cos_m = math.cos(self.m)
        self.sin_m = math.sin(self.m)
        # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
        self.th = math.cos(math.pi - self.m)
        self.mm = math.sin(math.pi - self.m) * self.m

    def forward(self, x, label=None, idx=None):
        # cos(theta)
        # if idx is None:
        #     weight = self.weight_smx
        # else:
        #     weight = self.weight_smx[idx, :, :]
        cosine = F.linear(
            F.normalize(x),
            (
                F.normalize(self.weight_smx)
                if len(self.weight_smx.shape) == 2
                else F.normalize(self.weight_smx[idx, :, :])
            ),
        )
        # cos(theta + m)
        sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)

        one_hot = torch.zeros_like(cosine, device=cosine.device)
        one_hot.scatter_(1, label.view(-1, 1), 1)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output = output * self.s
        loss = self.ce(output, label)
        predicted_classes = torch.argmax(output.detach(), dim=-1)
        return loss, [predicted_classes]
