from typing import Optional, Tuple, List
from pathlib import Path
import logging
import json
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from .loss import AudioTextContrastiveLoss
from .utils import a2t_metric
from ..zipformer.model import ZipformerEncoderModel
from ..zipformer.utils.padding import make_pad_mask
from transformers import (
    BertModel, BertTokenizer, GPT2Model, GPT2Tokenizer,
    RobertaModel, RobertaTokenizer, DistilBertModel, DistilBertTokenizer,
    CLIPTokenizer, CLIPTextModel
)
from ...auto.auto_config import AutoConfig
from ...utils.checkpoint import load_model_params

from transformers import AutoConfig as HFConfig
import torch.distributed as dist
from torch.distributed.nn.functional import all_gather  # differentiable all_gather

TEXT_MODELS = {
    'openai/clip-vit-base-patch32': (CLIPTextModel, CLIPTokenizer, 512),
    'prajjwal1/bert-tiny': (BertModel, BertTokenizer, 128),
    'prajjwal1/bert-mini': (BertModel, BertTokenizer, 256),
    'prajjwal1/bert-small': (BertModel, BertTokenizer, 512),
    'prajjwal1/bert-medium': (BertModel, BertTokenizer, 512),
    'gpt2': (GPT2Model, GPT2Tokenizer, 768),
    'distilgpt2': (GPT2Model, GPT2Tokenizer, 768),
    'bert-base-uncased': (BertModel, BertTokenizer, 768),
    'bert-large-uncased': (BertModel, BertTokenizer, 1024),
    'roberta-base': (RobertaModel, RobertaTokenizer, 768),
    'roberta-large': (RobertaModel, RobertaTokenizer, 1024),
    'distilbert-base-uncased': (DistilBertModel, DistilBertTokenizer, 768),
    'distilroberta-base': (RobertaModel, RobertaTokenizer, 768),
}

class MeanPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
        """
        Args:
          x: A 3-D tensor of shape (N, T, C).
          x_lens: A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.
        Returns:
          A 2-D tensor of shape (N, C).
        """
        padding_mask = make_pad_mask(x_lens)
        x[padding_mask] = 0
        z = x.sum(dim=1) / (~padding_mask).sum(dim=1, keepdim=True)
        return z  # (N, C)

class MultiheadAttentionPooling(nn.Module):
    def __init__(self, d_in, num_heads=4, d_qkv=None):
        super().__init__()
        self.h = num_heads
        d_qkv = d_qkv or d_in
        self.W_q = nn.Parameter(torch.randn(num_heads, d_qkv))
        self.proj_k = nn.Linear(d_in, d_qkv)
        self.proj_v = nn.Linear(d_in, d_qkv)
        self.W_o = nn.Linear(num_heads * d_qkv, d_in)
        nn.init.xavier_uniform_(self.W_q)
        nn.init.xavier_uniform_(self.proj_k.weight)
        nn.init.xavier_uniform_(self.proj_v.weight)

    def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        K = self.proj_k(x)  # (B, T, d_qkv)
        V = self.proj_v(x)  # (B, T, d_qkv
        Q = self.W_q.unsqueeze(0).expand(B, -1, -1)  # (B, h, d_qkv)
        attn = torch.einsum('bhd, btd -> bht', Q, K)  # (B, h, T)
        attn = attn / K.shape[-1] ** 0.5

        padding_mask = make_pad_mask(x_lens)
        padding_mask = padding_mask.unsqueeze(1)
        attn = attn.masked_fill(padding_mask, float('-inf'))

        alpha = F.softmax(attn, dim=-1)  # (B, h, T)
        z = torch.einsum('bht, btd -> bhd', alpha, V)  # (B, h, d_qkv)
        z = z.reshape(B, -1)  # (B, h * d_qkv)
        z = self.W_o(z)  # (B, d_in)
        return z

def _get_pooling_method(config):
    if config.pooling == 'mean':
        return MeanPooling()
    elif config.pooling == 'mhap':
        d_in = max(config.encoder_dim)
        mhap_config = config.mhap
        num_heads = mhap_config.num_heads if hasattr(mhap_config, 'num_heads') else 4
        d_qkv = mhap_config.d_qkv if hasattr(mhap_config, 'd_qkv') else None
        return MultiheadAttentionPooling(d_in, num_heads=num_heads, d_qkv=d_qkv)
    else:
        raise ValueError(f"Unknown pooling method: {config.pooling}")

