import numpy as np
import torch
import torch.nn as nn
# from .re_transformer import BertModel, BertTokenizer
from transformers import BertModel, BertTokenizer
import math
from peft import LoraModel, LoraConfig, get_peft_model

def init_weights(self):
    """
    Here we reproduce Keras default initialization weights for consistency with Keras version
    Reference: https://github.com/vonfeng/DeepMove/blob/master/codes/model.py
    """
    ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
    hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
    b = (param.data for name, param in self.named_parameters() if 'bias' in name)

    for t in ih:
        nn.init.xavier_uniform_(t)
    for t in hh:
        nn.init.orthogonal_(t)
    for t in b:
        nn.init.constant_(t, 0)

class LearnableFourierPositionalEncoding(nn.Module):
    def __init__(self, input_dim, hidd_dim):
        """
        Learnable Fourier Features from https://arxiv.org/pdf/2106.02795.pdf (Algorithm 1)
        Implementation of Algorithm 1: Compute the Fourier feature positional encoding of a multi-dimensional position
        Computes the positional encoding of a tensor of shape [N, G, M]
        :param G: positional groups (positions in different groups are independent)
        :param M: each point has a M-dimensional positional values
        :param F_dim: depth of the Fourier feature dimension
        :param H_dim: hidden layer dimension
        :param D: positional encoding dimension
        :param gamma: parameter to initialize Wr
        """
        super(LearnableFourierPositionalEncoding, self).__init__()
        self.input_dim = input_dim
        self.hidd_dim = hidd_dim
        
        self.Wr = nn.Linear(self.input_dim, self.hidd_dim // 2, bias=False)
        
        self.mlp = nn.Sequential(
            nn.Linear(self.hidd_dim, self.hidd_dim, bias=True),
            nn.GELU(),
            nn.Linear(self.hidd_dim, self.hidd_dim)
        )

    def forward(self, x):
        """
        Produce positional encodings from x
        :param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M],
                  where G is the positional group and each group has M-dimensional positional values.
                  Positions in different positional groups are independent
        :return: positional encoding for X
        """
        B, T, F = x.shape
        # Step 1. Compute Fourier features (eq. 2)
        projected = self.Wr(x)
        cosines = torch.cos(projected)
        sines = torch.sin(projected)
        F = 1 / np.sqrt(self.hidd_dim) * torch.cat([cosines, sines], dim=-1)
        # Step 2. Compute projected Fourier features (eq. 6)
        Y = self.mlp(F)
        # Step 3. Reshape to x's shape
        return Y

def cal_bert_token(token=None, keep_ratio = None):
    MODEL_PATH = '/data/WeiTongLong/code/llm/BERT/BERT-small' # 装着上面3个文件的文件夹位置
    tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=MODEL_PATH)
    
    # BERT词库中 [MASK]对应的是103
    if token == "MASK":
        tokens_tensor = torch.tensor([103])  # 文本编码转tensor
    elif token == "PAD":
        tokens_tensor = torch.tensor([0])  # 文本编码转tensor
    else:
        if keep_ratio == 0.25: target_sec = 60
        if keep_ratio == 0.125: target_sec = 120
        if keep_ratio == 0.0625: target_sec = 240
        
        text = "Task: Sparse trajectory recovery. Target: Output the road segment and movement ratio for each point in the trajectory. Content: The sparse trajectory is sampled every {} seconds and aims to recover trajectory every 15 seconds. The sparse trajectory is: ".format(target_sec)
        
        indexed_tokens = tokenizer.encode(text)
        tokens_tensor = torch.tensor([indexed_tokens])  # 文本编码转tensor
        
    model = BertModel.from_pretrained(MODEL_PATH)  # 读取预训练模型

    bert_token = model.state_dict()['embeddings.word_embeddings.weight']
    return bert_token[tokens_tensor]

class TemporalPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=2000, lookup_index=None):
        super(TemporalPositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)
        self.lookup_index = lookup_index
        self.max_len = max_len
        # computing the positional encodings once in log space
        pe = torch.zeros(max_len, d_model)
        for pos in range(max_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i+1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))

        pe = pe.unsqueeze(0)  # (1, T_max, d_model)
        self.register_buffer('pe', pe)
        # register_buffer:
        # Adds a persistent buffer to the module.
        # This is typically used to register a buffer that should not to be considered a model parameter.

    def forward(self, x):
        '''
        :param x: (batch_size, T, F_in)
        :return: (batch_size, T, F_out)
        '''
        if self.lookup_index is not None:
            x = x + self.pe[:, :, self.lookup_index, :]  # (batch_size, N, T, F_in) + (1,1,T,d_model)
        else:
            x = x + self.pe[:, :x.size(1), :]

        return self.dropout(x.detach())


class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()

        
        self.MODEL_PATH = '/data/WeiTongLong/code/llm/BERT/BERT-small' # 装着上面3个文件的文件夹位置
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=self.MODEL_PATH)
        self.model = BertModel.from_pretrained(self.MODEL_PATH)  # 读取预训练模型
        
        self.lora_model = self.LoRA_model(self.model)
    def forward(self, x, padding_mask):
        # print(x.shape)
        # encoder_hidden = self.lora_model(inputx, attention_mask=padding_mask)
        encoder_hidden = self.lora_model(inputs_embeds=x, attention_mask=padding_mask, output_hidden_states=True).hidden_states[-1]
        # encoder_hidden = self.LoRA_model(x, attention_mask=padding_mask)
        return encoder_hidden
        # return encoder_hidden[0]
        # x: B, T, F
        # MODEL_PATH = './BERT-small' # 装着上面3个文件的文件夹位置
        # tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=self.MODEL_PATH)
        # text = 'Follow their code on GitHub. Ha [mask]'
        # indexed_tokens = tokenizer.encode(text)  # 对文本编码
        # print(indexed_tokens)
        # # exit()
        # tokens_tensor = torch.tensor([indexed_tokens])  # 文本编码转tensor
        
        # outputs = self.model(tokens_tensor)
        
        # print(encoder_hidden)
    def LoRA_model(self, model):
        lora_config = LoraConfig(
                task_type="SEQ_2_SEQ_LM",
                r=8,  # Lora attention dimension.
                lora_alpha=32,  # The alpha parameter for Lora scaling.
                target_modules=["query", "value"],  # The names of the modules to apply Lora to.
                lora_dropout=0.01,  # The dropout probability for Lora layers.
            )
        return LoraModel(model, lora_config, 'tinybert')
        # return get_peft_model(model, lora_config)
# if __name__ == '__main__':
#     bert = BERT()
#     bert()