import pickle
from tqdm import tqdm
import math
import random
import torch
import numpy as np
from torch.nn.utils.rnn import pad_sequence

def split_train_val_test(data_type, train_ratio=0.8, val_ratio=0.1):
    data_path = '/data/traj_gen/A_new_dataset/{}/PLM_traj_rec_prec/pro_data_debug.bin'.format(data_type)
    # all_length
    # 长度大于20
    if data_type == "Porto":
        min_length, max_length = 20, 100
    elif data_type == "Chengdu":
        min_length, max_length = 20, 60

    with open(data_path, 'rb') as f:  #路段序列
        all_data = pickle.load(f)
        f.close()

    process_data = []
    minn, maxn = 1e9, -1
    for i in tqdm(all_data):
        if len(i) >= min_length and len(i) <= max_length: 
            process_data.append(i)
    
    val_split, test_split = int(len(process_data) * train_ratio), int(len(process_data) * (train_ratio + val_ratio))

    train_set = process_data[: val_split]
    val_set = process_data[val_split : test_split]
    test_set = process_data[test_split :]

    print(len(train_set), len(val_set), len(test_set), len(process_data))

    return train_set, val_set, test_set

def repeat_time(data, time):
    repeat_data = data * time
    return repeat_data

def load_data(dataset, types, id_size, repeat_times=1, debug=False):
    data_path = "/data/traj_recover/{}/PLM_traj_rec_prec/{}/pro_data.bin".format(dataset, types)
    if debug:
        data_path = "./data/traj_recover/{}/PLM_traj_rec_prec/{}/pro_data_debug.bin".format(dataset, types)
    with open(data_path, 'rb') as f:  #路段序列
        all_data = pickle.load(f)
        f.close()

    minn_length, maxn_length = 1e9, -1
    for i in all_data:
        minn_length = min(minn_length, len(i))
        maxn_length = max(maxn_length, len(i))
    print(minn_length, maxn_length)
    
    add_candi_id = np.zeros((1, id_size))
    candi_ids = []
    for traj in all_data:
        traj_ids = []
        for tr in traj:#轨迹点
            tmp_size = np.zeros((1, id_size))
            if len(tr['candi_id']) != 0:
                # print(tr['candi_id'])
                tmp_candi_ls = np.array(tr['candi_id'])
                ids = np.array(tmp_candi_ls[:, 0], dtype=np.long).tolist()
                prob = tmp_candi_ls[:, 2]
                tmp_size[0, ids] = prob
            tr['candi_id_list'] = tmp_size
    
    maxn_length = 0
    for i in all_data:
        maxn_length = max(maxn_length, len(i))
    print(types, maxn_length)

    if types == "train":
        all_data = repeat_time(all_data, repeat_times)
        random.shuffle(all_data)
    
    return all_data


def next_batch(data, batch_size):
    data_length = len(data)
    num_batches = math.ceil(data_length / batch_size)
    for batch_index in range(num_batches):
        start_index = batch_index * batch_size
        end_index = min((batch_index + 1) * batch_size, data_length)
        if end_index - start_index > 1:
            yield data[start_index:end_index]

def gps2grid(lat, lng, mbr, grid_size):
        """
        mbr:
            MBR class.
        grid size:
            int. in meter
        """
        LAT_PER_METER = 8.993203677616966e-06
        LNG_PER_METER = 1.1700193970443768e-05
        lat_unit = LAT_PER_METER * grid_size
        lng_unit = LNG_PER_METER * grid_size

        locgrid_x = int((lat - mbr['min_lat']) / lat_unit) + 1
        locgrid_y = int((lng - mbr['min_lng']) / lng_unit) + 1
        return locgrid_x, locgrid_y

