import torch
from torch import nn
from d2l import torch as d2l

def getPi(mask, another_encryption=False):
    dim=mask.shape[0]
    
    p=torch.zeros([dim,dim],dtype=torch.float)
    for i in range(dim):
        p[i][mask[i]]=1
    #eye
    #p=torch.eye(dim,dtype=torch.float)
    ip = torch.linalg.inv (p)    
    return p,ip
def getPi_Random(dim=197):
    mask = torch.randperm(dim)
    p,ip=getPi(mask,False)
    #self.p,self.ip=self.p.to("cuda"),self.ip.to("cuda")
    return p,ip
def getPi_M(dim = 197, bs = 32):
    pi=torch.eye(dim,dtype=torch.float)
    stack = []
    for i in range(bs):
        mask = torch.randperm(dim)
        stack.append(pi[mask]) 
    p = torch.stack(stack,dim = 0)
    ip = torch.transpose(p,1,2)
    return p,ip

def get_tokens_and_segments(tokens_a, tokens_b=None):

    tokens = ['<cls>'] + tokens_a + ['<sep>']

    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

#
P,IP=getPi_Random(128)
P,IP = P.to('cuda'),IP.to('cuda')
    
class MultiHeadAttention(nn.Module):
    
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        
        self.attend = nn.Softmax(dim = -1)
        
        self.scale = key_size ** -0.5
        
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens, R=False):
        

        q = self.W_q(queries)
        k = self.W_k(keys)
        v = self.W_v(values)
        
        dots = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        attn = self.attend(dots)
        
        out = torch.matmul(attn, v)
        out = self.W_o(out)

        return out


class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
        

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))


class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)


class EncoderBlock(nn.Module):

    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout,
            use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(
            ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens, R = False):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens, R))
        return self.addnorm2(Y, self.ffn(Y))
        
class BERTEncoder(nn.Module):

    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f"{i}", EncoderBlock(
                key_size, query_size, value_size, num_hiddens, norm_shape,
                ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
       
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
                                                      num_hiddens))
        #Shuffle keys
        self.m = torch.randperm(128)
        self.p,self.ip=getPi(self.m,False)
        self.p,self.ip=self.p.to("cuda"),self.ip.to("cuda")
        #self.R = R
        

    def forward(self, tokens, segments, valid_lens,R = False):
        
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        
        self.p,self.ip = getPi_Random(128)
        self.p,self.ip=self.p.to("cuda"),self.ip.to("cuda")
        
        
        
        if R:
            X=torch.matmul(self.p,X)
        
        for blk in self.blks:
            X = blk(X, valid_lens,R = R)
        if R:
            X=torch.matmul(self.ip,X)
        return X
        
        

class MaskLM(nn.Module):

    def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.Linear(num_hiddens, vocab_size))

    def forward(self, X, pred_positions):
        num_pred_positions = pred_positions.shape[1]
        pred_positions = pred_positions.reshape(-1)
        batch_size = X.shape[0]
        batch_idx = torch.arange(0, batch_size)

        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
        masked_X = X[batch_idx, pred_positions]
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat
        

class NextSentencePred(nn.Module):

    def __init__(self, num_inputs, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)

    def forward(self, X):

        return self.output(X)
        

class BERTModel(nn.Module):

    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,
                    ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                    dropout, max_len=max_len, key_size=key_size,
                    query_size=query_size, value_size=value_size)
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None,
                pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat