import torch
from torch import nn
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
import math
import torch.nn.functional as F
from enum import IntEnum
import numpy as np
from .que_base_model import QueBaseModel,QueEmb
from pykt.utils import debug_print

import os 
import json 

class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

class QueEmbedder(nn.Module):
    # This module decides 1) whether to read or initialize embeddings
    # 2) freeze or train embeddings 
    # 3) apply dim. reduction or not (via one linear layer.)
    def __init__(self, num_q, emb_size, emb_path, flag_load_emb, flag_emb_freezed, model_name):
        super().__init__()
        """
        Input:
            num_q: number of questions
            emb_size: size of embeddings (if different from provided embeddings, they will be cast to emb_size)
            emb_path: path of embeddings to be read from
            flag_load_emb: if TRUE, embeddings will be loaded from the path.
            flag_emb_freezed: if TRUE, embeddings will be fixed, i.e. won't be trained
            model_name: the name of original algorithm that calls this class (mostly for debugging purposes.)
        """
        self.num_q = num_q
        self.emb_size = emb_size
        self.emb_path = emb_path
        self.flag_load_emb = flag_load_emb
        self.flag_emb_freezed = flag_emb_freezed
        self.model_name = model_name

        # After initializing the embedding layer, this value can change, which signals the need of linear projection.
        self.loaded_emb_size = emb_size

        # Initialize embedding layer
        self.init_embedding_layer()

        if self.emb_size != self.loaded_emb_size:
            debug_print(f"Loaded embeddings' size is different than provided emb size. Linear layer will be applied.",fuc_name=self.model_name)
            self.projection_layer = nn.Linear(self.loaded_emb_size, self.emb_size)

        # For debug, count number of trainable params
        num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        debug_print(f"Que Embedder num trainable parameters: {num_params}",fuc_name=self.model_name)

    def init_embedding_layer(self):
        if self.emb_path == '' or not self.flag_load_emb:
            # If flag_emb_freezed is True, these will be fixed, i.e. no grad applied.
            if self.flag_emb_freezed:
                debug_print(f"Embeddings are randomly initialized and freezed",fuc_name=self.model_name)
                self.que_emb = nn.Embedding(self.num_q, self.emb_size, _freeze=True)
                # If flag_emb_freezed is False, the grad decent will be applied as usual
            else:
                debug_print(f"Embeddings are randomly initialized and not freezed",fuc_name=self.model_name)
                self.que_emb = nn.Embedding(self.num_q, self.emb_size, _freeze=False)

        # If path is not empty, and coming from inder KC pipeline
        # elif 'infer_kc' in self.emb_path and self.flag_load_emb:
        elif self.flag_load_emb:
            with open(self.emb_path, 'r') as f:
                precomputed_embeddings = json.load(f)
            precomputed_embeddings_tensor = torch.tensor([precomputed_embeddings[str(i)] for i in range(len(precomputed_embeddings))], dtype=torch.float)

            # IMPORTANT:
            # emb_size should be changed based on the loaded embeddings!
            num_q_precomputed, self.loaded_emb_size = precomputed_embeddings_tensor.shape # (Num questions x emb size)

            assert self.num_q == num_q_precomputed

            # For debug
            orig_norm = precomputed_embeddings_tensor[0].norm()
            debug_print(f"The original norm of the embeddings provided is {orig_norm} .",fuc_name=self.model_name)

            # Normalize the lengths to 1, for convenience.
            norms = precomputed_embeddings_tensor.norm(p=2, dim=1, keepdim=True)
            precomputed_embeddings_tensor = precomputed_embeddings_tensor/norms

            # Now scale to expected size.
            precomputed_embeddings_tensor = precomputed_embeddings_tensor * np.sqrt(self.loaded_emb_size)

            # For debug
            cur_norm = precomputed_embeddings_tensor[0].norm()
            debug_print(f"The norm of the embeddings are now scaled to {cur_norm} .",fuc_name=self.model_name)

            # If flag_emb_freezed is True, these will be fixed, i.e. no grad applied.
            if self.flag_emb_freezed:
                debug_print(f"Embeddings are loaded from path and freezed",fuc_name=self.model_name)
                self.que_emb = nn.Embedding.from_pretrained(precomputed_embeddings_tensor, freeze=True)
                # If flag_emb_freezed is False, the grad decent will be applied as usual
            else:
                debug_print(f"Embeddings are loaded from path and not freezed",fuc_name=self.model_name)
                self.que_emb = nn.Embedding.from_pretrained(precomputed_embeddings_tensor, freeze=False)

        else:
            self.que_emb = nn.Embedding(self.num_q, self.emb_size)
            print("Not using the provided path " + emb_path)

    def forward(self, q):
        # It just takes question ids and return (projected) embeddings
        x = self.que_emb(q)
        if self.emb_size != self.loaded_emb_size:
            x = self.projection_layer(x)
        return x


class AKTQueNet(nn.Module):
    def __init__(self, num_q, num_c, emb_size, n_blocks, dropout, d_ff=256, 
            kq_same=1, final_fc_dim=512, num_attn_heads=8, separate_qa=False, l2=1e-5, emb_type="qid", emb_path="", flag_load_emb=False, flag_emb_freezed=False, pretrain_dim=768):
        super().__init__()
        """
        Input:
            d_model: dimension of attention block
            final_fc_dim: dimension of final fully connected net before prediction
            num_attn_heads: number of heads in multi-headed attention
            d_ff : dimension for fully conntected net inside the basic block
            kq_same: if key query same, kq_same=1, else = 0
        """
        self.model_name = "akt_que"
        self.num_c = num_c
        self.dropout = dropout
        self.kq_same = kq_same
        self.num_q = num_q
        self.l2 = l2
        self.model_type = self.model_name
        self.separate_qa = separate_qa
        self.emb_type = emb_type
        self.emb_size = emb_size

        # if emb_type.startswith("qid"):
        #     # If path is empty, or flag_load_emb is False, initialize random Embeddings
        #     if emb_path == '' or not flag_load_emb:
        #         # If flag_emb_freezed is True, these will be fixed, i.e. no grad applied.
        #         if flag_emb_freezed:
        #             debug_print(f"Embeddings are randomly initialized and freezed",fuc_name="AKTQue")
        #             self.que_emb = nn.Embedding(self.num_q, self.emb_size, _freeze=True)
        #          # If flag_emb_freezed is False, the grad decent will be applied as usual
        #         else:
        #             debug_print(f"Embeddings are randomly initialized and not freezed",fuc_name="AKTQue")
        #             self.que_emb = nn.Embedding(self.num_q, self.emb_size)

        #     # If path is not empty, and coming from inder KC pipeline
        #     elif 'infer_kc' in emb_path and flag_load_emb:
        #         with open(emb_path, 'r') as f:
        #             precomputed_embeddings = json.load(f)
        #         precomputed_embeddings_tensor = torch.tensor([precomputed_embeddings[str(i)] for i in range(len(precomputed_embeddings))], dtype=torch.float)

        #         # IMPORTANT:
        #         # emb_size should be changed based on the loaded embeddings!
        #         num_q_precomputed, self.emb_size = precomputed_embeddings_tensor.shape # (Num questions x emb size)

        #         assert self.num_q == num_q_precomputed

        #         # If flag_emb_freezed is True, these will be fixed, i.e. no grad applied.
        #         if flag_emb_freezed:
        #             debug_print(f"Embeddings are loaded from path and freezed",fuc_name="AKTQue")
        #             self.que_emb = nn.Embedding.from_pretrained(precomputed_embeddings_tensor, freeze=True)
        #          # If flag_emb_freezed is False, the grad decent will be applied as usual
        #         else:
        #             debug_print(f"Embeddings are loaded from path and not freezed",fuc_name="AKTQue")
        #             self.que_emb = nn.Embedding.from_pretrained(precomputed_embeddings_tensor)

        #     else:
        #         self.que_emb = nn.Embedding(self.num_q, self.emb_size)
        #         print("Not using the provided path " + emb_path)

        self.que_emb = QueEmbedder(num_q, emb_size, emb_path, flag_load_emb, flag_emb_freezed, self.model_name)


        # embed_l = d_model
        if self.num_q > 0:
            # YO NOTE: Below difficulty param could be problematic (It's just a random value from embeddings).
            # It could be better to initialize with some more informed parameter values. 
            self.difficult_param = nn.Embedding(self.num_q, 1) # 题目难度
            self.q_embed_diff = nn.Embedding(self.num_q, self.emb_size) # question emb, 总结了包含当前question（concept）的problems（questions）的变化
            self.qa_embed_diff = nn.Embedding(2 * self.num_q, self.emb_size) # interaction emb, 同上
        
        if self.separate_qa: 
            self.qa_embed = nn.Embedding(2*self.num_c, self.emb_size) # interaction emb
        else: # false default
            self.qa_embed = nn.Embedding(2, self.emb_size)

        # self.que_emb = QueEmb(num_q=num_q,num_c=num_c,emb_size=emb_size, model_name=self.model_name, emb_type=emb_type,device=device,
        #                      emb_path=emb_path,pretrain_dim=pretrain_dim)
        # Architecture Object. It contains stack of attention block
        self.model = Architecture(num_q=num_q, n_blocks=n_blocks, n_heads=num_attn_heads, dropout=dropout,
                                    d_model=self.emb_size, d_feature=self.emb_size / num_attn_heads, d_ff=d_ff,  kq_same=self.kq_same, model_type=self.model_type)

        self.out = nn.Sequential(
            nn.Linear(self.emb_size + self.emb_size,
                      final_fc_dim), nn.ReLU(), nn.Dropout(self.dropout),
            nn.Linear(final_fc_dim, 256), nn.ReLU(
            ), nn.Dropout(self.dropout),
            nn.Linear(256, 1)
        )
        self.reset()

    def reset(self):
        for p in self.parameters():
            if p.size(0) == self.num_q+1 and self.num_q > 0:
                torch.nn.init.constant_(p, 0.)

    def base_emb(self, q, c, r):
        q_embed_data = self.que_emb(q)  # BS, seqlen,  d_model# c_ct
        if self.separate_qa:
            qa_data = q + self.num_q * r
            qa_embed_data = self.qa_embed(qa_data)
        else:
            # BS, seqlen, d_model # c_ct+ g_rt =e_(ct,rt)
            qa_embed_data = self.qa_embed(r)+q_embed_data
        return q_embed_data, qa_embed_data

    # def forward(self, q_data, target, pid_data=None, qtest=False):
    def forward(self, q, c, r):
        # Batch First
        q_embed_data, qa_embed_data = self.base_emb(q,c,r)

        # if self.num_q > 0: # have problem id
        #     YO: By commenting out below, we eliminate the effect of `difficulty` parameters
        #     These are useful when model works on KC level data, and yet it uses question-level difficulty differences
        #     (between questions having the same KCs).


        #     q_embed_diff_data = self.q_embed_diff(q)  # d_ct 总结了包含当前question（concept）的problems（questions）的变化
        #     pid_embed_data = self.difficult_param(q)  # uq 当前problem的难度
        #     q_embed_data = q_embed_data + pid_embed_data * \
        #         q_embed_diff_data  # uq *d_ct + c_ct # question encoder

        #     qa_embed_diff_data = self.qa_embed_diff(
        #         r)  # f_(ct,rt) or #h_rt (qt, rt)差异向量
        #     if self.separate_qa:
        #         qa_embed_data = qa_embed_data + pid_embed_data * \
        #             qa_embed_diff_data  # uq* f_(ct,rt) + e_(ct,rt)
        #     else:
        #         qa_embed_data = qa_embed_data + pid_embed_data * \
        #             (qa_embed_diff_data+q_embed_diff_data)  # + uq *(h_rt+d_ct) # （q-response emb diff + question emb diff）
        #     c_reg_loss = (pid_embed_data ** 2.).sum() * self.l2 # rasch部分loss
        # else:
        #     c_reg_loss = 0.

        # REMOVE below if you unblock above comment
        c_reg_loss = 0.

        # BS.seqlen,d_model
        # Pass to the decoder
        # output shape BS,seqlen,d_model or d_model//2
        d_output = self.model(q_embed_data, qa_embed_data)

        concat_q = torch.cat([d_output, q_embed_data], dim=-1)
        output = self.out(concat_q).squeeze(-1)
        m = nn.Sigmoid()
        preds = m(output)
       
        return preds, c_reg_loss
       



class AKTQue(QueBaseModel):
    def __init__(self, num_q,num_c, emb_size,n_blocks=1, dropout=0.1, emb_type='qid',kq_same=1, final_fc_dim=512, num_attn_heads=8, separate_qa=False, l2=1e-5,d_ff=256,emb_path="", flag_load_emb=False, flag_emb_freezed=False,  pretrain_dim=768,device='cpu',seed=0, **kwargs):
        model_name = "akt_que"
        super().__init__(model_name=model_name,emb_type=emb_type,emb_path=emb_path,pretrain_dim=pretrain_dim,device=device,seed=seed)
        self.model = AKTQueNet(num_q=num_q, num_c=num_c, emb_size=emb_size, n_blocks=n_blocks, dropout=dropout, d_ff=d_ff, 
            kq_same=kq_same, final_fc_dim=final_fc_dim, num_attn_heads=num_attn_heads, separate_qa=separate_qa, 
            l2=l2, emb_type=emb_type, emb_path=emb_path, flag_load_emb=flag_load_emb, flag_emb_freezed=flag_emb_freezed, pretrain_dim=pretrain_dim)
        self.model = self.model.to(device)
        self.emb_type = self.model.emb_type
        self.loss_func = self._get_loss_func("binary_crossentropy")
    
    def train_one_step(self, data, process=True, weighted_loss=0):
        y,reg_loss,data_new = self.predict_one_step(data,return_details=True, process=process)
        loss = self.get_loss(y,data_new['rshft'],data_new['sm'], weighted_loss=weighted_loss)#get loss
        #print(f"reg_loss is {reg_loss}")
        loss = loss+reg_loss
        return y,loss


    def predict_one_step(self,data,return_details=False, process=True):
        data_new = self.batch_to_device(data, process=process)
        # q, c, r, t, qshft, cshft, rshft, tshft, m, sm, cq, cc, cr, ct = self.batch_to_device(data)
        y, reg_loss = self.model(data_new['cq'].long(),data_new['cc'].long(),data_new['cr'].long())
        y = y[:,1:]
        if return_details:
            return y,reg_loss,data_new
        else:
            return y

