import torch

def pad_poly(poly_list, max_poly = 20, discre_size = 10):
    num = len(poly_list)
    
    pad_poly = torch.zeros([num, max_poly, 2])
    pos_in = torch.zeros([num, 2])
    for i in range(num):
        pad_poly[i, :poly_list[i].shape[0], :] = poly_list[i]
        pos_in[i,:] = torch.mean(poly_list[i], dim = 0)//discre_size
    
    return pad_poly.unsqueeze(0), pos_in.unsqueeze(0)

def in_poly_idx(poly, discre = 50):
    points = torch.cat([torch.arange(discre).repeat_interleave(discre).unsqueeze(-1), torch.arange(discre).repeat(discre).unsqueeze(-1)], dim = -1)
    index = -torch.ones(discre*discre)
    for i, corner in enumerate(poly):
        next_i = i + 1 if i + 1 < len(poly) else 0
        x1, y1 = corner
        x2, y2 = poly[next_i]
        condition1 =  (min(y1, y2) < points[:, 1])* (points[:, 1]<= max(y1, y2))  # find horizontal edges of polygon
        x = x1 + (points[:, 1] - y1) * (x2 - x1) / (y2 - y1)
        condition2 = x > points[:, 0]
        index = index*((-(condition1*condition2).long())*2+1)

    out = torch.where(index>0, 0, 1)
    return out 