import numpy as np
import torch
from torch import nn
from torch.nn import init
from math import sqrt


class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        star = time.time()
        out, attn = self.inner_attention(
            queries,
            keys,
            values,
            attn_mask,
            tau=tau,
            delta=delta
        )
        print("time: ", time.time()-star)
        out = out.view(B, L, -1)

        return self.out_projection(out), attn

class TriangularCausalMask():
    def __init__(self, B, L, device="cpu"):
        mask_shape = [B, 1, L, L]
        with torch.no_grad():
            self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)

    @property
    def mask(self):
        return self._mask

class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)

            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return V.contiguous(), A
        else:
            return V.contiguous(), None

class ExternalAttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(ExternalAttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        # self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        # self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, _ = queries.shape
        # _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        print("queries: ", queries.shape)
        # keys = self.key_projection(keys).view(B, S, H, -1)
        # values = self.value_projection(values).view(B, S, H, -1)
        star = time.time()
        out, attn = self.inner_attention(
            queries,
            keys=keys,
            values=values,
            attn_mask=attn_mask,
            tau=tau,
            delta=delta
        )
        print("time: ", time.time()-star)
        out = out.reshape(B, L, -1)

        return self.out_projection(out), attn

class ExternalAttention(nn.Module):

    def __init__(self, d_model, n_heads, S = 4, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super().__init__()
        d_model = d_model // n_heads
        self.mk = nn.Linear(d_model, S, bias=False)
        self.mv = nn.Linear(S, d_model, bias=False)
        self.init_weights()
        self.attn_drop = nn.Dropout(attention_dropout)   

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries, keys=None, values=None, attn_mask=None, tau=None, delta=None):
        # u: [bs * nvars x patch_num x d_model]
        # queries B N H D // H
        queries = queries.permute(0, 2, 1, 3) # B H N D 
        print("queries: ", queries.shape)
        attn = self.mk(queries) # B H N S
        attn = attn.softmax(dim=-2)
        attn_sum = 1e-9 + torch.sum(attn,dim=-1,keepdim=True)
        attn = attn / attn_sum # B H N S
        attn = self.attn_drop(attn)
        out = self.mv(attn) # B H N D
        out = out.permute(0, 2, 1, 3)

        return out, attn



if __name__ == '__main__':
    import time
    input=torch.randn(50,49,512).cuda()
    d_model = 512
    n_heads = 8
    S = 8
    d_ff = 4 * d_model
    ea = ExternalAttentionLayer(ExternalAttention(d_model=d_model, n_heads=n_heads, S = S, mask_flag=False, factor=0.5, attention_dropout=0.1,
                                      output_attention=False), d_model, n_heads).cuda()
    sa = AttentionLayer(FullAttention(False, factor=0.5, attention_dropout=0.1,
                                      output_attention=False), d_model, n_heads).cuda()
    # time
    start = time.time()
    output=ea(input,input,input, None)
    end = time.time()
    print("total time: ", end-start)

    start = time.time()
    output=sa(input,input,input, None)
    end = time.time()
    print("total time: ", end-start)
    