import math
import logging
from typing import Optional, List, Union
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from ..zipformer.model import ZipformerEncoderModel
from ..zipformer.utils.padding import make_pad_mask

from ...auto.auto_config import AutoConfig
from ...utils.checkpoint import load_model_params
from transformers import (
    T5Tokenizer, CLIPTokenizer, BartTokenizer
)

TOKENIZERS = {
    't5-small': T5Tokenizer,
    'facebook/bart-base': BartTokenizer,
    'openai/clip-vit-base-patch32': CLIPTokenizer,
}

SPECIAL_TOKENS = {
    'bos_token': '<s>',
    'eos_token': '</s>',
    'pad_token': '<pad>',
    'mask_token': '<mask>',
}

def ensure_special_tokens(tokenizer, required_special_tokens=SPECIAL_TOKENS):
    to_add = {k:v for k,v in required_special_tokens.items() if getattr(tokenizer, k) is None}
    if to_add:
        tokenizer.add_special_tokens(to_add)

def causal_mask(length: int, device: Optional[torch.device] = None) -> torch.BoolTensor:
    """Upper‑triangular mask for *batch‑first* causal self‑attention."""
    return torch.triu(torch.ones(length, length, dtype=torch.bool, device=device), diagonal=1)

class SinusoidalPositionalEncoding(nn.Module):
    """Classic positional embedding (fixed, non-trainable)."""
    def __init__(self, d_model, max_len=4096):
        super(SinusoidalPositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe, persistent=False)

    def forward(self, x):
        """
        x: [B, T] token ids → returns [B, T, D]
        """
        return self.pe[: x.size(1)].unsqueeze(0).to(x.device)