class Architecture(nn.Module):
    def __init__(self, num_q,  n_blocks, d_model, d_feature,
                 d_ff, n_heads, dropout, kq_same, model_type):
        super().__init__()
        """
            n_block : number of stacked blocks in the attention
            d_model : dimension of attention input/output
            d_feature : dimension of input in each of the multi-head attention part.
            n_head : number of heads. n_heads*d_feature = d_model
        """
        self.d_model = d_model
        self.model_type = model_type

        if model_type in {'akt', 'akt_que', 'akt_que'}:
            self.blocks_1 = nn.ModuleList([
                TransformerLayer(d_model=d_model, d_feature=d_model // n_heads,
                                 d_ff=d_ff, dropout=dropout, n_heads=n_heads, kq_same=kq_same)
                for _ in range(n_blocks)
            ])
            self.blocks_2 = nn.ModuleList([
                TransformerLayer(d_model=d_model, d_feature=d_model // n_heads,
                                 d_ff=d_ff, dropout=dropout, n_heads=n_heads, kq_same=kq_same)
                for _ in range(n_blocks*2)
            ])

    def forward(self, q_embed_data, qa_embed_data):
        # target shape  bs, seqlen
        seqlen, batch_size = q_embed_data.size(1), q_embed_data.size(0)

        qa_pos_embed = qa_embed_data
        q_pos_embed = q_embed_data

        y = qa_pos_embed
        seqlen, batch_size = y.size(1), y.size(0)
        x = q_pos_embed

        # encoder
        for block in self.blocks_1:  # encode qas, 对0～t-1时刻前的qa信息进行编码
            y = block(mask=1, query=y, key=y, values=y) # yt^
        flag_first = True
        for block in self.blocks_2:
            if flag_first:  # peek current question
                x = block(mask=1, query=x, key=x,
                          values=x, apply_pos=False) # False: 没有FFN, 第一层只有self attention, 对应于xt^
                flag_first = False
            else:  # dont peek current response
                x = block(mask=0, query=x, key=x, values=y, apply_pos=True) # True: +FFN+残差+laynorm 非第一层与0~t-1的的q的attention, 对应图中Knowledge Retriever
                # mask=0，不能看到当前的response, 在Knowledge Retrever的value全为0，因此，实现了第一题只有question信息，无qa信息的目的
                # print(x[0,0,:])
                flag_first = True
        return x

