from typing import Optional, Tuple, List
from pathlib import Path
import logging
import json
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..zipformer.model import ZipformerEncoderModel
from ..zipformer.utils.padding import make_pad_mask
from ...auto.auto_config import AutoConfig
from ...utils.checkpoint import load_model_params
from .utils import load_id2label, compute_acc
    
class MeanPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
        """
        Args:
          x: A 3-D tensor of shape (N, T, C).
          x_lens: A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.
        Returns:
          A 2-D tensor of shape (N, C).
        """
        padding_mask = make_pad_mask(x_lens)
        x[padding_mask] = 0
        z = x.sum(dim=1) / (~padding_mask).sum(dim=1, keepdim=True)
        return z  # (N, C)

class MultiheadAttentionPooling(nn.Module):
    def __init__(self, d_in, num_heads=4, d_qkv=None):
        super().__init__()
        self.h = num_heads
        d_qkv = d_qkv or d_in
        self.W_q = nn.Parameter(torch.randn(num_heads, d_qkv))
        self.proj_k = nn.Linear(d_in, d_qkv)
        self.proj_v = nn.Linear(d_in, d_qkv)
        self.W_o = nn.Linear(num_heads * d_qkv, d_in)
        nn.init.xavier_uniform_(self.W_q)
        nn.init.xavier_uniform_(self.proj_k.weight)
        nn.init.xavier_uniform_(self.proj_v.weight)

    def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        K = self.proj_k(x)  # (B, T, d_qkv)
        V = self.proj_v(x)  # (B, T, d_qkv
        Q = self.W_q.unsqueeze(0).expand(B, -1, -1)  # (B, h, d_qkv)
        attn = torch.einsum('bhd, btd -> bht', Q, K)  # (B, h, T)
        attn = attn / K.shape[-1] ** 0.5

        padding_mask = make_pad_mask(x_lens)
        padding_mask = padding_mask.unsqueeze(1)
        attn = attn.masked_fill(padding_mask, float('-inf'))

        alpha = F.softmax(attn, dim=-1)  # (B, h, T)
        z = torch.einsum('bht, btd -> bhd', alpha, V)  # (B, h, d_qkv)
        z = z.reshape(B, -1)  # (B, h * d_qkv)
        z = self.W_o(z)  # (B, d_in)
        return z

def _get_pooling_method(config):
    if config.pooling == 'mean':
        return MeanPooling()
    elif config.pooling == 'mhap':
        d_in = max(config.encoder_dim)
        mhap_config = config.mhap
        num_heads = mhap_config.num_heads if hasattr(mhap_config, 'num_heads') else 4
        d_qkv = mhap_config.d_qkv if hasattr(mhap_config, 'd_qkv') else None
        return MultiheadAttentionPooling(d_in, num_heads=num_heads, d_qkv=d_qkv)
    else:
        raise ValueError(f"Unknown pooling method: {config.pooling}")

