import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
from typing import Tuple, Optional, Literal
from torch import einsum




def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''
    batch_size, batch_length = seq.shape[:2]
    subsequent_mask = (1 - torch.triu(
        torch.ones((1, batch_length, batch_length), device=seq.device), diagonal=1)).bool()
    return subsequent_mask


def get_subsequent_mask_with_batch_length(batch_length, device):
    ''' For masking out the subsequent info. '''
    subsequent_mask = (1 - torch.triu(torch.ones((1, batch_length, batch_length), device=device), diagonal=1)).bool()
    return subsequent_mask


def get_vector_mask(batch_length, device):
    mask = torch.ones((1, 1, batch_length), device=device).bool()
    # mask = torch.ones((1, batch_length, 1), device=device).bool()
    return mask


class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
     
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn


class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)   # For head axis broadcasting.

        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn
    
class AttentionBlockKVCache(nn.Module):
    def __init__(self, feat_dim, hidden_dim, num_heads, dropout):
        super().__init__()
        self.slf_attn = MultiHeadAttention(num_heads, feat_dim, feat_dim//num_heads, feat_dim//num_heads, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(feat_dim, hidden_dim, dropout=dropout)

    def forward(self, q, k, v, slf_attn_mask=None):
        output, attn = self.slf_attn(q, k, v, mask=slf_attn_mask)
        output = self.pos_ffn(output)
        return output, attn

class MoEAttentionBlockKVCache(nn.Module):
    def __init__(self, feat_dim, hidden_dim, num_heads, n_routed_experts, n_activated_experts, dropout, task_dim):
        super().__init__()
        self.slf_attn = MultiHeadAttention(num_heads, feat_dim, feat_dim//num_heads, feat_dim//num_heads, dropout=dropout)
        #self.pos_ffn = PositionwiseFeedForward(feat_dim, hidden_dim, dropout=dropout)
        self.moe = MoE(feat_dim, hidden_dim, n_routed_experts, n_activated_experts, task_dim)
        self.layer_norm = nn.LayerNorm(feat_dim, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, taskemb,slf_attn_mask=None):
        output, attn = self.slf_attn(q, k, v, mask=slf_attn_mask)
        residual = output
        output = self.moe(output,taskemb)+residual
        output = self.dropout(output)
        return self.layer_norm(output)
class Gate(nn.Module):
    def __init__(self, ndim, n_activated_experts, n_all_experts):  
        super().__init__()
        self.ndim = ndim
        self.topk = n_activated_experts  
        self.route_scale = 1
        self.weight = nn.Linear(ndim,n_all_experts) 

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
        scores = self.weight(x) 
        scores = scores.softmax(dim=-1) 
        original_scores = scores.clone()
        indices = torch.topk(scores, self.topk, dim=-1)[1]
        weights = torch.gather(original_scores, dim=-1, index=indices)  
        weights *= self.route_scale  
        return weights.type_as(x), indices

class Expert(nn.Module):
    def __init__(self, dim: int, inter_dim: int):
        
        super().__init__()
        self.w1 = nn.Linear(dim, inter_dim, bias=False)
        self.w2 = nn.Linear(inter_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, inter_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        return self.w2(F.silu(self.w1(x)) * self.w3(x))




class MoE(nn.Module):
    def __init__(self, ndim, hiddim, n_routed_experts, n_activated_experts, task_dim):
        super().__init__()
        self.ndim = ndim
        self.hiddim = hiddim
        self.n_routed_experts = n_routed_experts 
        self.n_activated_experts = n_activated_experts

        assert self.n_activated_experts <= self.n_routed_experts
        self.task_dim = task_dim

        self.experts_start_idx =  0
        self.experts_end_idx = self.experts_start_idx + self.n_routed_experts
        #self.gate = Gate(task_dim, self.n_activated_experts, n_routed_experts)
        self.gate = Gate(ndim, self.n_activated_experts, n_routed_experts)
        self.experts = nn.ModuleList([Expert(ndim, hiddim) if self.experts_start_idx <= i < self.experts_end_idx else None
                                      for i in range(self.n_routed_experts)])
        self.shared_experts = Expert(ndim, hiddim)
    def forward(self, x, taskemb) -> torch.Tensor:
        shape = x.size()
        x = x.view(-1, self.ndim)
        #weights, indices = self.gate(taskemb.view(-1, self.task_dim))
        #weights = weights.expand([x.shape[0],weights.shape[-1]])
        #indices= indices.expand([x.shape[0],indices.shape[-1]])
        weights, indices = self.gate(x)
        
        y = torch.zeros_like(x)
        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
 
        for i in range(self.experts_start_idx, self.experts_end_idx):
            if counts[i] == 0:
                continue
            expert = self.experts[i]
            idx, top = torch.where(indices == i) 
            y[idx] += expert(x[idx]) * weights[idx, top, None]
        z = self.shared_experts(x)
        return (y + z).view(shape)


 
class CrossAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
        self.dropout = nn.Dropout(dropout)
    

    def forward(self, query, context):
   
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_kv = context.size(0), context.size(1), context.size(1) 
        query = query.expand(context.size(0), context.size(1), query.size(-1))
        q = self.w_qs(query).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(context).view(sz_b, len_kv, n_head, d_k)
        v = self.w_vs(context).view(sz_b, len_kv, n_head, d_v) 

        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 

        q, attn = self.attention(q, k, v)  

        out = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        out = self.fc(out) 
        #out = self.dropout(self.fc(out))

        return out 

class CrossAttention4(nn.Module):
    def __init__(self, query_dim, context_dim, n_head=8, d_k=64, d_v=64, dropout=0.1):
        super().__init__()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        inner_dim = n_head * d_k

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)  # 适配 query 维度
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)  # 适配 context 维度
        self.to_out = nn.Linear(inner_dim, context_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

        self.scale = d_k ** -0.5  # 缩放因子

    def forward(self, query, context):
        """
        query: [batch, 1, query_dim]  (b)
        context: [batch, seq_len, context_dim] (a)
        """
        bsz, seq_len, _ = context.shape
        h = self.n_head
        query=query.expand(bsz, seq_len,query.shape[-1])

  
        q = self.to_q(query)  
        k, v = self.to_kv(context).chunk(2, dim=-1) 

        
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
 
        attn = einsum("b i d, b j d -> b i j", q, k) * self.scale   
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
 
        out = einsum("b i j, b j d -> b i d", attn, v)  
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)   

        return self.to_out(out) 



class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)  # position-wise
        self.w_2 = nn.Linear(d_hid, d_in)  # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)

        return x
 
 


class PositionalEncoding1D(nn.Module):
    def __init__(
        self,
        max_length: int,
        embed_dim: int
    ):
        super().__init__()
        self.max_length = max_length
        self.embed_dim = embed_dim

        self.pos_emb = nn.Embedding(self.max_length, embed_dim)

    def forward(self, feat):
        pos_emb = self.pos_emb(torch.arange(self.max_length, device=feat.device))
        pos_emb = repeat(pos_emb, "L D -> B L D", B=feat.shape[0])

        feat = feat + pos_emb[:, :feat.shape[1], :]
        return feat

    def forward_with_position(self, feat, position):
        assert feat.shape[1] == 1
        pos_emb = self.pos_emb(torch.arange(self.max_length, device=feat.device))
        pos_emb = repeat(pos_emb, "L D -> B L D", B=feat.shape[0])

        feat = feat + pos_emb[:, position:position+1, :]
        return feat