class ZipformerForClapModel(ZipformerEncoderModel):
    @classmethod
    def from_pretrained(cls, exp_dir, checkpoint_filename='pretrained.pt'):
        """
        Load model from exp_dir.
        """
        config = AutoConfig.from_pretrained(exp_dir)
        model = cls(config)
        ckpt_path = Path(exp_dir) / checkpoint_filename
        load_model_params(model, ckpt_path)
        return model

    def __init__(self, config):
        super().__init__(config)
        # text_encoder initialization
        self.text_encoder_type = config.text_encoder_type
        model_cls, tokenizer_cls, width = TEXT_MODELS[self.text_encoder_type]
        self.text_encoder_width = width
        self.tokenizer = tokenizer_cls.from_pretrained(self.text_encoder_type)
        if config.init_text_encoder_from_scratch:
            logging.info(f"Initializing text encoder {self.text_encoder_type} from scratch")
            model_config = HFConfig.from_pretrained(self.text_encoder_type)
            self.text_encoder = model_cls(model_config, add_pooling_layer=False)
        else:
            logging.info(f"Loading pretrained text encoder {self.text_encoder_type} from HuggingFace checkpoint")
            self.text_encoder = model_cls.from_pretrained(self.text_encoder_type, add_pooling_layer=False)
            
        # clap initialization
        self.shared_emb_dim = config.shared_emb_dim
        self.text_proj = nn.Sequential(
            nn.Linear(self.text_encoder_width, self.shared_emb_dim),
            nn.ReLU(),
            nn.Linear(self.shared_emb_dim, self.shared_emb_dim),
        )

        self.audio_proj = nn.Sequential(
            nn.Linear(max(config.encoder_dim), self.shared_emb_dim),
            nn.ReLU(),
            nn.Linear(self.shared_emb_dim, self.shared_emb_dim),
        )

        self.pooling = _get_pooling_method(config)

        self.temp = config.temp
        self.embed_reg = config.embed_reg
        self.criterion = AudioTextContrastiveLoss()

    def encode_audio(self, x, x_lens):
        encoder_output = self.forward_encoder(x, x_lens)
        encoder_out = encoder_output.encoder_out
        encoder_out_lens = encoder_output.encoder_out_lens

        audio_feats = self.audio_proj(encoder_out)
        padding_mask = make_pad_mask(encoder_out_lens, max_len=audio_feats.size(1)).to(audio_feats.device)
        audio_feats = audio_feats.masked_fill(padding_mask.unsqueeze(-1), 0.0)
        frame_counts = (~padding_mask).sum(dim=1).clamp(min=1).unsqueeze(-1)
        audio_embeds = audio_feats.sum(dim=1) / frame_counts
        return F.normalize(audio_embeds, dim=-1)
    
    def encode_text(self, text):
        """
        Args:
            text: List[str]
        Returns:
            text_embeds: Tensor [B, embed_dim]
        """
        device = next(self.parameters()).device  # infer from model
        encoded = self.tokenizer(text, padding='longest', truncation=True,
                                max_length=30, return_tensors='pt').to(device)
        text_output = self.text_encoder(
            input_ids=encoded.input_ids,
            attention_mask=encoded.attention_mask
        )[0]
        text_embeds = text_output[:, 0, :] # [cls] token
        return F.normalize(self.text_proj(text_embeds), dim=-1)

    def forward(self, x, x_lens, text):
        audio_embeds = self.encode_audio(x, x_lens)
        text_embeds = self.encode_text(text)

        # Gather embeddings across all processes if distributed
        if self.training and dist.is_initialized():
            audio_embeds_all = gather_embeddings(audio_embeds)
            text_embeds_all = gather_embeddings(text_embeds)
        else:
            audio_embeds_all = audio_embeds
            text_embeds_all = text_embeds

        # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
        #     import IPython; IPython.embed()
        # torch.distributed.barrier()  # Ensure all processes are synchronized

        # Construct similarity and target
        sim_targets = torch.eye(audio_embeds_all.size(0), device=audio_embeds_all.device)

        sim_a2t = audio_embeds_all @ text_embeds_all.T / self.temp
        sim_t2a = text_embeds_all @ audio_embeds_all.T / self.temp

        loss = self.criterion(sim_a2t, sim_t2a, sim_targets)
        if self.embed_reg:
            loss += (
                torch.mean(torch.abs(audio_embeds_all)) / torch.sqrt(torch.sum(audio_embeds_all**2)) +
                torch.mean(torch.abs(text_embeds_all)) / torch.sqrt(torch.sum(text_embeds_all**2))
            )
        return loss, audio_embeds, text_embeds

    def generate(self, input, text):
        # Handle flexible input
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
        else:
            x, x_lens = self.extract_feature(input)
        # Encode audio and text inputs
        audio_embeds = self.encode_audio(x, x_lens)
        text_embeds = self.encode_text(text)
        
        # Compute similarity scores (audio-to-text and text-to-audio)
        sim_a2t = (audio_embeds @ text_embeds.T)
        sim_t2a = (text_embeds @ audio_embeds.T)
        
        return sim_a2t, sim_t2a
    
def gather_embeddings(emb: torch.Tensor) -> torch.Tensor:
    """
    Differentiable all-gather that tolerates different local batch sizes.
    Each rank ends up with a (global_B, D) tensor that
    participates fully in back-prop.
    """
    if not dist.is_initialized():
        return emb

    # 1) Share local batch sizes
    local_B = torch.tensor([emb.size(0)], device=emb.device, dtype=torch.long)
    B_list  = [torch.zeros_like(local_B) for _ in range(dist.get_world_size())]
    dist.all_gather(B_list, local_B)
    B_list  = [b.item() for b in B_list]
    max_B   = max(B_list)

    # 2) Pad along dim 0 to max_B *explicitly*
    if local_B < max_B:
        pad = emb.new_zeros((max_B - local_B, emb.size(1)))  # [pad_rows, D]
        emb = torch.cat([emb, pad], dim=0)                   # shape = (max_B, D)

    # 3) Differentiable all-gather (same shape on every rank)
    gathered = all_gather(emb)                               # [world, max_B, D]

    # 4) Remove the padding for each rank, then concat
    chunks = [g[:B] for g, B in zip(gathered, B_list)]       # exact per-rank slice
    return torch.cat(chunks, dim=0)                          # (global_B, D)