import pickle
import numpy as np
import os
from shapely.geometry import Polygon
import time


def trans_index(polygon):
    min_index = np.argmin(polygon[:,0] + polygon[:,1])
    return np.vstack((polygon[min_index:],polygon[:min_index])) 

if __name__ == '__main__':
    data_folder = '../datasets/Manhattan_polygon2_45847_height.pickle'
    save_dir = '../datasets/manhattan/'
    os.makedirs(save_dir, exist_ok=True)

    with open(data_folder, 'rb') as filename:
        polygon_list = pickle.load(filename)
    print(len(polygon_list))

    datas = np.array([polygon[:, :2] for polygon in polygon_list], dtype=object)

    print('start_3to1')

    for i in range(datas.shape[0]):
        if i%10000 == 0:
            print(i)
        poly = datas[i].copy()
        while True:
            flag = 1
            lenpoly = poly.shape[0]
            for k in range(lenpoly):
                san = Polygon(np.concatenate([poly[k:k+1], 
                                        poly[(k+1)%(lenpoly):(k+1)%(lenpoly)+1], 
                                        poly[(k+2)%(lenpoly):(k+2)%(lenpoly)+1]], axis = 0))
                if san.area <= 1:
                    idx = (k+1)%(lenpoly)
                    flag = 0
                    break
            if flag == 0:
                poly = np.delete(poly, idx, axis = 0)
            elif flag == 1:
                break
        datas[i] = poly

    print(datas.shape[0])
    np.save(save_dir+'poly_3to1', datas)

    print('###############################################################################################')

    print('start_leq20')
    datas = np.load(save_dir+'poly_3to1.npy', allow_pickle=True)
    for i in range(datas.shape[0]):
        if i%10000 == 0:
            print(i)
        poly = datas[i].copy()
        if poly.shape[0]>20:
            while poly.shape[0]>20:
                min_area = 10000
                idx = 0
                for k in range(poly.shape[0]):
                    san = Polygon(np.concatenate([poly[k:k+1], 
                                            poly[(k+1)%(poly.shape[0]):(k+1)%(poly.shape[0])+1], 
                                            poly[(k+2)%(poly.shape[0]):(k+2)%(poly.shape[0])+1]], axis = 0))
                    if san.area < min_area:
                        min_area = san.area
                        idx = (k+1)%(poly.shape[0])
                poly = np.delete(poly, idx, axis = 0)
            poly = trans_index(poly)
            assert poly.shape[0]<=20
            datas[i] = poly

    print(datas.shape[0])
    np.save(save_dir+'poly_leq20', datas)

    print('###############################################################################################')

    print('start_select')
    window_s = 500 
    datas = np.load(save_dir+'poly_leq20.npy', allow_pickle=True)

    center_mean = np.array([np.mean(polygon, axis=0) for polygon in datas])
    data = []
    for i, c in enumerate(center_mean): # 0.25
        if i % 10000 == 0:
            print(i)

        within_window_mask = [(abs(polygon - c) < (window_s / 2)).all() for polygon in datas]
        ptem = [polygon - (c - (window_s / 2)) for polygon in datas[np.where(within_window_mask)[0]]]
        data.append(ptem)
        
    np.save(save_dir+'poly_500', np.array(data, dtype=object))

    print('###############################################################################################')

    print('start_discard_max_min')
    datas = np.load(save_dir+'poly_500.npy', allow_pickle=True)

    print(len(datas))
    data_discardmin = []

    for d in datas:
        if len(d) > 6:
            data_discardmin.append(d)
    print('-min', len(data_discardmin))

    data_discardmax = []
    for d in data_discardmin:
        if len(d) < 60:
            data_discardmax.append(d)

    print('-max', len(data_discardmax))

    np.save(save_dir+'poly_500_6_60', np.array(data_discardmax, dtype=object))
    print('###############################################################################################')

    print('start_discretize')
    datas = np.load(save_dir+'poly_500_6_60.npy', allow_pickle=True)

    ps = 50
    l = 10
    imgs = datas

    x_d = []
    b_d = []
    count_p = np.zeros((imgs.shape[0], ps, ps), dtype=int)
    for batch in range(imgs.shape[0]):
        flag = 0
        if batch%10000 == 0:
            print(batch)

        tem = np.zeros([len(imgs[batch]), 2])
        for i in range(len(imgs[batch])):
            poly = imgs[batch][i].copy()
            center = np.mean(poly, axis = 0)
            x,y = center//l
            assert x<=49 and y<=49
            if count_p[batch, int(y), int(x)]>=1:
                flag = 1
                break
            tem[i, :] = np.array([x,y])
            count_p[batch, int(y), int(x)] += 1
        if flag == 0:
            b_d.append(imgs[batch])
            x_d.append(tem)

    print(len(x_d))
    print(len(b_d))

    np.save(save_dir+'poly_pos', np.array(x_d, dtype=object))
    np.save(save_dir+'poly', np.array(b_d, dtype=object))

    print('###############################################################################################')

    print('start_padding')
    datain = np.load(save_dir+'poly.npy', allow_pickle=True)
    datain_pos = np.load(save_dir+'poly_pos.npy', allow_pickle=True)

    lendata = len(datain)
    data_poly = np.zeros([lendata, 60, 20, 2])
    data_pos = np.zeros([lendata, 60, 2])
    data_info = np.zeros([lendata, 61])

    for batch in range(lendata):
        if batch%10000 == 0:
            print(batch)
        pan = datain[batch].copy()
        data_info[batch, 0] = len(pan)
        for i in range(len(pan)):
            poly = pan[i].copy()
            leng = poly.shape[0]
            data_poly[batch, i, :leng, :] = poly
            data_pos[batch, i, :] = datain_pos[batch][i]
            data_info[batch, i+1] = leng


    np.save(save_dir+'polypos_np', data_pos)
    np.save(save_dir+'poly_np', data_poly)
    np.save(save_dir+'polyinfo_np', data_info)






