import pickle
from tqdm import tqdm
import math
import random
import torch
import numpy as np
import datetime 

class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, id_size, datatype, mbr, repeat_times=1, debug=False):

        self.dataset = dataset
        self.id_size = id_size
        self.datatype = datatype # train, val, test
        self.repeat_times = repeat_times
        self.debug = debug
        self.mbr = mbr

        self.data = self.load_data(self.dataset, self.datatype, self.id_size, self.repeat_times, self.debug)
        self.data = self.repeat_data(self.repeat_times)
        # self.textprotype = np.load("/data/code/llm/BERT/BERT-small/representative_100.npy")
    
    def load_data(self, dataset, types, id_size, repeat_times=1, debug=False):
        data_path = "./load_data/data/{}/{}/pro_data_example.bin".format(dataset, types)
        
        with open(data_path, 'rb') as f:  #路段序列
            all_data1 = pickle.load(f)
            f.close()
        all_data = []

        for i in all_data1:
            if len(i) <= 100:
                all_data.append(i)

        add_candi_id = np.zeros((1, id_size))
        candi_ids = []
        
        
        self.maxn_length = 0
        for i in all_data:
            self.maxn_length = max(self.maxn_length, len(i))


        print(types, self.maxn_length, len(all_data))

        return all_data
    def repeat_data(self, time):
        repeat_data = self.data * time
        return repeat_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        """Generate one sample of data"""
        
        road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, times, road_condition_x_index, road_condition_y_index, road_condition_t_index, start_times, end_times = self.process_data(self.data[index])

        return road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, times, road_condition_x_index, road_condition_y_index, road_condition_t_index, start_times, end_times

    
    def gps2grid(self, lat, lng, grid_size=50):
        """
        mbr:
            MBR class.
        grid size:
            int. in meter
        """
        LAT_PER_METER = 8.993203677616966e-06
        LNG_PER_METER = 1.1700193970443768e-05
        lat_unit = LAT_PER_METER * grid_size
        lng_unit = LNG_PER_METER * grid_size

        locgrid_x = int((lat - self.mbr['min_lat']) / lat_unit) + 1
        locgrid_y = int((lng - self.mbr['min_lng']) / lng_unit) + 1
        return locgrid_x, locgrid_y

    def candi_list_map(self, tr):
        tmp_size = torch.zeros((1, self.id_size))
        if len(tr['candi_id']) != 0:
            tmp_candi_ls = torch.from_numpy(np.array(tr['candi_id'])).float()
            ids = tmp_candi_ls[:, 0].long().tolist()
            prob = tmp_candi_ls[:, 2]
            tmp_size[0, ids] = prob
        candi_id_list = tmp_size
        return candi_id_list

    def process_data(self, data):
        """
        将数据处理成不同的列表形式
        """
        road_id, road_rate, mm_lat, mm_lng = [], [], [], []
        src_lat, src_lng, src_candi_id = [], [], []
        times = []
        road_condition_x_index, road_condition_y_index, road_condition_t_index = [], [], []

        interval = 64
        lng_interval = abs(self.mbr['max_lng'] - self.mbr['min_lng']) / interval   
        lat_interval = abs(self.mbr['max_lat'] - self.mbr['min_lat']) / interval

        traj_start = data[0]['ori_time']
        traj_end = data[-1]['ori_time']
        start_hour, start_minute = traj_start.hour, traj_start.minute
        end_hour, end_minute = traj_end.hour, traj_end.minute
        start_week, end_week = traj_start.weekday(), traj_end.weekday()
        
        start_times = [start_hour, start_minute, start_week]
        end_times = [end_hour, end_minute, end_week]

        for tr in data:
            road_id.append(tr['road_id'])
            road_rate.append(tr['rate'])
            mm_lat.append(tr['mm_lat'])
            mm_lng.append(tr['mm_lng'])


            # x, y = gps2grid(tr['ori_lat'], tr['ori_lng'], mbr, grid_size)
            x = 2 * (tr['ori_lat'] - self.mbr['min_lat']) / (self.mbr['max_lat'] - self.mbr['min_lat']) - 1
            y = 2 * (tr['ori_lng'] - self.mbr['min_lng']) / (self.mbr['max_lng'] - self.mbr['min_lng']) - 1
            src_lat.append(x)
            src_lng.append(y)

            timestamp = tr['ori_time'].timestamp()
            _time = timestamp % (24 * 60 * 60)
            times.append(_time)

            x_id = int(max(tr['ori_lat'] - self.mbr['min_lat'], 0) / lat_interval)
            y_id = int(max(tr['ori_lng'] - self.mbr['min_lng'], 0) / lng_interval)
            tmp_time = str(tr['ori_time']).split(" ")[1].split(":")
            # t_id = int(tmp_time[0]) * 4 + int(tmp_time[1]) // 15
            t_id = int(tmp_time[0])
            road_condition_x_index.append(x_id)
            road_condition_y_index.append(y_id)
            road_condition_t_index.append(t_id)
            # print(tr['candi_id_list'])
            # src_candi_id[traj_i, tr_i] = tr['candi_id_list']
            # src_candi_id.append(tr['candi_id_list'])
        src_candi_id = map(self.candi_list_map, data)
        src_candi_id = list(src_candi_id)
        src_candi_id = torch.cat(src_candi_id, dim=0).squeeze(1)
        
        return road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, times, road_condition_x_index, road_condition_y_index, road_condition_t_index, start_times, end_times


