# pylint: disable=C0114, C0301, R0913, C0303, C0115, C0116, R0914, E0402
from typing import Dict, Iterable, Optional, List
import torch
import torch.nn.functional as F
from torch import Tensor, nn
import numpy as np

from .utils import LayerNorm, Linear, sinusoids, Conv1d, make_pad_mask


def make_attn_mask(seq_lens, dtype):
    mask = make_pad_mask(seq_lens).to(dtype) * torch.finfo(dtype).min
    mask = mask.unsqueeze(1).expand(-1, seq_lens.max(), -1)
    return mask

class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        mask: Optional[Tensor] = None,
    ):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        wv = self.qkv_attention(q, k, v, mask)          
        return self.out(wv)
    
    def forward_streaming(
        self,
        x: Tensor,
        mask: Tensor = None,
        cache: Tensor =  torch.zeros((0, 0, 0)),
    ):
        q = self.query(x)
        
        if cache.size(0) > 0:
            key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
            k = torch.cat([key_cache, self.key(x)], dim=1)
            v = torch.cat([value_cache, self.value(x)], dim=1)
        else:
            k = self.key(x)
            v = self.value(x)            
        new_cache = torch.cat((k, v), dim=-1)
        wv = self.qkv_attention(q, k, v, mask)
        return self.out(wv), new_cache

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            assert mask.dim() in [3,4]
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
            qk += mask
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
    
    
    
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        mask: Tensor,
    ):
        x = x + self.attn.forward(self.attn_ln(x), mask=mask)
        x = x + self.mlp(self.mlp_ln(x))
        return x
    
    def forward_streaming(
        self,
        x: Tensor,
        mask: Tensor,
        cache: Tensor = torch.zeros((0, 0, 0)),
    ):
        
        res,cache = self.attn.forward_streaming(self.attn_ln(x), mask=mask, cache = cache)
        x = x + res
        x = x + self.mlp(self.mlp_ln(x))
        return x,cache


class AudioEncoder_MoShared(nn.Module):
    def __init__(
        self,
        n_mels: int = 80,
        n_ctx: int = 1500,
        n_conv1_stride: int = 1,
        n_conv2_stride: int = 2,
        n_state: int = 1024,
        n_head: int = 16,
        n_layer: int = 20,
    ):
        super().__init__()
        self.kernel_size = 3
        self.lorder = self.kernel_size - 1
        self.conv1 = Conv1d(n_mels,  n_state, kernel_size = self.kernel_size, stride=n_conv1_stride, padding=0)
        self.conv2 = Conv1d(n_state, n_state, kernel_size = self.kernel_size, stride=n_conv2_stride, padding=0)        
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        
        
    def forward(self, 
                x: Tensor,
                mask: Tensor = None):
        """
        x : torch.Tensor, shape = (batch_size, n_ctx, n_mels)
            the mel spectrogram of the audio
        stream_mel_mask : torch.Tensor, shape = (batch_size, n_ctx, n_ctx)
            the lengths of mel spectrogram of the audio 
        mel_mask : torch.Tensor, shape = (batch_size, n_ctx, n_ctx)
            the lengths of mel spectrogram of the audio            
        """
        if self.lorder>0:
            x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)  # B,D,T
        x = F.gelu(self.conv1(x))
        if self.lorder>0:
            x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) # B,D,T//2
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)
        src_len = x.size(1)
        self.input_positional_embedding = self.positional_embedding[:src_len]
        assert x.shape[1:] == self.input_positional_embedding.shape, f"incorrect audio shape: {x.shape[1:], self.input_positional_embedding.shape}"
        x = (x + self.input_positional_embedding).to(x.dtype)        
        for i,block in enumerate(self.blocks):
            x = block.forward(x,mask)
        return x
        
        
class AudioEncoder_PaShared(nn.Module):
    def __init__(
        self,
        n_state: int = 1024,
        n_head: int = 16,
        n_layer: int = 4,
    ):
        super().__init__()
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)
        
    def forward(self,
                x: Tensor,
                mask:Tensor=None):
        """
        x : torch.Tensor, shape = (batch_size, n_ctx, n_mels)
            the mel spectrogram of the audio
        mask : torch.Tensor, shape = (batch_size, n_ctx, n_ctx)
            the lengths of mel spectrogram of the audio
        """
        for i, block in enumerate(self.blocks):
            x = block.forward(x, mask)                
        x = self.ln_post(x)
        return x

class TextEncoder(nn.Module):
    def __init__(
        self,
        n_state: int = 1024,
        n_head: int = 16,
        n_layer: int = 2,
        n_vocab: int = 51866,
        n_ctx: int = 448,
        token_embedding: Tensor = None,
        positional_embedding: Tensor = None,
    ):
        super().__init__()    
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state) 
        if token_embedding is None:
            self.token_embedding = nn.Embedding(n_vocab, n_state)
        else:
            self.token_embedding = token_embedding
        
        if positional_embedding is None:
            self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        else:
            self.positional_embedding = positional_embedding
        self.causal = True


    def forward(
            self, 
            x: Tensor,
            x_lens: Tensor):
        """
        x : torch.Tensor, shape = (batch_size, n_ctx, n_mels)
            the mel spectrogram of the audio
        stream_mel_mask : torch.Tensor, shape = (batch_size, n_ctx, n_ctx)
            the lengths of mel spectrogram of the audio 
        mel_mask : torch.Tensor, shape = (batch_size, n_ctx, n_ctx)
            the lengths of mel spectrogram of the audio            
        """
        x = self.token_embedding(x) + self.positional_embedding[:x.shape[-1]]   
        mask = make_attn_mask(x_lens, next(self.parameters()).dtype)
        if self.causal:
            n_ctx = x.shape[1]
            causal_mask = torch.empty(
                    n_ctx, n_ctx, device=x.device, dtype=x.dtype
                    ).fill_(-np.inf).triu_(1).unsqueeze(0) 
            mask = mask + causal_mask
        for _, block in enumerate(self.blocks):
            x = block.forward(x, mask)
        x = self.ln_post(x)
        
        return x
    
    def freeze_token_embedding(self):
        for _, param in self.token_embedding.named_parameters():
            param.requires_grad = False
    
    
        