class DecoderModel(nn.Module):
    """
    Wrapper for PyTorch's TransformerDecoder.
    """
    def __init__(self, config, tokenizer):
        super().__init__()

        self.vocab_size = len(tokenizer)

        # Basic model parameters
        self.d_model = config.d_model if hasattr(config, 'd_model') else max(config.encoder_dim)
        self.decoder_nhead = config.decoder_nhead
        self.text_embed = nn.Embedding(self.vocab_size, self.d_model)
        self.output_proj = nn.Linear(self.d_model, self.vocab_size, bias=False)
        if config.decoder_shared_emb:
            self.output_proj.weight = self.text_embed.weight

        # Other Configurable parameters
        self.dim_feedforward = config.dim_feedforward if hasattr(config, 'dim_feedforward') else 4 * self.d_model
        self.dropout = config.decoder_dropout if hasattr(config, 'decoder_dropout') else 0.1
        self.activation = config.decoder_activation if hasattr(config, 'decoder_activation') else 'relu'
        self.norm_first = config.decoder_norm_first if hasattr(config, 'decoder_norm_first') else False
        self.bias = config.decoder_bias if hasattr(config, 'decoder_bias') else False
        
        self.num_layers = config.num_decoder_layers if hasattr(config, 'num_decoder_layers') else 6
        self.positional_encoding = SinusoidalPositionalEncoding(self.d_model)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.d_model,
            nhead=self.decoder_nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            activation=self.activation,
            batch_first=True,
            norm_first=self.norm_first,
            bias=self.bias,
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=self.num_layers)

    def forward(self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        """
        Forward pass for the decoder.
        
        Args:
            tgt: Tensor [B, T, D], target sequence embeddings
            memory: Tensor [B, S, D], encoder output embeddings
            tgt_mask: Optional[Tensor], mask for the target sequence
            tgt_key_padding_mask: Optional[Tensor], padding mask for the target sequence
            memory_key_padding_mask: Optional[Tensor], padding mask for the encoder output
            
        Returns:
            Tensor [B, T, D], decoder output embeddings
        """
        tgt_emb = self.text_embed(tgt) + self.positional_encoding(tgt)
        return self.decoder(
            tgt=tgt_emb,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
        )

class ZipformerForAudioCaptioningModel(ZipformerEncoderModel):
    @classmethod
    def from_pretrained(cls, exp_dir, checkpoint_filename='pretrained.pt'):
        """
        Load model from exp_dir.
        """
        config = AutoConfig.from_pretrained(exp_dir)
        model = cls(config)
        ckpt_path = Path(exp_dir) / checkpoint_filename
        load_model_params(model, ckpt_path)
        return model

    def __init__(self, config):
        super().__init__(config)

        self.text_tokenizer_type = config.text_tokenizer_type
        self._build_tokenizer(self.text_tokenizer_type)
        self.text_decoder = DecoderModel(config, self.tokenizer)

        self.enc_dec_proj = nn.Linear(max(config.encoder_dim), self.text_decoder.d_model, bias=False) if max(config.encoder_dim) != self.text_decoder.d_model else None
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.id_pad, label_smoothing=config.label_smoothing if hasattr(config, 'label_smoothing') else 0.1, reduction='sum')
    
    # --------------------------- Tokenizer initialization ----------------- #
    def _build_tokenizer(self, tokenizer_type: str):
        """
        Initialize the tokenizer based on the specified type.
        Args:
            tokenizer_type: str, type of tokenizer to use
        Returns:
            tokenizer: Tokenizer instance
        """
        if tokenizer_type not in TOKENIZERS:
            raise ValueError(f"Unsupported tokenizer type '{tokenizer_type}'. Supported: {list(TOKENIZERS.keys())}")
        
        self.tokenizer = TOKENIZERS[tokenizer_type].from_pretrained(tokenizer_type)
        ensure_special_tokens(self.tokenizer)

        self.id_bos, self.id_eos, self.id_pad = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id
        self.vocab_size = len(self.tokenizer)

    # --------------------------- Encoder wrapper -------------------------- #
    def encode_audio(self, x, x_lens):
        encoder_output = self.forward_encoder(x, x_lens)
        encoder_out, encoder_out_lens = encoder_output.encoder_out, encoder_output.encoder_out_lens
        padding_mask = make_pad_mask(encoder_out_lens, encoder_out.size(1)).to(encoder_out.device)

        return encoder_out, padding_mask
    
    # --------------------------- Text helper ------------------------------ #
    def _tokenise(self, text: List[str]):
        ids = self.tokenizer(text, add_special_tokens=False, truncate=True, max_length=128).input_ids
        return ids

    def _prepare_tgt(self, text: List[str]):
        """
        Prepare target input for decoder.
        Args:
            text: List[str]
        Returns:
            tgt_in: Tensor [B, T], input to decoder
            tgt_out: Tensor [B, T], target output for loss computation
            tgt_pad_mask: Tensor [B, T], padding mask for decoder input
            tgt_mask: Tensor [T, T], causal mask for decoder input
        """
        PAD, BOS, EOS = self.id_pad, self.id_bos, self.id_eos
        device = next(self.parameters()).device      # always valid

        ids = self._tokenise(text)  # tokenize text to ids

        text_lens = [len(s) for s in ids]

        tgt_in  = [torch.tensor([BOS] + s, dtype=torch.long) for s in ids]  # add BOS token
        tgt_out = [torch.tensor(s + [EOS], dtype=torch.long) for s in ids]  # add EOS token                                      
        tgt_in  = nn.utils.rnn.pad_sequence(tgt_in,  batch_first=True, padding_value=PAD).to(device) 
        tgt_out = nn.utils.rnn.pad_sequence(tgt_out, batch_first=True, padding_value=PAD).to(device)

        tgt_pad_mask = tgt_in.eq(PAD)                   # True → ignore position
        tgt_mask = causal_mask(tgt_in.size(1), device=device)

        return tgt_in, tgt_out, tgt_pad_mask, tgt_mask, text_lens
    
    # --------------------------- Forward ---------------------------------- #
    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        text: Optional[List[str]] = None,
    ):
        """
        Args:
            x: Tensor [B, T, D]
            x_lens: Tensor [B]
            text: List[str] or None
        Returns:
            logits: Tensor [B, T', vocab_size]
            loss: Tensor, if text is provided, otherwise None

        """
        # Encode audio features
        audio_embs, padding_mask = self.encode_audio(x, x_lens)

        if self.enc_dec_proj is not None:
            audio_embs = self.enc_dec_proj(audio_embs)

        # Prepare decoder input
        if self.training:
            assert text is not None, "Text must be provided during training" 
            tgt_in, tgt_out, tgt_pad_mask, tgt_mask, text_lens = self._prepare_tgt(text)
            decoder_out = self.text_decoder(
                tgt=tgt_in,
                memory=audio_embs,
                tgt_mask=tgt_mask,                      # causal mask
                tgt_key_padding_mask=tgt_pad_mask,      # padding mask for decoder input
                memory_key_padding_mask=padding_mask,   # padding mask for encoder output
            )
            logits = self.text_decoder.output_proj(decoder_out)
            loss = self.criterion(logits.view(-1, self.vocab_size), tgt_out.view(-1))
            return loss, None, text_lens
        else:
            output_text = self._generate(audio_embs, padding_mask)
            return None, output_text, None

    @torch.inference_mode()
    def _generate(self, audio_embs, padding_mask, max_len=128):
        """
        Generate captions from audio embeddings.
        Args:
            audio_embs: Tensor [B, T, D], audio embeddings
            padding_mask: Tensor [B, T], padding mask for audio embeddings
            max_length: int, maximum length of generated captions
            is_causal: bool, whether to use causal mask for decoder
        Returns:
            generated_ids: Tensor [B, T'], generated token ids
        """
        B  = audio_embs.size(0)
        tokens = torch.full((B, 1), self.id_bos, dtype=torch.long, device=audio_embs.device)  # [B, 1]
        finished = torch.zeros(B, dtype=torch.bool, device=audio_embs.device)

        for _ in range(max_len):
            tgt_mask = causal_mask(tokens.size(1), device=audio_embs.device)
            dec_out = self.text_decoder(
                tgt=tokens,
                memory=audio_embs,
                tgt_mask=tgt_mask,
                memory_key_padding_mask=padding_mask
            )
            next_tok = self.text_decoder.output_proj(dec_out[:, -1]).argmax(-1)
            tokens = torch.cat([tokens, next_tok[:, None]], dim=1)
            finished |= next_tok.eq(self.id_eos)
            if finished.all():
                break

        # Find first EOS position and pad the rest of the sequence
        for b in range(B):
            eos_pos = (tokens[b] == self.id_eos).nonzero(as_tuple=True)[0]
            if len(eos_pos):
                tokens[b, eos_pos[0]:] = self.id_pad
        
        text_ids = tokens[:, 1:].detach().cpu()  # remove BOS token, return [B, T']
        text = self.tokenizer.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        return text  # remove BOS token, return [B, T']
    
    def generate(self, input):
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
        else:
            x, x_lens = self.extract_feature(input)

        audio_embs, padding_mask = self.encode_audio(x, x_lens)
        if self.enc_dec_proj is not None:
            audio_embs = self.enc_dec_proj(audio_embs)

        text = self._generate(audio_embs, padding_mask)
        return text
        