class TransformerLayer(nn.Module):
    def __init__(self, d_model, d_feature,
                 d_ff, n_heads, dropout,  kq_same):
        super().__init__()
        """
            This is a Basic Block of Transformer paper. It containts one Multi-head attention object. Followed by layer norm and postion wise feedforward net and dropout layer.
        """
        kq_same = kq_same == 1
        # Multi-Head Attention Block
        self.masked_attn_head = MultiHeadAttention(
            d_model, d_feature, n_heads, dropout, kq_same=kq_same)

        # Two layer norm layer and two droput layer
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        self.linear1 = nn.Linear(d_model, d_ff)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, mask, query, key, values, apply_pos=True):
        """
        Input:
            block : object of type BasicBlock(nn.Module). It contains masked_attn_head objects which is of type MultiHeadAttention(nn.Module).
            mask : 0 means, it can peek only past values. 1 means, block can peek only current and pas values
            query : Query. In transformer paper it is the input for both encoder and decoder
            key : Keys. In transformer paper it is the input for both encoder and decoder
            Values. In transformer paper it is the input for encoder and  encoded output for decoder (in masked attention part)

        Output:
            query: Input gets changed over the layer and returned.

        """

        seqlen, batch_size = query.size(1), query.size(0)
        nopeek_mask = np.triu(
            np.ones((1, 1, seqlen, seqlen)), k=mask).astype('uint8')
        src_mask = (torch.from_numpy(nopeek_mask) == 0).to(query.device)
        if mask == 0:  # If 0, zero-padding is needed.
            # Calls block.masked_attn_head.forward() method
            query2 = self.masked_attn_head(
                query, key, values, mask=src_mask, zero_pad=True) # 只能看到之前的信息，当前的信息也看不到，此时会把第一行score全置0，表示第一道题看不到历史的interaction信息，第一题attn之后，对应value全0
        else:
            # Calls block.masked_attn_head.forward() method
            query2 = self.masked_attn_head(
                query, key, values, mask=src_mask, zero_pad=False)

        query = query + self.dropout1((query2)) # 残差1
        query = self.layer_norm1(query) # layer norm
        if apply_pos:
            query2 = self.linear2(self.dropout( # FFN
                self.activation(self.linear1(query))))
            query = query + self.dropout2((query2)) # 残差
            query = self.layer_norm2(query) # lay norm
        return query


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_feature, n_heads, dropout, kq_same, bias=True):
        super().__init__()
        """
        It has projection layer for getting keys, queries and values. Followed by attention and a connected layer.
        """
        self.d_model = d_model
        self.d_k = d_feature
        self.h = n_heads
        self.kq_same = kq_same

        self.v_linear = nn.Linear(d_model, d_model, bias=bias)
        self.k_linear = nn.Linear(d_model, d_model, bias=bias)
        if kq_same is False:
            self.q_linear = nn.Linear(d_model, d_model, bias=bias)
        self.dropout = nn.Dropout(dropout)
        self.proj_bias = bias
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)
        self.gammas = nn.Parameter(torch.zeros(n_heads, 1, 1))
        torch.nn.init.xavier_uniform_(self.gammas)

        self._reset_parameters()

    def _reset_parameters(self):
        xavier_uniform_(self.k_linear.weight)
        xavier_uniform_(self.v_linear.weight)
        if self.kq_same is False:
            xavier_uniform_(self.q_linear.weight)

        if self.proj_bias:
            constant_(self.k_linear.bias, 0.)
            constant_(self.v_linear.bias, 0.)
            if self.kq_same is False:
                constant_(self.q_linear.bias, 0.)
            constant_(self.out_proj.bias, 0.)

    def forward(self, q, k, v, mask, zero_pad):

        bs = q.size(0)

        # perform linear operation and split into h heads

        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        if self.kq_same is False:
            q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        else:
            q = self.k_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        # transpose to get dimensions bs * h * sl * d_model

        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # calculate attention using function we will define next
        gammas = self.gammas
        scores = attention(q, k, v, self.d_k,
                           mask, self.dropout, zero_pad, gammas)

        # concatenate heads and put through final linear layer
        concat = scores.transpose(1, 2).contiguous()\
            .view(bs, -1, self.d_model)

        output = self.out_proj(concat)

        return output