def process_batch(data, mbr, id_size, grid_size=50, data_type="train", keep_ratio = 0.0625):
    # 将数据处理成模型可接受的输入形式
    src_data = [] #进行下采样操作
    trg_id, trg_rate, trg_lat, trg_lng = [], [], [], []
    src_lat, src_lng, src_time = [], [], []
    src_candi_id = []
    traj_length = []
    maxn_length = -1
    # for traj in data:
    #     maxn_length = max(maxn_length, len(traj))

    maxn_src_candi_ID = -1  #记录每个GPS点周围候选路段的最大值，方便转换为一个tensor
    # src_candi_id = np.zeros((len(data), maxn_length, id_size))
    traj_i = 0
    for traj in data:
        # print(traj)

        maxn_length = max(maxn_length, len(traj))
        traj_length.append(len(traj))
        tmp_id, tmp_rate = [], []
        tmp_lat, tmp_lng = [], []
        tmp_src_lat, tmp_src_lng, tmp_time = [], [], []
        tmp_src_candi_id = []

        tr_i = 0
        for tr in traj:
            tmp_id.append(tr['road_id'])
            tmp_rate.append(tr['rate'])
            tmp_lat.append(tr['mm_lat'])
            tmp_lng.append(tr['mm_lng'])


            # x, y = gps2grid(tr['ori_lat'], tr['ori_lng'], mbr, grid_size)
            x = 2 * (tr['ori_lat'] - mbr['min_lat']) / (mbr['max_lat'] - mbr['min_lat']) - 1
            y = 2 * (tr['ori_lng'] - mbr['min_lng']) / (mbr['max_lng'] - mbr['min_lng']) - 1
            tmp_src_lat.append(x)
            tmp_src_lng.append(y)

            timestamp = tr['ori_time'].timestamp()
            _time = timestamp % (24 * 60 * 60)
            tmp_time.append(_time)

            # src_candi_id[traj_i, tr_i] = tr['candi_id_list']
            tr_i += 1
            tmp_src_candi_id.append(tr['candi_id_list'])
            # maxn_src_candi_ID = max(maxn_src_candi_ID, len(tr['candi_id']))
        
        trg_id.append(tmp_id)
        trg_rate.append(tmp_rate)
        trg_lat.append(tmp_lat)
        trg_lng.append(tmp_lng)

        src_lat.append(tmp_src_lat)
        src_lng.append(tmp_src_lng)
        src_time.append(tmp_time)
        traj_i += 1
        src_candi_id.append(tmp_src_candi_id)
    # 将一个batch的轨迹填充成相同长度，不够的后面补0

    add_candi_id_list = np.zeros((1, id_size))
    
    for i in range(len(trg_id)):
        fill_num = maxn_length - len(trg_id[i])
        fill_lst = [0] *  fill_num #for _ in range(fill_num)]
        trg_id[i] = trg_id[i] + fill_lst
        trg_rate[i] = trg_rate[i] + fill_lst
        trg_lat[i] = trg_lat[i] + fill_lst
        trg_lng[i] = trg_lng[i] + fill_lst

        src_lat[i] = src_lat[i] + fill_lst
        src_lng[i] = src_lng[i] + fill_lst
        src_time[i] = src_time[i] + fill_lst

        # src_candi_id[i] = src_candi_id[i] + [[0, 0, 0]] * fill_num
        src_candi_id[i] = src_candi_id[i] + [add_candi_id_list] * fill_num
        
    src_candi_id = add_candi_id_list #np.array(src_candi_id).squeeze(2)
    # print(src_candi_id.shape)
    # exit()
    # src_candi_id = np.stack(src_candi_id, 0).squeeze(2)
    # print(src_candi_id.shape)
    # print(maxn_length)
    # exit()
    if data_type=="train":
        mask_index = np.zeros((len(data), maxn_length)) #用于记录哪个位置的元素被mask了, mask_index = 1表示被mask
        padd_index = np.zeros((len(data), maxn_length))

        for i in range(len(data)):
            traj_len = len(data[i])
            tmp_traj = data[i][1 : traj_len - 1]
            curr_mask_num = random.randint(0, len(tmp_traj))  # 随机选择一个数，代表缺失的个数
            # print(len(tmp_traj))
            mask_index[i][random.choices([i+1 for i in range(len(tmp_traj))], k=curr_mask_num)] = 1
            padd_index[i][traj_len:] = 1
        src_lat, src_lng, src_time = np.array(src_lat), np.array(src_lng), np.array(src_time)
        
        trg_id, trg_rate = np.array(trg_id), np.array(trg_rate)
        src_lat[mask_index == 1] = 0
        src_lng[mask_index == 1] = 0
    elif data_type == "val" or data_type == "test":
        mask_index = np.zeros((len(data), maxn_length)) #用于记录哪个位置的元素被mask了, mask_index = 1表示被mask
        padd_index = np.zeros((len(data), maxn_length))

        for i in range(len(data)):
            traj_len = len(data[i])  #当前轨迹的总长度

            traj_i_mask = list(range(traj_len))

            # 以下只记录下标
            if (traj_len - 1) % int(1 / keep_ratio) == 0:
                src_traj_index = traj_i_mask[::int(1 / keep_ratio)]
            else:
                src_traj_index = traj_i_mask[::int(1 / keep_ratio)] + [traj_i_mask[-1]]


            mask_index[i][list(set(traj_i_mask) - set(src_traj_index))] = 1
            padd_index[i][traj_len:] = 1
        src_lat, src_lng, src_time = np.array(src_lat), np.array(src_lng), np.array(src_time)
        
        trg_id, trg_rate = np.array(trg_id), np.array(trg_rate)
        src_lat[mask_index == 1] = 0
        src_lng[mask_index == 1] = 0
        

    return src_lat, src_lng, src_time, src_candi_id, mask_index, trg_id, trg_rate, trg_lat, trg_lng, traj_length, padd_index
