import torch
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from einops import repeat
from .layer import LearnableFourierPositionalEncoding, cal_bert_token, TemporalPositionalEncoding, BERT
from peft import LoraModel, LoraConfig, get_peft_model
from transformers import BertModel, BertTokenizer
import math

def LoRA_model(model):
    lora_config = LoraConfig(
            task_type="SEQ_2_SEQ_LM",
            r=8,  # Lora attention dimension.
            lora_alpha=32,  # The alpha parameter for Lora scaling.
            target_modules=["query", "value"],  # The names of the modules to apply Lora to.
            lora_dropout=0.01,  # The dropout probability for Lora layers.
        )
    # return LoraModel(model, lora_config, 'tinybert')
    return get_peft_model(model, lora_config)


class Decoder(nn.Module):
    """
    Multi-layer CRU Cell
    """

    def __init__(self, id_size, hidd_dim):
        super(Decoder, self).__init__()
        self.hidd_dim = hidd_dim
        self.id_size = id_size
        
        self.road_fc = nn.Sequential(
            nn.Linear(self.hidd_dim, self.hidd_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.hidd_dim, self.id_size + 1)
        )
        self.rate_fc = nn.Sequential(
            nn.Linear(self.hidd_dim, 1)
        )
    def forward(self, hidden, prompt_length):
        # 计算路段ID和rate
        hidden = hidden[:, prompt_length:, ]
        road_ID = F.log_softmax(self.road_fc(hidden), dim=-1)
        road_rate = torch.sigmoid(self.rate_fc(hidden))

        return road_ID, road_rate

class Traj_conv(nn.Module):
    """
    Multi-layer CRU Cell
    """

    def __init__(self, hidd_dim, kernel_size):
        super(Traj_conv, self).__init__()
        self.hidd_dim = hidd_dim
        
        self.kernel_size = kernel_size

        self.padding = (kernel_size - 1) // 2

        self.conv = nn.Conv1d(in_channels=self.hidd_dim, out_channels=self.hidd_dim, kernel_size=self.kernel_size, padding=self.padding)
        
        
    def forward(self, x):
        # x: B, T, F
        x = x.permute(0, 2, 1)
        traj_x = self.conv(x)

        return traj_x.permute(0, 2, 1)
class Prompt_get(nn.Module):
    """
    Multi-layer CRU Cell
    """

    def __init__(self):
        super(Prompt_get, self).__init__()
        MODEL_PATH = '/data/WeiTongLong/code/llm/BERT/BERT-small' # 装着上面3个文件的文件夹位置
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=MODEL_PATH)
        
        # self.indexed_tokens = self.tokenizer.encode(text)  # 对文本编码
    def forward(self, keep_ratio):
        # x: B, T, F
        if keep_ratio == 0.25: target_sec = 60
        if keep_ratio == 0.125: target_sec = 120
        if keep_ratio == 0.0625: target_sec = 240
        
        text = "Task: The sparse trajectory recovery. Target: Output each trajectory point's road segment and moving ratio. Content: The sparse trajectory is sampled every 15 seconds and aims to recover trajectory every {} seconds. The observed trajectory is: ".format(target_sec)
        
        indexed_tokens = self.tokenizer.encode(text)
        tokens_tensor = torch.tensor([indexed_tokens])  # 文本编码转tensor
        
class ReprogrammingLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_keys=None, attention_dropout=0.1):
        super(ReprogrammingLayer, self).__init__()

        d_keys = d_model // n_heads

        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_keys * n_heads)
        self.out_projection = nn.Linear(d_keys * n_heads, d_model)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, target_embedding, source_embedding, value_embedding):
        B, L, _ = target_embedding.shape
        S, _ = source_embedding.shape
        H = self.n_heads

        target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
        source_embedding = self.key_projection(source_embedding).view(S, H, -1)
        value_embedding = self.value_projection(value_embedding).view(S, H, -1)

        out = self.reprogramming(target_embedding, source_embedding, value_embedding)

        out = out.reshape(B, L, -1)

        return self.out_projection(out)

    def reprogramming(self, target_embedding, source_embedding, value_embedding):
        B, L, H, E = target_embedding.shape
        from math import sqrt

        scale = 1. / sqrt(E)

        scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)

        return reprogramming_embedding

