import numpy as np
import torch

def random_masking(x, pos, info, remain_num, h = None, max_build = 60, max_poly = 20):
    len_keep = remain_num
    poly_reserve = torch.zeros([x.shape[0], remain_num, max_poly, 2])
    pos_reserve = torch.zeros([x.shape[0], remain_num, 2])

    poly_tar = torch.zeros([x.shape[0], max_build-remain_num, max_poly, 2])
    pos_tar = torch.zeros([x.shape[0], max_build-remain_num, 2])

    if h is not None:
        h_reserve = torch.zeros([x.shape[0], remain_num, 1])
        h_tar = torch.zeros([x.shape[0], max_build-remain_num, 1])
        
    len_tar = []

    for i in range(x.shape[0]):
        L = int(info[i, 0])

        x_tem = x[i, :L].clone()
        pos_tem = pos[i, :L].clone()
        if h is not None:
            h_tem = h[i, :L].clone()

        noise = np.random.rand(L)
        ids_shuffle = np.argsort(noise, axis=0)  
        ids_keep = ids_shuffle[:len_keep]
        ids_tar = ids_shuffle[len_keep:]

        poly_reserve[i] = x_tem[ids_keep]
        pos_reserve[i] = pos_tem[ids_keep]
        poly_tar[i][:len(ids_tar), :, :] = x_tem[ids_tar]
        pos_tar[i][:len(ids_tar), :] = pos_tem[ids_tar]
        if h is not None:
            h_reserve[i, :, 0] = h_tem[ids_keep]
            h_tar[i][:len(ids_tar), 0] = h_tem[ids_tar]
        len_tar.append(info[i, ids_tar+1].long())
        
    if h is not None:
        return poly_reserve, pos_reserve, h_reserve, poly_tar, pos_tar, h_tar, len_tar
    return poly_reserve, pos_reserve, poly_tar, pos_tar, len_tar


def random_masking_test(x, pos, info, remain_num, max_build = 60, max_poly = 20):
    assert x.shape[0]==1
    len_keep = remain_num
    poly_reserve = torch.zeros([x.shape[0], remain_num, max_poly, 2])
    pos_reserve = torch.zeros([x.shape[0], remain_num, 2])

    poly_tar = torch.zeros([x.shape[0], max_build-remain_num, max_poly, 2])
    pos_tar = torch.zeros([x.shape[0], max_build-remain_num, 2])
    
    len_tar = []

    L = int(info[0, 0])

    x_tem = x[0, :L].clone()
    pos_tem = pos[0, :L].clone()

    noise = np.random.rand(L)
    ids_shuffle = np.argsort(noise, axis=0)  
    ids_keep = ids_shuffle[:len_keep]
    ids_tar = ids_shuffle[len_keep:]

    poly_reserve[0] = x_tem[ids_keep]
    pos_reserve[0] = pos_tem[ids_keep]
    poly_tar[0][:len(ids_tar), :, :] = x_tem[ids_tar]
    pos_tar[0][:len(ids_tar), :] = pos_tem[ids_tar]

    len_tar.append(info[0, ids_tar+1].long())
    
    return poly_reserve, pos_reserve, poly_tar, pos_tar, len_tar, ids_keep


