from typing import Optional, Tuple, List
from pathlib import Path
import logging
import json
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..zipformer.model import ZipformerEncoderModel
from ..zipformer.utils.padding import make_pad_mask
from ...auto.auto_config import AutoConfig
from ...utils.checkpoint import load_model_params
from ..audio_tag.utils import load_id2label, compute_acc

# Import ZipformerForSequenceClassificationModel as parent class 
from ..audio_tag.model import ZipformerForSequenceClassificationModel 
from .loss import AAMSoftmaxLoss


class ZipformerForSpeakerIdentificationModel(ZipformerForSequenceClassificationModel):
    def __init__(self, config, id2label):
        super().__init__(config, id2label)
        self.loss_type = getattr(config, "loss_type", "ce") 
        if self.is_multilabel:
            self.criterion = torch.nn.BCEWithLogitsLoss(reduction="sum")
        else:
            if self.loss_type == "aams":
                # AAM-Softmax Loss
                margin = getattr(config, "margin", 0.2)
                scale = getattr(config, "scale", 30.0)
                self.criterion = AAMSoftmaxLoss(margin=margin, scale=scale)
            else:
                # Cross Entropy Loss
                self.criterion = torch.nn.CrossEntropyLoss(reduction='sum')

