import torch.nn as nn
import torch
import math


class PositionalEmbedding(nn.Module):

    def __init__(self, args, max_len=8):
        super().__init__()

        #[1] parameter
        self.max_length = args['max_len']
        self.feature_hidden = args['embeded_all_feature_cnt']
        self.device = args["device"]
        self.subseg_type = args["client_discrete_feature"]["client_position_feature"]["subseg_type"]
        #[2] emb
        self.emb_pos_absolute = nn.Embedding(self.max_length, self.feature_hidden).to(self.device)
        #self.emb_pos_classific = nn.Embedding(self.max_length, self.feature_hidden).to(self.device)
        self.emb_subseg_type = nn.Embedding(self.subseg_type, self.feature_hidden).to(self.device)

        self.absolute_bins = torch.Tensor([50,100,150,200,250,300,400,500,600,800,1000,1500,2000,3000]).to(self.device)

        self.emb_pos_absoluteDis = nn.Embedding(self.absolute_bins.size(0) + 1, self.feature_hidden).to(self.device)
        self.emb_pos_absoluteDenseDis = nn.Embedding(2, self.feature_hidden).to(self.device)


    def forward(self, batch_size, client_feature_dict, ds_to_sub_end, padding_mask):

        #[1] absolute position
        position = torch.arange(self.max_length - 1, -1, step=-1).to(self.device)
        padding_sum = torch.sum(padding_mask, -1)
        max_gap = self.max_length - padding_sum
        max_gap = torch.unsqueeze(max_gap, -1)

        position = torch.unsqueeze(position, 0)
        position = position.repeat(batch_size,1)
        gap_position = position - max_gap

        position= torch.clamp(gap_position, 0, self.max_length)
        position_emb = self.emb_pos_absolute(position.long())

        # #[2]
        # position_norm= torch.tensor(torch.arange(0, self.max_length)).to(self.device)
        # position_norm = torch.unsqueeze(position_norm, 0)
        # position_norm = position_norm.repeat(batch_size, 1)
        # position_norm = self.emb_pos_classific(position_norm.long())

        #[2] subtype
        subtype_id = client_feature_dict["client_feature_subtype"]
        subtype_emb = self.emb_subseg_type(subtype_id.long())

        ds_to_sub_end.to(self.device)
        # masked_ds_sub = ds_to_sub_end[torch.where(gap_position==0)].unsqueeze(-1)
        # relative_gap = abs(ds_to_sub_end - masked_ds_sub)
        # relative_gap[torch.where(padding_mask==0)] = float('inf')
        # relative_bin_index = torch.where(ds_to_sub_end>=600, self.distance_bins(relative_gap, self.nodense_bins), self.distance_bins(relative_gap, self.dense_bins))
        # relative_emb = torch.where(ds_to_sub_end.unsqueeze(-1).repeat(1,1,self.feature_hidden)>=600, self.emb_pos_relaNoDense(relative_bin_index), self.emb_pos_relaDense(relative_bin_index))

        absolute_bin_index = self.distance_bins(ds_to_sub_end, self.absolute_bins)
        absolute_emb = self.emb_pos_absoluteDis(absolute_bin_index)
        absolute_dense_index = torch.where(ds_to_sub_end<=600, 1, 0)
        absolute_dense_emb = self.emb_pos_absoluteDenseDis(absolute_dense_index)
        return position_emb + subtype_emb + absolute_emb + absolute_dense_emb
    

    def distance_bins(self, distance, bins):
        return torch.bucketize(distance, bins, right=True)


class AbsPositionalEmbedding(nn.Module):

    def __init__(self, args, max_len=10):
        super().__init__()

        #[1] parameter
        self.max_length = args['speed_max_len']
        self.feature_hidden = args['speed_att_hid']
        self.device = args["device"]
        #[2] emb
        self.emb_pos_absolute = nn.Embedding(self.max_length, self.feature_hidden).to(self.device)


    def forward(self, batch_size):

        #[1] absolute position
        position = torch.arange(self.max_length - 1, -1, step=-1).to(self.device)
        position = torch.unsqueeze(position, 0)
        position = position.repeat(batch_size,1)
        gap_position = position 

        position= torch.clamp(gap_position, 0, self.max_length)
        position_emb = self.emb_pos_absolute(position.long())
 
        return position_emb


