from typing import Optional, Tuple, List
from pathlib import Path
import logging
import json
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..zipformer.model import ZipformerEncoderModel
from ..zipformer.utils.padding import make_pad_mask
from ...auto.auto_config import AutoConfig
from ...utils.checkpoint import load_model_params
from .utils import load_id2label, compute_acc
    
class ZipformerForSequenceClassificationModel(ZipformerEncoderModel):
    @classmethod
    def from_pretrained(cls, exp_dir, checkpoint_filename='pretrained.pt'):
        config = AutoConfig.from_pretrained(exp_dir)
        id2label_json = Path(exp_dir) / 'id2label.json'
        id2label = load_id2label(id2label_json)
        model = cls(config, id2label)
        ckpt_path = Path(exp_dir) / checkpoint_filename
        load_model_params(model, ckpt_path)
        return model
    
    def __init__(self, config, id2label):
        super().__init__(config)
        self.id2label = id2label
        self.label2id = {label: int(idx) for idx, label in id2label.items()}
        self.num_classes = len(self.id2label)
        self.is_multilabel = config.is_multilabel
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(max(config.encoder_dim), self.num_classes),
        )
        
        if self.is_multilabel:
            # for multi-label classification
            self.criterion = torch.nn.BCEWithLogitsLoss(reduction="sum")
        else:
            self.criterion = torch.nn.CrossEntropyLoss(reduction='sum')
            
        if self.config.fuse_encoder:
            self.encoder_fusion_weights = nn.Parameter(torch.zeros(len(config.num_encoder_layers)))
        else:
            self.encoder_fusion_weights = None
            
    def tag2multihot(self, tag_strings):
        # input: ['sand;rub', 'butterfly']
        # output: torch.tensor([[1,1,0], [0,0,1]])
        multihot = torch.zeros((len(tag_strings), self.num_classes), dtype=torch.float32)

        for i, tag_str in enumerate(tag_strings):
            tags = tag_str.split(";")
            for tag in tags:
                multihot[i, int(self.label2id[tag])] = 1.0
        return multihot
        
    def forward(self, x, x_lens, tags):
        """
        Args:
          x:
            A 3-D tensor of shape (N, T, C).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.
          tags:
            The ground truth tag of audio, shape (N,). Multilabel is separated by ';',
            e.g. ['dog;bark', 'cat']
        Returns:
          Return the binary crossentropy loss
        """
        targets = self.tag2multihot(tags).to(x.device)
        assert x.ndim == 3, x.shape
        assert x_lens.ndim == 1, x_lens.shape

        # Compute encoder outputs
        encoder_output = self.forward_encoder(x, x_lens)
        
        if self.encoder_fusion_weights is not None:
            fusion_weights = F.softmax(self.encoder_fusion_weights, dim=0).view(-1, 1, 1, 1)
            encoder_out = (encoder_output.encoder_out_full * fusion_weights).sum(dim=0)
        else:
            encoder_out = encoder_output.encoder_out
            
        # Forward the classifer
        logits = self.forward_classifier(
            encoder_out=encoder_out, 
            encoder_out_lens=encoder_output.encoder_out_lens
        )  # (N, num_classes)

        loss = self.criterion(logits, targets)
        
        top1_acc, top5_acc = compute_acc(logits, targets)

        return loss, logits, top1_acc, top5_acc
    
    def forward_classifier(self, encoder_out, encoder_out_lens):
        """
        Args:
          encoder_out:
            A 3-D tensor of shape (N, T, C).
          encoder_out_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.

        Returns:
          A 2-D tensor of shape (N, num_classes).
        """
        logits = self.classifier(encoder_out)  # (N, T, num_classes)
        padding_mask = make_pad_mask(encoder_out_lens)
        logits[padding_mask] = 0 # mask the padding frames
        # avg pooling on the logits
        logits = logits.sum(dim=1)
        logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) 
        return logits

    def generate(self, input, return_full_logits=False, threshold=0.0, topk=1):
        """
        Generate predictions from either:
            - A tuple (x, x_lens) of precomputed features
            - A list of wav file paths
            - A list of raw waveforms (1D numpy arrays or tensors)

        Args:
            input: (x, x_lens) or list of file paths or raw waveforms
            return_full_logits (bool): Whether to return full logits and probs (not top-k only)
            threshold (float): Threshold for multilabel classification
            topk (int): Number of top predictions to return

        Returns:
            - full logits if return_full_logits=True
            else:
            - labels (List[List[str]])
            - logits (Tensor): (N, topk)
            - probs  (Tensor): (N, topk)
        """
        # Handle flexible input
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
        else:
            x, x_lens = self.extract_feature(input)
        
        device = next(self.parameters()).device
        x = x.to(device)
        x_lens = x_lens.to(device)
        # Forward encoder
        encoder_output = self.forward_encoder(x, x_lens)
        if self.encoder_fusion_weights is not None:
            fusion_weights = F.softmax(self.encoder_fusion_weights, dim=0).view(-1, 1, 1, 1)
            encoder_out = (encoder_output.encoder_out_full * fusion_weights).sum(dim=0)
        else:
            encoder_out = encoder_output.encoder_out

        # Forward classifier
        logits_full = self.forward_classifier(encoder_out, encoder_output.encoder_out_lens)  # (N, num_classes)

        if return_full_logits:
            return logits_full

        # Multilabel classification
        if self.is_multilabel:
            probs_full = torch.sigmoid(logits_full)
            topk_probs, topk_indices = torch.topk(probs_full, k=min(topk, probs_full.size(-1)), dim=-1)
            topk_logits = torch.gather(logits_full, dim=1, index=topk_indices)

            labels = [
                [self.id2label[str(idx.item())] for idx, prob in zip(indices, probs)
                if prob.item() > threshold]
                for indices, probs in zip(topk_indices, topk_probs)
            ]
            return labels, topk_logits, topk_probs

        # Single-label classification
        else:
            probs_full = torch.softmax(logits_full, dim=-1)
            topk_probs, topk_indices = torch.topk(probs_full, k=min(topk, probs_full.size(-1)), dim=-1)
            topk_logits = torch.gather(logits_full, dim=1, index=topk_indices)

            labels = [
                [self.id2label[str(idx.item())] for idx in indices]
                for indices in topk_indices
            ]
            return labels, topk_logits, topk_probs