def distance_1(lat1, lng1, lat2, lng2):
    """
    Calculate haversine distance between two GPS points in meters
    Args:
    -----
        a,b: SPoint class
    Returns:
    --------
        d: float. haversine distance in meter
    """
    EARTH_MEAN_RADIUS_METER = 6371008.7714

    if lat1 == lat2 and lng1 == lng2:return 0.0

    delta_lat = math.radians(lat2-lat1)
    delta_lng = math.radians(lng2-lng1)
    h = math.sin(delta_lat / 2.0) * math.sin(delta_lat / 2.0) + math.cos(math.radians(lat1)) * math.cos(
        math.radians(lat2)) * math.sin(delta_lng / 2.0) * math.sin(delta_lng / 2.0)
    c = 2.0 * math.atan2(math.sqrt(h), math.sqrt(1 - h))
    d = EARTH_MEAN_RADIUS_METER * c / 1000
    return d

# Use for DataLoader
def collate_fn(data):
    road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, times, road_condition_x_index, road_condition_y_index, road_condition_t_index, start_times, end_times = zip(*data)  # unzip data
    road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, times, road_condition_x_index, road_condition_y_index, road_condition_t_index, start_times, end_times = list(road_id), list(road_rate), \
                                            list(mm_lat), list(mm_lng), list(src_lat), list(src_lng), list(src_candi_id), list(times), list(road_condition_x_index), \
                                            list(road_condition_y_index), list(road_condition_t_index), list(start_times), list(end_times)
    length = []
    maxn_length = 0
    for i in road_id:
        maxn_length = max(maxn_length, len(i))
        
    id_size = src_candi_id[0].shape[1]
    add_candi_id_list = torch.zeros((1, id_size))

    def merge(x, fill, fill_num):
        
        x = x + fill * fill_num
        
        x = np.array(x)
        return x
    # road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, times = road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, times

    res_src_candi = torch.zeros((len(road_id), maxn_length, id_size))

    for i in range(len(road_id)):
        traj_i_length = len(road_id[i])
        fill_num = maxn_length - traj_i_length  #对于不足最大长度的，后面补0

        length.append(traj_i_length)
        # print(road_id[i], fill_num)
        road_id[i] = merge(road_id[i], [0], fill_num)
        road_rate[i] = merge(road_rate[i], [0], fill_num)
        mm_lat[i] = merge(mm_lat[i], [0], fill_num)
        mm_lng[i] = merge(mm_lng[i], [0], fill_num)

        src_lat[i] = merge(src_lat[i], [0], fill_num)
        src_lng[i] = merge(src_lng[i], [0], fill_num)
        
        road_condition_x_index[i] = merge(road_condition_x_index[i], [0], fill_num)
        road_condition_y_index[i] = merge(road_condition_y_index[i], [0], fill_num)
        road_condition_t_index[i] = merge(road_condition_t_index[i], [0], fill_num)

        res_src_candi[i, :length[i]] = src_candi_id[i]


        times[i] = merge(times[i], [0], fill_num)

    return road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, res_src_candi, times, road_condition_x_index, road_condition_y_index, road_condition_t_index, length, start_times, end_times