class spatialTemporalConv(nn.Module):
    def __init__(self, in_channel, base_channel):
        super(spatialTemporalConv, self).__init__()
        self.start_conv = nn.Conv2d(in_channel, base_channel, 1, 1, 0)
        self.spatial_conv = nn.Sequential(
            nn.Conv2d(base_channel, base_channel, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channel, base_channel, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(base_channel)
        )
        self.temporal_conv = nn.Sequential(
            nn.Conv1d(base_channel, base_channel, 3, 1, 1),
            nn.ReLU(inplace=True)
        )
    def forward(self, road_condition):
        #road_condition: T, N, N
        T, N, N = road_condition.shape
        # print(road_condition.unsqueeze(-1).shape)
        _start = self.start_conv(road_condition.unsqueeze(1))  # T,1,N,N
        spatialConv = self.spatial_conv(_start)  #T,F,N,N
        spatial_reshape = spatialConv.reshape(T, -1, N*N).permute(2, 1, 0) # N*N,F,T
        temporalConv = self.temporal_conv(spatial_reshape)
        conv_res = temporalConv.reshape(N, N, -1, T).permute(3, 2, 0, 1)  # T,F,N,N
        # print((_start + conv_res).shape)
        return (_start + conv_res).permute(0, 2, 3, 1)


class PTR(nn.Module):
    """
    Multi-layer CRU Cell
    """

    def __init__(self, hidd_dim, id_size, device, road_candi, conv_kernel, soft_traj_num, drop_out=0.3):
        super(PTR, self).__init__()
        self.hidd_dim = hidd_dim
        self.id_size = id_size
        self.drop_out = drop_out
        self.device = device
        self.LFF = LearnableFourierPositionalEncoding(2, self.hidd_dim)
        self.input_embed = nn.Linear(2, self.hidd_dim)
        self.ST_conv = spatialTemporalConv(1, self.hidd_dim)
        
        self.road_candi = road_candi
        self.global_input_protext = nn.Linear(self.hidd_dim, self.hidd_dim)

        self.local_input_protext = nn.Linear(self.hidd_dim, self.hidd_dim)


        self.MASK_token = cal_bert_token(token="MASK").to(self.device)
        self.PAD_token = cal_bert_token(token="PAD").to(self.device)

        self.road_embed = nn.Parameter(torch.randn(self.id_size, self.hidd_dim).to(self.device), requires_grad=True).to(self.device)
        self.position_embed = TemporalPositionalEncoding(self.hidd_dim, self.drop_out)
        if self.road_candi:
            self.input_layer = nn.Linear(self.hidd_dim * 2, self.hidd_dim)

        self.prompt_layer = nn.Sequential(
                nn.Linear(self.hidd_dim, self.hidd_dim),
                nn.ReLU(),
                nn.Linear(self.hidd_dim, self.hidd_dim)
        )
        self.Traj_conv = Traj_conv(self.hidd_dim, conv_kernel)

        self.forward_delay = nn.Linear(1, self.hidd_dim)
        self.backward_delay = nn.Linear(1, self.hidd_dim)

        self.local_cat_layer = nn.Linear(self.hidd_dim * 2, self.hidd_dim)

        # self.prompt = Prompt_get()
        # self.bert = LoRA_model(BERT().model)
        self.bert = BERT()
        self.Decoder = Decoder(self.id_size, self.hidd_dim)
        self.soft_traj_num = soft_traj_num

        self.learnable_mask_token = nn.Parameter(torch.randn(1, self.hidd_dim).to(self.device), requires_grad=True).to(self.device)

        self.soft_traj_prompt = nn.Parameter(torch.randn(soft_traj_num, self.hidd_dim).to(self.device), requires_grad=True).to(self.device)
        
        self.ReprogrammingLayer = ReprogrammingLayer(self.hidd_dim, 8)
        self.reset_parameters()
        self.ReprogrammingLayer_cat = nn.Linear(self.hidd_dim * 2, self.hidd_dim)
        self.time_prompt_embed = nn.Linear(2, self.hidd_dim)

        self.road_condition_merge = nn.Sequential(
            nn.Linear(self.hidd_dim * 2, self.hidd_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.hidd_dim, self.hidd_dim)
        )
        self.road_out = nn.Linear(self.hidd_dim, self.hidd_dim)

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.soft_traj_prompt.shape[1])
        self.soft_traj_prompt.data.uniform_(-stdv, stdv)

        stdv1 = 1. / math.sqrt(self.learnable_mask_token.shape[1])
        self.learnable_mask_token.data.uniform_(-stdv1, stdv1)


        

    def GPS_road_embed(self, src_lat, src_lng, mask_index, padd_index, src_candi_id, learned_mask_prompt=None):
        src_data = torch.cat((src_lat.unsqueeze(-1), src_lng.unsqueeze(-1)), -1)
        lff = self.LFF(src_data)  # 对已知点进行编码

        B, T, _ = lff.shape
        # 经纬度编码，使用Learnable Fourier Features
        
        #将已知点进行空间转换，和mask对应的表示对齐
        src_gps_hidden = self.global_input_protext(lff) 
        # 得到了观测点的embedding，接下来对于未知点, mask_index位置为1的地方，使用BERT中的[mask] token替换


        # exit()

        if learned_mask_prompt is not None:
            src_gps_hidden[mask_index==1] = learned_mask_prompt[mask_index==1]
        else:
            src_gps_hidden[mask_index==1] = self.MASK_token
        # padd_index = padd_index.repeat(1, 1, self.hidd_dim)
        
        src_gps_hidden[padd_index==1] = self.PAD_token

        # 上面只处理了经纬度，接下来需要处理路段ID和rate，对于所有的元素，输入都是未知的，都用[mask] token替换
        src_road_id = self.MASK_token.unsqueeze(0).repeat(B, T, 1)
        src_rate = self.MASK_token.unsqueeze(0).repeat(B, T, 1)

        # 接下来，对于观测到的GPS点，他们有周围的路网信息，需要对路网信息进行加权，被mask的点，他们的路网信息是空的
        # print(0000000)
        src_input = src_gps_hidden
        
        if self.road_candi:
            # src_road_canid = torch.einsum('btr,rd->btd',(src_candi_id,self.road_embed)) #torch.mm(support, x0), B, T, D
            src_road_canid = torch.matmul(src_candi_id, self.road_embed)
            candi_road = src_candi_id.sum(2).unsqueeze(-1)

            src_road_canid = src_road_canid / (candi_road + 1e-6) #加上路网周围的信息

            # 将mask的点使用mask token替换
            if learned_mask_prompt is not None:
                src_road_canid[mask_index==1] = learned_mask_prompt[mask_index==1]
            else:
                src_road_canid[mask_index==1] = self.MASK_token

            src_road_canid[padd_index==1] = self.PAD_token
            # 将两类信息进行平均
            src_input = self.input_layer(torch.cat((src_gps_hidden, src_road_canid), -1))

            src_input = self.Traj_conv(src_input)
        return src_input
    
    # def global_model(self, src_input)

    def local_model(self, src_input, forward_delta_t, backward_delta_t, forward_index, backward_index,mask_index, padd_index):
        B,T,C = src_input.shape

        forward_lff = src_input
        backward_lff = src_input
        
        # 获取前向和后向的embedding。
        mask = mask_index.bool()
        forward_lff[mask] = forward_lff[torch.arange(B)[:, None], forward_index][mask]
        backward_lff[mask] = backward_lff[torch.arange(B)[:, None], backward_index][mask]
        

        # for i in range(B):
        #     forward_lff[i][mask_index[i]==1] = forward_lff[i][forward_index[i][mask_index[i]==1]]
        #     backward_lff[i][mask_index[i]==1] = backward_lff[i][backward_index[i][mask_index[i]==1]]
            

        forward_delta = torch.exp(- F.relu(self.forward_delay(forward_delta_t.unsqueeze(-1))))
        
        backward_delta = torch.exp(- F.relu(self.backward_delay(backward_delta_t.unsqueeze(-1))))

        local_out = (forward_delta * forward_lff + backward_delta * backward_lff)  / (forward_delta + backward_delta)


        return local_out
    
    def mask_prompt(self, road_condition, road_condition_xyt_index, forward_delta_t, backward_delta_t, forward_index, backward_index,mask_index, padd_index):
        B, T = forward_delta_t.shape
        times_embed = self.time_prompt_embed(torch.cat((forward_delta_t.unsqueeze(-1), backward_delta_t.unsqueeze(-1)), -1)) + self.learnable_mask_token
        
        road_condition_conv = self.ST_conv(road_condition)  # T, N, N, F
    
        x, y, t = road_condition_xyt_index[:, :, 0], road_condition_xyt_index[:, :, 1], road_condition_xyt_index[:, :, 2]  #此时未知节点使用0的index代替了
        trajectory_road_condition = road_condition_conv[t, x, y]  # B, T, F

        #然后对未知节点，使用前后已知点的加权得到
        forward_lff = trajectory_road_condition
        backward_lff = trajectory_road_condition
        # 获取前向和后向的embedding。
        mask = mask_index.bool()
        forward_lff[mask] = forward_lff[torch.arange(B)[:, None], forward_index][mask]
        backward_lff[mask] = backward_lff[torch.arange(B)[:, None], backward_index][mask]

        forward_delta = torch.exp(- F.relu(self.forward_delay(forward_delta_t.unsqueeze(-1))))
        
        backward_delta = torch.exp(- F.relu(self.backward_delay(backward_delta_t.unsqueeze(-1))))

        road_condition_out = (forward_delta * forward_lff + backward_delta * backward_lff) / (forward_delta + backward_delta)
        road_condition_out = self.road_out(road_condition_out)

        out = torch.cat((road_condition_out, times_embed), -1)
        out = self.road_condition_merge(out)
        # out = road_condition_out + times_embed

        return out



    def forward(self, src_lat, src_lng, src_time, mask_index, src_candi_id, traj_length, padd_index, keep_ratio, prompt_token, road_condition, road_condition_xyt_index, forward_delta_t, backward_delta_t, forward_index, backward_index, flag=False):
        """
        src_lat: batchsize * T
        src_lng: batchsize * T
        """
        B, T = src_lat.shape

        prompt_token = self.prompt_layer(prompt_token)
        
        _, prompt_length, _ = prompt_token.shape

        # 对于已知节点获取可逆傅立叶编码
        

        learned_mask_prompt = self.mask_prompt(road_condition, road_condition_xyt_index, forward_delta_t, backward_delta_t, forward_index, backward_index, mask_index, padd_index)

        src_input = self.GPS_road_embed(src_lat, src_lng, mask_index, padd_index, src_candi_id, learned_mask_prompt)

        
        src_input = self.ReprogrammingLayer(src_input, self.soft_traj_prompt, self.soft_traj_prompt)
        # src_input_reprogram = self.ReprogrammingLayer(src_input, self.soft_traj_prompt, self.soft_traj_prompt)
        # src_input = self.ReprogrammingLayer_cat(torch.cat((src_input, src_input_reprogram), -1))

        src_input = torch.cat((prompt_token, src_input), dim=1)
        


        pe = self.position_embed(src_input)

        src_input = src_input + pe
        
        # 接下来将信息送入到BERT中, 先构造padding矩阵
        _padding_mask = torch.arange(T).unsqueeze(0) < torch.tensor(traj_length).unsqueeze(1)
        _padding_mask = _padding_mask.float().to(self.device)

        prompt_mask = torch.ones((B, prompt_length)).to(self.device)
        _padding_mask = torch.cat((prompt_mask, _padding_mask), 1)
        
        
        # print(src_input.shape, _padding_mask.shape, prompt_mask.shape)
        # exit()
        bert_out = self.bert(src_input, _padding_mask)
        # print(bert_out.shape)
        # exit()
        if flag == True:
            print(bert_out)
            exit()
        outputs_id, outputs_rate = self.Decoder(bert_out, prompt_length)
        
        # 对于预测结果，超出长度的地方，对应的predict_ID,predict_rate = 0
        # max_trg_len, batch_size, trg_rid_size
        # outputs_id = outputs_id.permute(1, 0, 2)  # batch size, seq len, rid size
        # outputs_rate = outputs_rate.permute(1, 0, 2)  # batch size, seq len, 1
        # print(outputs_id.shape)
        # exit()

        outputs_id_mask = torch.arange(T)[None, :, None].to(self.device) < torch.tensor(traj_length)[:, None, None].to(self.device)
        outputs_rate_mask = torch.arange(T)[None, :, None].to(self.device) < torch.tensor(traj_length)[:, None, None].to(self.device)
        # for i in range(B):
        #     outputs_id_mask[i][traj_length[i]:, 0] = 1
        
        outputs_id = outputs_id * outputs_id_mask
        outputs_rate = outputs_rate * outputs_rate_mask
        # print(outputs_id.shape, outputs_id_mask.shape)
        # print(outputs_rate.shape, outputs_rate_mask.shape)
        # exit()
        
        # for i in range(B):
        #     # outputs_id[i][traj_length[i]:] = 0
        #     outputs_id[i][traj_length[i]:, 0] = 1  # make sure argmax will return eid0
        #     # outputs_rate[i][traj_length[i]:] = 0

        outputs_id = outputs_id.permute(1, 0, 2)
        outputs_rate = outputs_rate.permute(1, 0, 2)
        # exit()
        return outputs_id, outputs_rate



        # return predict_ID, predict_rate
        # exit()



        