
from typing import Optional, Tuple, List
from pathlib import Path
import logging
import json
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..zipformer.model import ZipformerEncoderModel
from ..zipformer.utils.padding import make_pad_mask
from ...auto.auto_config import AutoConfig
from ...utils.checkpoint import load_model_params

from .utils import load_id2label

from .loss import GE2ELoss
    

class ZipformerForSpeakerVerificationModel(ZipformerEncoderModel):
    @classmethod
    def from_pretrained(cls, exp_dir, checkpoint_filename='pretrained.pt'):
        config = AutoConfig.from_pretrained(exp_dir)
        id2label_json = Path(exp_dir) / 'id2label.json'
        id2label = load_id2label(id2label_json)
        model = cls(config, id2label)
        ckpt_path = Path(exp_dir) / checkpoint_filename
        load_model_params(model, ckpt_path)
        return model

    def __init__(self, config, id2label):
        super().__init__(config)
        self.id2label = id2label
        self.label2id = {label: int(idx) for idx, label in self.id2label.items()} # idx

        self.criterion = GE2ELoss(
            init_w=getattr(self.config, "ge2e_init_w"),
            init_b=getattr(self.config, "ge2e_init_b")
        )

        if self.config.fuse_encoder:
            self.encoder_fusion_weights = nn.Parameter(torch.zeros(len(config.num_encoder_layers)))
        else:
            self.encoder_fusion_weights = None

    def forward(self, x, x_lens, target):
        """
        Args:
          x:
            A 3-D tensor of shape (N, T, C).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.
          target:
            The ground truth label of spkid, e.g., "id00015"
        Returns:
          Return the generalized end-to-end loss
        """
        assert x.ndim == 3, x.shape
        assert x_lens.ndim == 1, x_lens.shape

        speaker_ids_tensor = torch.tensor([self.label2id[spk_str_id] for spk_str_id in target],
            dtype=torch.long, 
            device=x.device)

        # Compute encoder outputs
        encoder_output = self.forward_encoder(x, x_lens)
        
        if self.encoder_fusion_weights is not None:
            fusion_weights = F.softmax(self.encoder_fusion_weights, dim=0).view(-1, 1, 1, 1)
            encoder_out = (encoder_output.encoder_out_full * fusion_weights).sum(dim=0)
        else:
            encoder_out = encoder_output.encoder_out
        
        # foward the verificator to get embedding that specific for GE2E loss
        embeddings = self.forward_verificator(encoder_out, encoder_output.encoder_out_lens)  # (B, C)
        loss, acc = self.criterion(embeddings, speaker_ids_tensor) 
        
        # raw scale of the embeddings as a stability metric
        embedding_norm = embeddings.norm(dim=-1).mean()

        return loss, embeddings, acc, embedding_norm.item()


    def forward_verificator(self, encoder_out, encoder_out_lens):
        """
        Performs pooling to produce utterance-level speaker embeddings for loss input.
        Args:
          encoder_out:
            A 3-D tensor of shape (N, T, C).
          encoder_out_lens:
            A 1-D tensor of shape (N,).
        Returns:   pooled_embeddings (torch.Tensor): A 2-D tensor of shape (N, D)
        """
        padding_mask = make_pad_mask(encoder_out_lens)  # (N, T)
        encoder_out[padding_mask] = 0  
        valid_counts = (~padding_mask).sum(dim=1, keepdim=True)  # (N, 1)
        pooled_embeddings = encoder_out.sum(dim=1) / valid_counts  # (N, D) 

        return pooled_embeddings


    def generate(self, input, threshold=None):
        """
        Generate speaker embeddings from either:
            - A tuple (x, x_lens) of precomputed features
            - A list of wav file paths
            - A list of raw waveforms (1D numpy arrays or tensors)
        Args:
            input: (x, x_lens) or list of file paths or raw waveforms
            threshold: Cosine similarity threshold for verification decision. Must be provided before inference.

        Returns:
            - embeddings (Tensor): (N, embedding_dim) speaker embeddings.
        """

        if threshold is None:
            raise ValueError(
                "Threshold is not set. Please evaluate the model on a validation set "
                "to determine the best threshold (e.g., using EER), and provide it before inference."
                "If you are using the 'asr_init_unfreeze_e40' setup, set threshold = 0.1691"
                "If you are using the 'scratch_e40' setup, set threshold = 0.4000"
            )

        # Handle flexible input
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
            names = list(range(x.size(0)))
        else:
            x, x_lens = self.extract_feature(input)
            if isinstance(input[0], str): # wav file path
                names = [str(p) for p in input]
            else: # raw waveform
                names = list(range(x.size(0)))
        
        device = next(self.parameters()).device
        x = x.to(device)
        x_lens = x_lens.to(device)

        # Forward encoder
        encoder_output = self.forward_encoder(x, x_lens)
        if self.encoder_fusion_weights is not None:
            fusion_weights = F.softmax(self.encoder_fusion_weights, dim=0).view(-1, 1, 1, 1)
            encoder_out = (encoder_output.encoder_out_full * fusion_weights).sum(dim=0)
        else:
            encoder_out = encoder_output.encoder_out

        embeddings = self.forward_verificator(encoder_out, encoder_output.encoder_out_lens)

        # Calculate cosine similarity matrix (L2 norm + dot product)
        embeddings = F.normalize(embeddings, p=2, dim=-1)
        sim_matrix = torch.matmul(embeddings, embeddings.T)

        # Human-readable result
        results = []
        N = len(names)
        for i in range(N):
            for j in range(i + 1, N):
                score = float(sim_matrix[i][j])
                pred = int(score > threshold)
                results.append({
                    'pair': f'{names[i]} - {names[j]}',
                    'score': round(score, 4),
                    'prediction': pred
                })
        for r in results:
            print(f"{r['pair']}: score={r['score']}, prediction={r['prediction']}")

        return results
