import torch
import torch.nn as nn
import torch.nn.functional as F

class GE2ELoss(nn.Module):
    """
    Generalized End-to-End (GE2E) Loss for speaker verification that calculates centroids based on speaker_ids
    """
    def __init__(self, init_w=10.0, init_b=-5.0):
        """
        Args:
            init_w (float): scaling parameter 'w'.
            init_b (float): bias parameter 'b'.
        """
        super().__init__()
        self.w = nn.Parameter(torch.tensor(init_w))
        self.b = nn.Parameter(torch.tensor(init_b))

    def forward(self, embeddings: torch.Tensor, speaker_ids: torch.Tensor):
        """
        Calculates the GE2E Loss.
        Args:
            embeddings (torch.Tensor): A 2-D tensor of shape (N, D),
                                       where N is the total number of utterances in the batch,
                                       and D is the embedding dimension.
            speaker_ids (torch.Tensor): A 1-D tensor of shape (N,),
                                        containing the speaker ID for each embedding.

        Returns:
            tuple:
                loss (torch.Tensor): A scalar tensor representing the GE2E loss for the batch.
                acc (float): Top-1 accuracy for the batch (percentage of embeddings correctly classified to their own speaker's centroid).
        """
        if embeddings.ndim != 2:
            raise ValueError(f"Expected embeddings to be 2-D (N, D), but got {embeddings.ndim}-D: {embeddings.shape}")
        if speaker_ids.ndim != 1 or embeddings.size(0) != speaker_ids.size(0):
            raise ValueError(f"Expected speaker_ids to be 1-D (N,) and match N of embeddings. "
                             f"Got embeddings N={embeddings.size(0)}, speaker_ids shape={speaker_ids.shape}")

        # Normalize embeddings (L2 normalization is standard for cosine similarity)
        embeddings_norm = F.normalize(embeddings, p=2, dim=1)  # (N, D)
        
        unique_speakers, inverse_indices, counts = torch.unique(
            speaker_ids, sorted=True, return_inverse=True, return_counts=True
        )
        num_unique_speakers = len(unique_speakers)
        num_utterances_in_batch = embeddings.size(0)
        embedding_dim = embeddings.size(1)

        # Group embeddings by speaker
        grouped_embeddings = [
            embeddings_norm[speaker_ids == spk_id] for spk_id in unique_speakers
        ]
        
        # Calculate centroids and leave-one-out centroids for each utterance
        centroids = torch.stack([spk_embs.mean(dim=0) for spk_embs in grouped_embeddings])
        loo_centroids_flat = torch.zeros_like(embeddings_norm) # (N, D)
        
        current_idx = 0
        for i, spk_id in enumerate(unique_speakers):
            spk_embs = grouped_embeddings[i] # (U_i, D) where U_i is count of utterances for this speaker
            num_spk_utts = counts[i].item()

            if num_spk_utts > 1:
                sum_spk_embs = spk_embs.sum(dim=0, keepdim=True) # (1, D)
                loo_spk_centroids = (sum_spk_embs - spk_embs) / (num_spk_utts - 1)
            else:
                loo_spk_centroids = centroids[i].unsqueeze(0).repeat(num_spk_utts, 1) # (U_i, D)
            loo_centroids_flat[current_idx : current_idx + num_spk_utts] = loo_spk_centroids
            current_idx += num_spk_utts
            
        # Compute similarity matrix
        sim_matrix_all = torch.matmul(embeddings_norm, centroids.T) # (N, D) @ (D, S) -> (N, S)
        own_speaker_centroid_idx = inverse_indices # (N,)
        own_loo_similarities = torch.sum(embeddings_norm * loo_centroids_flat, dim=1) # (N,)
        sim_matrix = sim_matrix_all.clone() 
        sim_matrix[torch.arange(num_utterances_in_batch), own_speaker_centroid_idx] = own_loo_similarities
        

        sim_matrix = self.w * sim_matrix + self.b # (N, num_unique_speakers)
        labels = own_speaker_centroid_idx # (N,)

        loss = F.cross_entropy(sim_matrix, labels, reduction='sum')
        # Batch-level accuracy
        with torch.no_grad():
            preds = sim_matrix.argmax(dim=1)  # (N,)
            correct = (preds == labels).sum().item()
            acc = correct / num_utterances_in_batch

        return loss, float(acc)