import torch
from torch import Tensor, nn
import numpy as np
from typing import Dict, Iterable, Optional, List
from torch.nn import functional as F
from torch.nn import Linear
from .utils import LayerNorm, Linear, make_pad_mask

class ZeroAdaLN(nn.Module):
    def __init__(self, n_embd: int):
        super().__init__()
        self.modal_emb = nn.Embedding(2, n_embd, padding_idx=None) 
        self.silu = nn.SiLU()
        self.linear = Linear(n_embd, n_embd*6, bias=True)
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)

    def forward(self, x, modal_id: int):
        step_ids = torch.full_like(x[..., 0], modal_id, dtype=torch.long)
        emb = self.modal_emb(step_ids)
        emb = self.linear(self.silu(emb))
        return torch.chunk(emb, 6, dim=2)
    

class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        mask: Optional[Tensor] = None,
    ):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        wv = self.qkv_attention(q, k, v, mask)          
        return self.out(wv)

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            assert mask.dim() in [3,4]
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
            qk += mask
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
    

class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, use_adaln: bool = True):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)
        self.adaLN = ZeroAdaLN(n_state) if use_adaln else None

    def forward(
        self,
        x: Tensor,
        mask: Tensor,
        modal_id: int=0,
    ):
        if self.adaLN:
            ns1, nb1, fs1, ns2, nb2, fs2 = self.adaLN(x, modal_id)
        else:
            ns1, nb1, fs1, ns2, nb2, fs2 = 0, 0, 1, 0, 0, 1

        res = x
        x = self.attn_ln(x)
        x = x * (1 + ns1) + nb1
        x = res + self.attn.forward(x, mask=mask) * fs1

        res = x
        x = self.mlp_ln(x)
        x = x * (1 + ns2) + nb2
        x = res + self.mlp(x) * fs2
        return x
    
class ShareEncoder(nn.Module):
    def __init__(
        self,
        n_state: int = 1024,
        n_head: int = 16,
        n_layer: int = 2,
        use_adaln: bool = True,
    ):
        super().__init__()
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head, use_adaln) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)
        
    def forward(
            self,
            x: Tensor,
            mask:Tensor=None,
            modal_id: int=0,
            ):
        """
        x : torch.Tensor, shape = (batch_size, n_ctx, n_mels)
            the mel spectrogram of the audio
        mask : torch.Tensor, shape = (batch_size, n_ctx, n_ctx)
            the lengths of mel spectrogram of the audio
        """
        for i, block in enumerate(self.blocks):
            x = block.forward(x, mask, modal_id)                
        x = self.ln_post(x)
        return x
    
class TextEncoder(ShareEncoder):
    def __init__(
        self,
        n_state: int = 1024,
        n_head: int = 16,
        n_layer: int = 2,
        n_vocab: int = 51866,
        n_ctx: int = 448,
        token_embedding: Tensor = None,
        positional_embedding: Tensor = None,
        causal = True,
    ):
        super().__init__(n_state, n_head, n_layer, use_adaln=False)     
        if token_embedding is None:
            self.token_embedding = nn.Embedding(n_vocab, n_state)
        else:
            self.token_embedding = token_embedding
        
        if positional_embedding is None:
            self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        else:
            self.positional_embedding = positional_embedding
        self.causal = causal

        
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        
        
    def forward(self, 
                x: Tensor,
                mask: Tensor = None):
        """
        x : torch.Tensor, shape = (batch_size, n_ctx, n_mels)
            the mel spectrogram of the audio
        stream_mel_mask : torch.Tensor, shape = (batch_size, n_ctx, n_ctx)
            the lengths of mel spectrogram of the audio 
        mel_mask : torch.Tensor, shape = (batch_size, n_ctx, n_ctx)
            the lengths of mel spectrogram of the audio            
        """
        x = self.token_embedding(x) + self.positional_embedding[:x.shape[-1]]   
        if self.causal:
            n_ctx = x.shape[1]
            token_mask = torch.empty(
                    n_ctx, n_ctx, device=x.device, dtype=x.dtype
                    ).fill_(-np.inf).triu_(1).unsqueeze(0) 
            mask = mask + token_mask if mask is not None else token_mask
        for _, block in enumerate(self.blocks):
            x = block.forward(x, mask)
        return x
    
    def freeze_token_embedding(self):
        for _, param in self.token_embedding.named_parameters():
            param.requires_grad = False

class Adapter(nn.Module):
    def __init__(
            self,             
            n_state,
            n_head,
            n_layers=1,
            kernel_size=4,
            stride=4,
            dropout=0.0,
            pad=False):
        super().__init__()

        self.layers = nn.ModuleList(
            AdapterLayer(
                n_state,
                n_head,
                kernel_size,
                stride,
                dropout,
                pad,
                ) for _ in range(n_layers))

        self.kernel_size = kernel_size
        self.stride = stride
        self.pad = pad

    def _compute_sub_sample_lengths(self, seq_lens):
        if seq_lens is None:
            return None
        pad = self.kernel_size // 2 if self.pad else 0
        seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1
        return seq_lens.floor()

    def forward(self, x, seq_lens=None):
        for layer in self.layers:
            seq_lens = self._compute_sub_sample_lengths(seq_lens).long() if seq_lens is not None else None
            x = layer(
                x, seq_lens=seq_lens
            )

        return x, seq_lens


class AdapterLayer(nn.Module):
    def __init__(
            self, 
            n_state,
            n_head,
            kernel_size=4,
            stride=4,
            dropout=0.0,
            pad=False,
            ):
        super().__init__()

        self.kernel_size = kernel_size
        self.stride = stride

        # 1. residual convolution
        self.residual_norm = LayerNorm(n_state)
        self.residual_conv = nn.Conv1d(
            n_state,
            2 * n_state,
            self.kernel_size,
            stride=self.stride,
            padding=self.stride // 2 if pad else 0,
        )
        self.activation = nn.GLU(dim=1)

        # Self-Attention
        self.self_attn_norm = LayerNorm(n_state)
        self.self_attn_conv = nn.Conv1d(
            n_state,
            2 * n_state,
            self.kernel_size,
            stride=self.stride,
            padding=self.stride // 2 if pad else 0,
        )
        self.self_attn = MultiHeadAttention(n_state, n_head)
        self.self_attn_dropout = nn.Dropout(dropout)

        # Feed-forward
        self.ffn_norm = LayerNorm(n_state)
        self.ffn = nn.Sequential(
            Linear(n_state, n_state * 4),
            nn.GELU(),
            Linear(n_state * 4, n_state)
        )

    def forward(
        self,
        x,
        seq_lens: Optional[torch.Tensor] = None,
    ):
        res = self.residual_norm(x)

        res = res.transpose(1, 2)
        res = self.residual_conv(res)
        res = self.activation(res)
        res = res.transpose(1, 2)

        x = self.self_attn_norm(x)
        x = x.transpose(1, 2)
        x = self.self_attn_conv(x)
        x = self.activation(x)
        x = x.transpose(1, 2)

        if seq_lens is not None:
            mask = make_pad_mask(seq_lens).unsqueeze(1).expand(-1, seq_lens.max(), -1)
            mask = torch.zeros_like(mask, dtype=x.dtype).masked_fill(mask, float("-inf")).to(x.device)
        else:
            mask = None

        x = self.self_attn(x, mask)
        x = self.self_attn_dropout(x)
        x = x + res

        x = self.ffn(self.ffn_norm(x)) + x

        return x
    