from typing import Optional, Tuple, List
from pathlib import Path
import logging
import json
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..zipformer.model import ZipformerEncoderModel
from ..zipformer.utils.padding import make_pad_mask
from ...auto.auto_config import AutoConfig
from ...utils.checkpoint import load_model_params
from .utils import load_id2label, compute_acc
import pdb

from ..audio_tag.model import ZipformerForSequenceClassificationModel 

class ZipformerForMusicTaggingModel(ZipformerForSequenceClassificationModel):      
    def tag2multihot(self, tag_strings):
        # input: ['sand;rub', 'butterfly']
        # output: torch.tensor([[1,1,0], [0,0,1]])
        multihot = torch.zeros((len(tag_strings), self.num_classes), dtype=torch.float32)

        for i, tag_str in enumerate(tag_strings):
            if tag_str: # avoid the scenario with empty positive labels
                tags = tag_str.split(";")
                for tag in tags:
                    multihot[i, int(self.label2id[tag])] = 1.0
        return multihot