def attention(q, k, v, d_k, mask, dropout, zero_pad, gamma=None):
    """
    This is called by Multi-head atention object to find the values.
    """
    # d_k: 每一个头的dim
    scores = torch.matmul(q, k.transpose(-2, -1)) / \
        math.sqrt(d_k)  # BS, 8, seqlen, seqlen
    bs, head, seqlen = scores.size(0), scores.size(1), scores.size(2)

    x1 = torch.arange(seqlen).expand(seqlen, -1).to(q.device)
    x2 = x1.transpose(0, 1).contiguous()

    with torch.no_grad():
        scores_ = scores.masked_fill(mask == 0, -1e32)
        scores_ = F.softmax(scores_, dim=-1)  # BS,8,seqlen,seqlen
        scores_ = scores_ * mask.float().to(q.device) # 结果和上一步一样
        distcum_scores = torch.cumsum(scores_, dim=-1)  # bs, 8, sl, sl
        disttotal_scores = torch.sum(
            scores_, dim=-1, keepdim=True)  # bs, 8, sl, 1 全1
        # print(f"distotal_scores: {disttotal_scores}")
        position_effect = torch.abs(
            x1-x2)[None, None, :, :].type(torch.FloatTensor).to(q.device)  # 1, 1, seqlen, seqlen 位置差值
        # bs, 8, sl, sl positive distance
        dist_scores = torch.clamp(
            (disttotal_scores-distcum_scores)*position_effect, min=0.) # score <0 时，设置为0
        dist_scores = dist_scores.sqrt().detach()
    m = nn.Softplus()
    gamma = -1. * m(gamma).unsqueeze(0)  # 1,8,1,1 一个头一个gamma参数， 对应论文里的theta
    # Now after do exp(gamma*distance) and then clamp to 1e-5 to 1e5
    total_effect = torch.clamp(torch.clamp(
        (dist_scores*gamma).exp(), min=1e-5), max=1e5) # 对应论文公式1中的新增部分
    scores = scores * total_effect

    scores.masked_fill_(mask == 0, -1e32)
    scores = F.softmax(scores, dim=-1)  # BS,8,seqlen,seqlen
    # print(f"before zero pad scores: {scores.shape}")
    # print(zero_pad)
    if zero_pad:
        pad_zero = torch.zeros(bs, head, 1, seqlen).to(q.device)
        scores = torch.cat([pad_zero, scores[:, :, 1:, :]], dim=2) # 第一行score置0
    # print(f"after zero pad scores: {scores}")
    scores = dropout(scores)
    output = torch.matmul(scores, v)
    # import sys
    # sys.exit()
    return output


class LearnablePositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        # Compute the positional encodings once in log space.
        pe = 0.1 * torch.randn(max_len, d_model)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=True)

    def forward(self, x):
        return self.weight[:, :x.size(Dim.seq), :]  # ( 1,seq,  Feature)


class CosinePositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        # Compute the positional encodings once in log space.
        pe = 0.1 * torch.randn(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=False)

    def forward(self, x):
        return self.weight[:, :x.size(Dim.seq), :]  # ( 1,seq,  Feature)