def construct_mask(src_lat, src_lng, mbr, length, data_type, sample_num, keep_ratio=None):

    maxn_length = max(length)
    traj_total_time = []
    traj_total_dis = []
    traj_space_dis = []
    lat = (src_lat + 1) / 2 * (mbr['max_lat'] - mbr['min_lat']) + mbr['min_lat']
    lng = (src_lng + 1) / 2 * (mbr['max_lng'] - mbr['min_lng']) + mbr['min_lng']

    if keep_ratio is None:
        raise NotImplementedError(f'Mask val or test set, mush setting keep_ratio')

    mask_index = np.zeros((sample_num, maxn_length)) #用于记录哪个位置的元素被mask了, mask_index = 1表示被mask
    padd_index = np.zeros((sample_num, maxn_length))

    patch_index = np.zeros((sample_num, maxn_length))

    # forward_delta_t backward_delta_t用来记录与前后的时间间隔信息
    forward_delta_t = np.zeros((sample_num, maxn_length))
    backward_delta_t = np.zeros((sample_num, maxn_length))

    # backward_index用来记录后面最近的已知GPS点的下标
    forward_index = np.zeros((sample_num, maxn_length))
    backward_index = np.zeros((sample_num, maxn_length))

    forward_padding = [1 * i for i in range(1, 16)]

    for i in range(sample_num):
        traj_len = length[i]  #当前轨迹的总长度
        total_time = traj_len * 15
        traj_total_time.append([total_time//60, total_time % 60 ])  #通行时间：x分x秒


        traj_i_mask = list(range(traj_len))

        # 确定元素的下标，src_traj_index中包含的值代表保留下来的
        if (traj_len - 1) % int(1 / keep_ratio) == 0:
            src_traj_index = traj_i_mask[::int(1 / keep_ratio)]
        else:
            src_traj_index = traj_i_mask[::int(1 / keep_ratio)] + [traj_i_mask[-1]]

        # set(traj_i_mask) - set(src_traj_index)代表在traj_i_mask里面，不在src_traj_index中，即不可见的。将不可见的mask设置为1
        mask_index[i][list(set(traj_i_mask) - set(src_traj_index))] = 1
        padd_index[i][traj_len:] = 1


        # 创建一个足够大的数组，填充为0
        result_array = np.zeros(max(src_traj_index) + 1, dtype=int)

        back_boundaries = np.searchsorted(src_traj_index, np.arange(len(result_array)))
        backward_index[i][:traj_len] = np.array(src_traj_index)[back_boundaries]
        
        forward_boundaries = np.searchsorted(src_traj_index, np.arange(len(result_array)), side='right') - 1
        forward_index[i][:traj_len] = np.array(src_traj_index)[forward_boundaries]
        
        traj_dis = 0
        for j in range(1, len(src_traj_index)):
            start_idx = src_traj_index[j - 1]
            end_idx = src_traj_index[j]

            traj_dis += distance_1(lat[i][start_idx], lng[i][start_idx], lat[i][end_idx], lng[i][end_idx])

            forward_delta_t[i][start_idx+1: end_idx] = forward_padding[0: end_idx - start_idx - 1]
            backward_delta_t[i][start_idx+1: end_idx] = forward_padding[0: end_idx - start_idx - 1][::-1]
        traj_total_dis.append(traj_dis)
        
    return mask_index, padd_index, forward_delta_t, backward_delta_t, forward_index, backward_index, traj_total_time, traj_total_dis

def cal_prompt_token():
    import torch
    import torch.nn as nn
    # from .re_transformer import BertModel, BertTokenizer
    from transformers import BertModel, BertTokenizer

    task_prompt_tensor = {}
    
    def encode_prompt(text):
        indexed_tokens = tokenizer.encode(text)
        tokens_tensor = torch.tensor([indexed_tokens])  # 文本编码转tensor
        return bert_token[tokens_tensor]

    MODEL_PATH = './PLM/BERT' # 装着上面3个文件的文件夹位置
    tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=MODEL_PATH)

    model = BertModel.from_pretrained(MODEL_PATH)  # 读取预训练模型

    bert_token = model.state_dict()['embeddings.word_embeddings.weight']

    for keep_ratio in [0.25, 0.125, 0.0625]:
        if keep_ratio == 0.25: target_sec = 60
        if keep_ratio == 0.125: target_sec = 120
        if keep_ratio == 0.0625: target_sec = 240
        text = "Task: Sparse trajectory recovery. Target: Output the road segment and movement ratio for each point in the trajectory. Content: The sparse trajectory is sampled every {} seconds and aims to recover trajectory every 15 seconds. ".format(target_sec)
    
        task_prompt_tensor[keep_ratio] = encode_prompt(text)
    

    _time_prompt = encode_prompt("The trajectory started at ")

    # _time =  _start_time + " on " + _start_date + " and ended at " + _end_time + " on " + _end_date
        
    # travel_minutes, travel_seconds = total_time[length][0], total_time[length][1]
    # travel_dis = total_dis[length]
    travel_prompt_1 = encode_prompt("Total time cost: ")
    travel_prompt_2 = encode_prompt("Total space transfer distance: ")
    # travel_prompt = ". Total time cost: {} minutes {} seconds. Total space transfer distance: {:.2f} kilometers. ".format(travel_minutes, travel_seconds, travel_dis)
    
    traj_prompt = encode_prompt("The sparse trajectory is: ")
    
    
    # traj_prompt = _time + travel_prompt + "The sparse trajectory is: "
    # traj_prompt_token = encode_prompt(traj_prompt)
    
    return task_prompt_tensor, _time_prompt, (travel_prompt_1, travel_prompt_2), traj_prompt