class ZipformerForAudioCaptioningWithMaskingModel(ZipformerForAudioCaptioningModel):
    """
    Zipformer model for audio captioning with masking support.
    """
    def __init__(self, config):
        super().__init__(config)
        assert config.parallel_decoding_prob >= 0.0 and config.parallel_decoding_prob <= 1.0, \
            f"Parallel decoding probability must be in [0.0, 1.0], got {config.parallel_decoding_prob}"
        self.parallel_decoding_prob = config.parallel_decoding_prob
        if self.parallel_decoding_prob == 0.0:
            logging.info("Parallel decoding is disabled (probability set to 0.0). "
                  "This will use causal decoding for all inputs.")
        elif self.parallel_decoding_prob == 1.0:
            logging.info("Parallel decoding is enabled (probability set to 1.0). "
                  "This will use parallel decoding for all inputs.")
            
        self.id_mask = self.tokenizer.mask_token_id

    def _prepare_tgt(self, text: List[str], parallel_decoding_prob: float):
        """
        Prepare target input for decoder.
        Args:
            text: List[str]
        Returns:
            tgt_in: Tensor [B, T], input to decoder
            tgt_out: Tensor [B, T], target output for loss computation
            tgt_pad_mask: Tensor [B, T], padding mask for decoder input
            tgt_mask: Tensor [T, T], causal mask for decoder input
        """
        PAD, BOS, EOS, MASK = self.id_pad, self.id_bos, self.id_eos, self.id_mask
        device = next(self.parameters()).device      # always valid

        ids = self._tokenise(text)  # tokenize text to ids
        text_lens = [len(s) for s in ids]
        lengths = [len(s) for s in ids]
        B = len(ids)

        n_parallel = int(round(B*parallel_decoding_prob))
        perm = torch.randperm(B, device=device)
        parallel_ids = perm[:n_parallel]
        causal_ids = perm[n_parallel:]

        max_causal_len = max((lengths[idx] + 1 for idx in causal_ids), default=0)  # +1 for EOS
        max_parallel_len = max((lengths[idx] for idx in parallel_ids), default=0)
        max_len = max(max_causal_len, max_parallel_len)

        tgt_in = torch.full((B, max_len), PAD, dtype=torch.long, device=device)  # [B, max_len]
        tgt_out = torch.full_like(tgt_in, PAD)  # [B, max_len]
        tgt_pad_mask = torch.zeros_like(tgt_in, dtype=torch.bool)  # [B, max_len]

        if causal_ids.numel():
            tgt_in_causal = [torch.tensor([BOS] + ids[idx], dtype=torch.long) for idx in causal_ids]  # add BOS token
            tgt_out_causal = [torch.tensor(ids[idx] + [EOS], dtype=torch.long) for idx in causal_ids]
            tgt_in_causal = nn.utils.rnn.pad_sequence(tgt_in_causal, batch_first=True, padding_value=PAD).to(device)
            tgt_out_causal = nn.utils.rnn.pad_sequence(tgt_out_causal, batch_first=True, padding_value=PAD).to(device)

            T_causal = tgt_in_causal.size(1)
            tgt_in[causal_ids, :T_causal] = tgt_in_causal
            tgt_out[causal_ids, :T_causal] = tgt_out_causal

        if parallel_ids.numel():
            tgt_out_parallel = [torch.tensor(ids[idx], dtype=torch.long) for idx in parallel_ids]  # no EOS token
            tgt_in_parallel = [torch.full_like(t, MASK) for t in tgt_out_parallel]  # use MASK token
            tgt_out_parallel = nn.utils.rnn.pad_sequence(tgt_out_parallel, batch_first=True, padding_value=PAD).to(device)
            tgt_in_parallel = nn.utils.rnn.pad_sequence(tgt_in_parallel, batch_first=True, padding_value=PAD).to(device)

            T_parallel = tgt_out_parallel.size(1)
            tgt_in[parallel_ids, :T_parallel] = tgt_in_parallel
            tgt_out[parallel_ids, :T_parallel] = tgt_out_parallel

        tgt_pad_mask = tgt_in.eq(PAD)                   # True → ignore position
        
        per_sample_mask = torch.zeros((B, max_len, max_len), dtype=torch.bool, device=device)  # [B, T, T]
        if causal_ids.numel():
            per_sample_mask[causal_ids] = causal_mask(max_len, device=device)  # causal mask for causal_ids

        num_heads = self.text_decoder.decoder_nhead
        tgt_mask = per_sample_mask[:, None, :, :].repeat_interleave(num_heads, dim=1) # [B, num_heads, T, T]
        tgt_mask = tgt_mask.view(B*num_heads, max_len, max_len)  # [B, num_heads, T, T]

        return tgt_in, tgt_out, tgt_pad_mask, tgt_mask, text_lens

    # --------------------------- Forward ---------------------------------- #
    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        text: Optional[List[str]] = None,
    ):
        """
        Args:
            x: Tensor [B, T, D]
            x_lens: Tensor [B]
            text: List[str] or None
        Returns:
            logits: Tensor [B, T', vocab_size]
            loss: Tensor, if text is provided, otherwise None

        """
        # Encode audio features
        audio_embs, padding_mask = self.encode_audio(x, x_lens)

        if self.enc_dec_proj is not None:
            audio_embs = self.enc_dec_proj(audio_embs)

        # Prepare decoder input
        if self.training:
            assert text is not None, "Text must be provided during training" 
            tgt_in, tgt_out, tgt_pad_mask, tgt_mask, text_lens = self._prepare_tgt(text, self.parallel_decoding_prob)
            decoder_out = self.text_decoder(
                tgt=tgt_in,
                memory=audio_embs,
                tgt_mask=tgt_mask,                      # causal mask
                tgt_key_padding_mask=tgt_pad_mask,      # padding mask for decoder input
                memory_key_padding_mask=padding_mask,   # padding mask for encoder output
            )
            logits = self.text_decoder.output_proj(decoder_out)
            loss = self.criterion(logits.view(-1, self.vocab_size), tgt_out.view(-1))

            return loss, None, text_lens
        else:
            output_text = self._generate(audio_embs, padding_mask, force_causal=True)
            ce_loss, text_lens = self._generate(audio_embs, padding_mask, force_causal=False, text=text)
            return ce_loss, output_text, text_lens

    @torch.inference_mode()
    def _generate(self, audio_embs, padding_mask, force_causal, text=None, max_len=128):
        """
        Generate captions from audio embeddings.
        Args:
            audio_embs: Tensor [B, T, D], audio embeddings
            padding_mask: Tensor [B, T], padding mask for audio embeddings
            max_length: int, maximum length of generated captions
            is_causal: bool, whether to use causal mask for decoder
        Returns:
            generated_ids: Tensor [B, T'], generated token ids
        """
        if force_causal:
            return super()._generate(audio_embs, padding_mask, max_len)
        else:
            if text is None:
                raise ValueError("Text must be provided for non-causal generation")
            tgt_in, tgt_out, tgt_pad_mask, tgt_mask, text_lens = self._prepare_tgt(text, 1.0) # use parallel decoding for non-causal generation
            decoder_out = self.text_decoder(
                tgt=tgt_in,
                memory=audio_embs,
                tgt_mask=tgt_mask,                      # parallel mask
                tgt_key_padding_mask=tgt_pad_mask,      # padding mask for decoder input
                memory_key_padding_mask=padding_mask,   # padding mask for encoder output
            )
            logits = self.text_decoder.output_proj(decoder_out)
            loss = self.criterion(logits.view(-1, self.vocab_size), tgt_out.view(-1))
            return loss, text_lens  # return CE loss for non-causal generation
    
    def generate(self, input):
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
        else:
            x, x_lens = self.extract_feature(input)

        audio_embs, padding_mask = self.encode_audio(x, x_lens)
        if self.enc_dec_proj is not None:
            audio_embs = self.enc_dec_proj(audio_embs)

        text = self._generate(audio_embs, padding_mask, force_causal=True)
        return text