from typing import Optional, Tuple, List
from pathlib import Path
import logging
import json
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..zipformer.model import ZipformerEncoderModel
from ..zipformer.utils.padding import make_pad_mask
from ...auto.auto_config import AutoConfig
from ...utils.checkpoint import load_model_params
from ..audio_tag.utils import load_id2label, compute_acc
from sklearn.metrics import precision_score, recall_score, f1_score

class ZipformerForFrameClassificationModel(ZipformerEncoderModel):
    @classmethod
    def from_pretrained(cls, exp_dir, checkpoint_filename='pretrained.pt'):
        """
        Load model from exp_dir.
        """
        config = AutoConfig.from_pretrained(exp_dir)
        id2label_json = Path(exp_dir) / 'id2label.json'
        id2label = load_id2label(id2label_json)
        model = cls(config, id2label)
        ckpt_path = Path(exp_dir) / checkpoint_filename
        load_model_params(model, ckpt_path)
        return model

    def __init__(self, config, id2label):
        super().__init__(config)
        self.id2label = id2label
        self.label2id = {label: int(idx) for idx, label in id2label.items()}
        self.num_classes = len(self.id2label)
        self.is_multilabel = config.is_multilabel

        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(max(config.encoder_dim), self.num_classes),
        )

        if self.is_multilabel:
            # for multi-label classification
            self.criterion = torch.nn.BCEWithLogitsLoss()
        else:
            self.criterion = torch.nn.CrossEntropyLoss()
            
        if self.config.fuse_encoder:
            self.encoder_fusion_weights = nn.Parameter(torch.zeros(len(config.num_encoder_layers)))
        else:
            self.encoder_fusion_weights = None

    def forward(self, x, x_lens, frame_target, compute_metrics: bool=False, return_ground_truth: bool=False):
        """
        Args:
            x:
                A 3-D tensor of shape (B, T, D).
            x_lens:
                A 1-D tensor of shape (B,). It contains the lengths of each sequence in the batch before padding.
            frame_target:
                A 3-D tensor of shape (B, T, num_classes) for multi-label classification at each frame.
        """
        assert x.ndim == 3, f"Expected input x to be 3D tensor, got {x.ndim}D tensor"
        assert x_lens.ndim == 1, f"Expected input x_lens to be 1D tensor, got {x_lens.ndim}D tensor"

        # Compute encoder outputs
        encoder_output = self.forward_encoder(x, x_lens)

        if self.encoder_fusion_weights is not None:
            fusion_weights = F.softmax(self.encoder_fusion_weights, dim=0).view(-1, 1, 1, 1)
            encoder_out = (encoder_output.encoder_out_full * fusion_weights).sum(dim=0)
        else:
            encoder_out = encoder_output.encoder_out

        results = {}

        logits = self.forward_classifier(encoder_out)
        encoder_out_lens = encoder_output.encoder_out_lens

        padding_mask = make_pad_mask(encoder_out_lens)

        flattened_logits = logits[~padding_mask]
        frame_target = frame_target[:, :padding_mask.size(1), :]
        flattened_frame_target = frame_target[~padding_mask]

        loss = self.criterion(flattened_logits, flattened_frame_target)

        results['frame'] = {}
        results['frame']['loss'] = loss
        results['frame']['logits'] = logits
        results['frame']['output_lens'] = encoder_out_lens

        if compute_metrics:
            with torch.no_grad():
                if self.is_multilabel:
                    pred_binary_map = (torch.sigmoid(flattened_logits)>0.5).float()
                    f1 = f1_score(flattened_frame_target.cpu(), pred_binary_map.cpu(), average='micro', zero_division=0)
                    results['frame']['f1'] = f1
                else:
                    acc = compute_acc(flattened_logits, flattened_frame_target)
                    results['frame']['acc'] = acc
        
        if return_ground_truth:
            results['frame']['ground_truth'] = flattened_frame_target

        return results

    def forward_classifier(self, encoder_out: torch.Tensor) -> torch.Tensor:
        """
        Forward the classifier to get logits.
        
        Args:
            encoder_out: Tensor of shape (B, T, D) where D is the output dimension of the encoder.
        
        Returns:
            logits: Tensor of shape (B, T, num_classes) for frame-level classification.
        """
        logits = self.classifier(encoder_out)
        return logits

    def generate(self, input, remove_padding: bool = True):
        """
        Generate predictions from either:
            - A tuple (x, x_lens) of precomputed features
            - A list of wav file paths
            - A list of raw waveforms (1D numpy arrays or tensors)
        
        Args:
            input: (x, x_lens) or list of file paths or raw waveforms

        Returns:
            - unpadded logits (List[Tensor]) if remove_padding=True
            - logits (Tensor): (N, T, num_classes) if remove_padding=False
        """
        # Handle flexible input
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
        else:
            x, x_lens = self.extract_feature(input)

        device = next(self.parameters()).device
        x = x.to(device)
        x_lens = x_lens.to(device)
        # Forward encoder
        encoder_output = self.forward_encoder(x, x_lens)
        if self.encoder_fusion_weights is not None:
            fusion_weights = F.softmax(self.encoder_fusion_weights, dim=0).view(-1, 1, 1, 1)
            encoder_out = (encoder_output.encoder_out_full * fusion_weights).sum(dim=0)
        else:
            encoder_out = encoder_output.encoder_out
        
        # Forward classifier
        logits_full = self.forward_classifier(encoder_out)  # (N, T, num_classes)
        probs = torch.sigmoid(logits_full) if self.is_multilabel else F.softmax(logits_full, dim=-1)

        if not remove_padding:
            return logits_full, probs
        
        # Remove padding
        padding_mask = make_pad_mask(encoder_output.encoder_out_lens)
        unpadded_logits, unpadded_probs = [], []
        for logits, probs, mask in zip(logits_full, probs, padding_mask):
            unpadded_logits.append(logits[~mask])
            unpadded_probs.append(probs[~mask])

        return unpadded_logits, unpadded_probs