class ZipformerForSequenceClassificationModel(ZipformerEncoderModel):
    @classmethod
    def from_pretrained(cls, exp_dir, checkpoint_filename='pretrained.pt'):
        config = AutoConfig.from_pretrained(exp_dir)
        id2label_json = Path(exp_dir) / 'id2label.json'
        id2label = load_id2label(id2label_json)
        model = cls(config, id2label)
        ckpt_path = Path(exp_dir) / checkpoint_filename
        load_model_params(model, ckpt_path)
        return model
    
    def __init__(self, config, id2label):
        super().__init__(config)
        self.id2label = id2label
        self.label2id = {label: int(idx) for idx, label in id2label.items()}
        self.num_classes = len(self.id2label)
        self.is_multilabel = config.is_multilabel
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(max(config.encoder_dim), self.num_classes),
        )

        self.pooling = _get_pooling_method(config)
        
        if self.is_multilabel:
            # for multi-label classification
            self.criterion = torch.nn.BCEWithLogitsLoss(reduction="sum")
        else:
            self.criterion = torch.nn.CrossEntropyLoss(reduction='sum')
            
        if self.config.fuse_encoder:
            self.encoder_fusion_weights = nn.Parameter(torch.zeros(len(config.num_encoder_layers)))
        else:
            self.encoder_fusion_weights = None
            
    def tag2multihot(self, tag_strings):
        # input: ['sand;rub', 'butterfly']
        # output: torch.tensor([[1,1,0], [0,0,1]])
        multihot = torch.zeros((len(tag_strings), self.num_classes), dtype=torch.float32)

        for i, tag_str in enumerate(tag_strings):
            tags = tag_str.split(";")
            for tag in tags:
                multihot[i, int(self.label2id[tag])] = 1.0
        return multihot
        
    def forward(self, x, x_lens, tags):
        """
        Args:
          x:
            A 3-D tensor of shape (N, T, C).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.
          tags:
            The ground truth tag of audio, shape (N,). Multilabel is separated by ';',
            e.g. ['dog;bark', 'cat']
        Returns:
          Return the binary crossentropy loss
        """
        targets = self.tag2multihot(tags).to(x.device)
        assert x.ndim == 3, x.shape
        assert x_lens.ndim == 1, x_lens.shape

        # Compute encoder outputs
        encoder_output = self.forward_encoder(x, x_lens)
        
        if self.encoder_fusion_weights is not None:
            fusion_weights = F.softmax(self.encoder_fusion_weights, dim=0).view(-1, 1, 1, 1)
            encoder_out = (encoder_output.encoder_out_full * fusion_weights).sum(dim=0)
        else:
            encoder_out = encoder_output.encoder_out
            
        # Forward the classifer
        logits = self.forward_classifier(
            encoder_out=encoder_out, 
            encoder_out_lens=encoder_output.encoder_out_lens
        )  # (N, num_classes)

        loss = self.criterion(logits, targets)
        
        top1_acc, top5_acc = compute_acc(logits, targets)

        return loss, logits, top1_acc, top5_acc
    
    def forward_classifier(self, encoder_out, encoder_out_lens):
        """
        Args:
          encoder_out:
            A 3-D tensor of shape (N, T, C).
          encoder_out_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.

        Returns:
          A 2-D tensor of shape (N, num_classes).
        """
        # logits = self.classifier(encoder_out)  # (N, T, num_classes)
        # padding_mask = make_pad_mask(encoder_out_lens)
        # logits[padding_mask] = 0 # mask the padding frames
        # # avg pooling on the logits
        # logits = logits.sum(dim=1)
        # logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) 
        uttr_emb = self.pooling(encoder_out, encoder_out_lens)
        logits = self.classifier(uttr_emb)  # (N, num_classes)
        return logits

    def generate(self, input, return_full_logits=False, threshold=0.0, topk=1):
        """
        Generate predictions from either:
            - A tuple (x, x_lens) of precomputed features
            - A list of wav file paths
            - A list of raw waveforms (1D numpy arrays or tensors)

        Args:
            input: (x, x_lens) or list of file paths or raw waveforms
            return_full_logits (bool): Whether to return full logits and probs (not top-k only)
            threshold (float): Threshold for multilabel classification
            topk (int): Number of top predictions to return

        Returns:
            - full logits if return_full_logits=True
            else:
            - labels (List[List[str]])
            - logits (Tensor): (N, topk)
            - probs  (Tensor): (N, topk)
        """
        # Handle flexible input
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
        else:
            x, x_lens = self.extract_feature(input)
        
        device = next(self.parameters()).device
        x = x.to(device)
        x_lens = x_lens.to(device)
        # Forward encoder
        encoder_output = self.forward_encoder(x, x_lens)
        if self.encoder_fusion_weights is not None:
            fusion_weights = F.softmax(self.encoder_fusion_weights, dim=0).view(-1, 1, 1, 1)
            encoder_out = (encoder_output.encoder_out_full * fusion_weights).sum(dim=0)
        else:
            encoder_out = encoder_output.encoder_out

        # Forward classifier
        logits_full = self.forward_classifier(encoder_out, encoder_output.encoder_out_lens)  # (N, num_classes)

        if return_full_logits:
            return logits_full

        # Multilabel classification
        if self.is_multilabel:
            probs_full = torch.sigmoid(logits_full)
            topk_probs, topk_indices = torch.topk(probs_full, k=min(topk, probs_full.size(-1)), dim=-1)
            topk_logits = torch.gather(logits_full, dim=1, index=topk_indices)

            labels = [
                [self.id2label[str(idx.item())] for idx, prob in zip(indices, probs)
                if prob.item() > threshold]
                for indices, probs in zip(topk_indices, topk_probs)
            ]
            return labels, topk_logits, topk_probs

        # Single-label classification
        else:
            probs_full = torch.softmax(logits_full, dim=-1)
            topk_probs, topk_indices = torch.topk(probs_full, k=min(topk, probs_full.size(-1)), dim=-1)
            topk_logits = torch.gather(logits_full, dim=1, index=topk_indices)

            labels = [
                [self.id2label[str(idx.item())] for idx in indices]
                for indices in topk_indices
            ]
            return labels, topk_logits, topk_probs
