import numpy as np
import random
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from typing import Dict, Iterable, Optional, List

from .utils import LayerNorm, Linear


class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int, att: str):
        super().__init__()
        self.n_head = n_head
        self.att = att
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        xa: Tensor = None,
        mask: Tensor = None,
    ):
        q = self.query(x)
        assert self.att in ['self','cross']
        if self.att == 'self':
            k = self.key(x)
            v = self.value(x)
        else:
            k = self.key(xa)
            v = self.value(xa)
        wv = self.qkv_attention(q, k, v, mask)
        return self.out(wv)
    
    def forward_inference(
        self,
        x: Tensor,
        xa: Tensor = None,
        mask: Tensor = None,
        cache:Tensor = torch.zeros((0, 0, 0)),
    ):
        q = self.query(x)
        assert self.att in ['self','cross']
        if self.att == 'self':
            if cache.size(0) > 0:
                key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
                #print(key_cache.size(),self.key(x).size())
                k = torch.cat([key_cache, self.key(x)], dim=1)
                v = torch.cat([value_cache, self.value(x)], dim=1)
            else:
                k = self.key(x)
                v = self.value(x)
            new_cache = torch.cat((k, v), dim=-1)
            wv = self.qkv_attention(q, k, v, mask)
        else:
            if cache.size(0) > 0:
                k, v = torch.split(cache, cache.size(-1) // 2, dim=-1)
                if xa is not None:
                    k = torch.cat((k,self.key(xa)),1)
                    v = torch.cat((v,self.value(xa)),1)
                    #print(k.size(),v.size())
                    new_cache = torch.cat((k, v), dim=-1)
                else:
                    new_cache = cache
            else:
                k = self.key(xa)
                v = self.value(xa)
                new_cache = torch.cat((k, v), dim=-1)
            wv = self.qkv_attention(q, k, v, mask)
        return self.out(wv), new_cache

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        
        #print(q.size(),k.size())

        qk = q @ k
        if mask is not None:
            assert mask.dim() in [3,4]
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
            qk += mask[:, :, -qk.shape[2]:]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
    
    


    
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head,att='self')
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = MultiHeadAttention(n_state, n_head,att='cross')
        self.cross_attn_ln = LayerNorm(n_state)

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)
        
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        self_att_mask: Optional[Tensor] = None,
        cross_att_mask: Optional[Tensor] = None,
    ):
        
        x = x + self.attn.forward(self.attn_ln(x),xa=None, mask=self_att_mask)
        x = x + self.cross_attn.forward(self.cross_attn_ln(x), xa, mask=cross_att_mask)
        x = x + self.mlp(self.mlp_ln(x))
        return x

    def forward_inference(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        self_att_mask: Optional[Tensor] = None,
        cross_att_mask: Optional[Tensor] = None,
        self_att_cache: Tensor = torch.zeros(( 0, 0, 0)),
        cross_att_cache: Tensor = torch.zeros(( 0, 0, 0)),
    ):
        
        res, self_att_cache = self.attn.forward_inference(self.attn_ln(x),xa=None, mask=self_att_mask, cache = self_att_cache)
        x = x + res
        
        res, cross_att_cache = self.cross_attn.forward_inference(self.cross_attn_ln(x),xa, mask=cross_att_mask, cache = cross_att_cache)
        x = x + res
        x = x + self.mlp(self.mlp_ln(x))
        return x, self_att_cache, cross_att_cache


class LiteSelfAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)
    
    def forward(
        self,
        x: Tensor,
        mask: Tensor = None,
        cache:Tensor = torch.zeros((0, 0, 0)),
    ):
        q = self.query(x)
        k, v = torch.split(cache, cache.size(-1) // 2, dim=-1)
        new_cache = torch.cat((k, v), dim=-1)
        wv = self.qkv_attention(q, k, v, mask)
        return self.out(wv)

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            assert mask.dim() in [3,4]
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
            qk += mask[:, :, -qk.shape[2]:]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)

class LiteAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()

        self.attn = LiteSelfAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = MultiHeadAttention(n_state, n_head, att='cross')
        self.cross_attn_ln = LayerNorm(n_state)

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)
        
    def forward(
        self,
        x: Tensor,
        xa: Tensor,
        self_att_cache: Tensor,
        cross_att_cache: Tensor = torch.zeros((0, 0, 0)),
        self_att_mask: Optional[Tensor] = None,
        cross_att_mask: Optional[Tensor] = None,
    ):
        
        x = x + self.attn.forward(self.attn_ln(x), mask=self_att_mask, cache = self_att_cache)
        x = x + self.cross_attn.forward_inference(self.cross_attn_ln(x),xa, mask=cross_att_mask, cache = cross_att_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x

    def forward_inference(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        self_att_mask: Optional[Tensor] = None,
        cross_att_mask: Optional[Tensor] = None,
        self_att_cache: Tensor = torch.zeros(( 0, 0, 0)),
        cross_att_cache: Tensor = torch.zeros(( 0, 0, 0)),
    ):
        
        x = x + self.attn.forward(self.attn_ln(x), mask=self_att_mask, cache = self_att_cache)
        res, ca_cache =  self.cross_attn.forward_inference(self.cross_attn_ln(x), xa, mask=cross_att_mask, cache = cross_att_cache)
        x = x + res
        x = x + self.mlp(self.mlp_ln(x))
        return x, ca_cache




class TextDecoder(nn.Module):
    def __init__(
        self,
        n_state: int = 1280,
        n_head: int = 16,
        n_layer: int = 4,
        n_vocab: int = 51866,
        n_ctx: int = 448,
        token_embedding: Tensor = None,
        positional_embedding: Tensor = None,
        use_lite_blocks: bool = False
    ):
        super().__init__()
        
        if token_embedding is None:
            self.token_embedding = nn.Embedding(n_vocab, n_state)
        else:
            self.token_embedding = token_embedding
        
        if positional_embedding is None:
            self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        else:
            self.positional_embedding = positional_embedding

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)
            ]
        )     
        self.ln = LayerNorm(n_state)   

        if use_lite_blocks: 
            self.blocks_lite: Iterable[ResidualAttentionBlock] = nn.ModuleList(
                [
                    LiteAttentionBlock(n_state, n_head) for _ in range(3)
                ]
            )        
            self.ln_lite = LayerNorm(n_state)
        else:
            self.blocks_lite = self.ln_lite = self.merge_weight = None
        
        if positional_embedding is not None:
            self.positional_embedding.requires_grad = False
        
    def forward(self, 
                x: Tensor, 
                xa: Tensor, 
                token_mask: Tensor=None, 
                mel_mask: Tensor=None,
                ):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            the encoded audio features to be attended on
        token_mask: torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
        mel_mask: torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
        """

        if token_mask is None:
            n_ctx = x.size(1)
            token_mask = torch.empty(
                n_ctx, n_ctx, device=xa.device, dtype=xa.dtype
                ).fill_(-np.inf).triu_(1).unsqueeze(0)
        x_embed = x = self.token_embedding(x) 

        x = x + self.positional_embedding[:x.shape[1]]
        x = x.to(xa.dtype)
              
        r_sa_cache = []
        for i, block in enumerate(self.blocks):
            x, sa, ca = block.forward_inference(x, xa, self_att_mask = token_mask, cross_att_mask = mel_mask)
            r_sa_cache.append(sa)

        r_sa_cache = torch.stack(r_sa_cache, dim=2).to(x.device)

        x = self.ln(x)
        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
        
        return logits
    
    
    def forward_cache(
        self,
        x: Tensor,
        xa: Tensor,
        sa_cache: Tensor = torch.zeros((0, 0, 0, 0)),
        ca_cache: Tensor = torch.zeros((0, 0, 0, 0)),
        ):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            the encoded audio features to be attended on
        token_mask: torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
        mel_mask: torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
        """
        n_ctx = x.size(1)
        token_mask = torch.empty(
            n_ctx, n_ctx, device=xa.device, dtype=xa.dtype
            ).fill_(-np.inf).triu_(1).unsqueeze(0)
        offset = sa_cache.size(2)
        offset_ca = ca_cache.size(2)
        xa = xa[:, offset_ca:]
        if offset > 0:
            x = x[:, offset:]
        
        x_embed = x = self.token_embedding(x) 
        x = x + self.positional_embedding[offset:offset + x.shape[1]]
        r_sa_cache = []
        r_ca_cache = []
        for i,block in enumerate(self.blocks):
            x, sa, ca = block.forward_inference(
                x,
                xa,
                self_att_mask = token_mask,
                cross_att_mask = None,
                self_att_cache = sa_cache[i] if sa_cache.size(0) > 0 else sa_cache,
                cross_att_cache = ca_cache[i] if ca_cache.size(0) > 0 else ca_cache
                )
            r_sa_cache.append(sa)
            r_ca_cache.append(ca)

        x = self.ln(x)

        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
        if offset == 0:
            logits = logits[:, [-1]]

        r_sa_cache = torch.stack(r_sa_cache, dim=0).to(x.device)
        r_ca_cache = torch.stack(r_ca_cache, dim=0).to(x.device)

        return logits, r_sa_cache, r_ca_cache

    def forward_cache_lite(
        self,
        x: Tensor,
        xa: Tensor,
        sa_cache: Tensor = torch.zeros((0, 0, 0, 0)),
        ca_cache: Tensor = torch.zeros((0, 0, 0, 0)),
        ):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            the encoded audio features to be attended on
        token_mask: torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
        mel_mask: torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
        """

        offset = sa_cache.size(2)
        offset_ca = ca_cache.size(2)
        x = x[:, [-1]]
        xa = xa[:, offset_ca:]
        n_layer, bsz, _, n_state = sa_cache.shape

        r_sa_cache = sa_cache.permute(1, 2, 0, 3)
        r_sa_cache = r_sa_cache.reshape(bsz, -1, 3, n_layer // 3, n_state).transpose(2, 3).reshape(bsz, -1, 3, n_state)

        x = self.token_embedding(x)

        r_ca_cache = []
        for i, block in enumerate(self.blocks_lite):
            x, ca = block.forward_inference(
                x, xa, 
                self_att_cache=r_sa_cache[:, :, i], 
                cross_att_cache = ca_cache[i] if ca_cache.size(0) > 0 else ca_cache)
            r_ca_cache.append(ca)

        x = self.ln_lite(x)

        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()

        r_ca_cache = torch.stack(r_ca_cache, dim=0).to(x.device)

        return logits, None,  r_ca_cache