class ZipformerForDESED(ZipformerForFrameClassificationModel):
    """
    Zipformer model for DESED task.
    """
    def __init__(self, config, id2label):
        super().__init__(config, id2label)
        assert config.is_multilabel, "DESED task requires multi-label classification"
        self.w_q = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(max(config.encoder_dim), self.num_classes),
        )
        self.aggregator = nn.Softmax(dim=-1)


    def forward(self, x, x_lens, frame_target, clip_only_indices=[], compute_metrics: bool=False, return_ground_truth: bool=False):
        """
        Override to handle desed task as both frame-level target and clip-level target are provided.
        """
        assert x.ndim == 3, x.shape
        assert x_lens.ndim == 1, x_lens.shape

        try:
            # Compute encoder outputs
            encoder_output = self.forward_encoder(x, x_lens)
            
            if self.encoder_fusion_weights is not None:
                fusion_weights = F.softmax(self.encoder_fusion_weights, dim=0).view(-1, 1, 1, 1)
                encoder_out = (encoder_output.encoder_out_full * fusion_weights).sum(dim=0)
            else:
                encoder_out = encoder_output.encoder_out

            if isinstance(frame_target, tuple):
                frame_target, clip_target = frame_target

            loss_all = 0.0
            logits_all = self.forward_classifier(encoder_out)
            encoder_out_lens_all = encoder_output.encoder_out_lens

            padding_mask_all = make_pad_mask(encoder_out_lens_all)  

            # Frame-level classification for strongly labeled data
            indices = list(set(range(encoder_out_lens_all.shape[0])) - set(clip_only_indices))
            logits = logits_all[indices]
            encoder_out_lens = encoder_out_lens_all[indices]
            padding_mask = padding_mask_all[indices]

            flattened_logits = logits[~padding_mask]
            frame_target = frame_target[:, :padding_mask.size(1), :]
            flattened_frame_target = frame_target[~padding_mask]

            loss = self.criterion(flattened_logits, flattened_frame_target)

            results = {}

            results['frame'] = {}
            results['frame']['loss'] = loss
            results['frame']['logits'] = logits
            results['frame']['output_lens'] = encoder_out_lens

            loss_all += loss
        except Exception as e:
            logging.error(f"Error during frame-level classification: {e}")
            import IPython; IPython.embed()

        if len(clip_only_indices) > 0:
            # Clip-level classification for weakly labeled data
            clip_embs = encoder_out[clip_only_indices]
            clip_logits = logits_all[clip_only_indices]
            clip_padding_mask = padding_mask_all[clip_only_indices].unsqueeze(-1)

            # Aggregate frame-level logits to clip-level
            frame_probs = torch.sigmoid(clip_logits)
            query = self.w_q(clip_embs)
            query = query.masked_fill(clip_padding_mask, -1e30)
            attn_weights = self.aggregator(query)
            attn_weights = attn_weights.masked_fill(clip_padding_mask, 0.0)
            norm = attn_weights.clamp(min=1e-10, max=1.0).sum(dim=1, keepdim=True)
            clip_probs = (attn_weights * frame_probs).sum(dim=1) / norm.squeeze(1)

            clip_loss = self.criterion(clip_probs, clip_target)
            loss_all += clip_loss

            results['clip'] = {}
            results['clip']['loss'] = clip_loss

        results['total'] = {}
        results['total']['loss'] = loss_all

        if compute_metrics:
            with torch.no_grad():
                if self.is_multilabel:
                    pred_binary_map = (torch.sigmoid(flattened_logits)>0.5).float()
                    f1 = f1_score(flattened_frame_target.cpu(), pred_binary_map.cpu(), average='micro', zero_division=0)
                    results['frame']['f1'] = f1
                else:
                    acc = compute_acc(flattened_logits, flattened_frame_target)
                    results['frame']['acc'] = acc
        
        if return_ground_truth:
            results['frame']['ground_truth'] = flattened_frame_target

        return results