import pickle
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import shapely
from shapely.geometry import Polygon
from scipy.cluster import hierarchy
from scipy.cluster.hierarchy import leaves_list
from shapely.ops import unary_union
import cv2
from scipy.spatial import cKDTree

def trans_index(polygon):
    min_index = np.argmin(polygon[:,0] + polygon[:,1])
    return np.vstack((polygon[min_index:],polygon[:min_index]))  

data_folder = '../datasets/Manhattan_polygon2_45847_height.pickle'
with open(data_folder, 'rb') as filename:
    polygon_list = pickle.load(filename)


center_mean = np.zeros((len(polygon_list), 2))
max_len = 0
for i in range(len(polygon_list)):
    if polygon_list[i].shape[0] <= 50:
        max_len += 1
         
    center_mean[i] = np.mean(polygon_list[i], axis=0)[:2]

data = []
kdtree = cKDTree(center_mean)
for i in range(len(polygon_list)):
    if i % 1000 == 0:
        print(i)

    distances, indices = kdtree.query(center_mean[i], k=32)

    data.append(np.array([polygon_list[j][:, :2]-center_mean[i] for j in indices], dtype=object))

datas = np.array(data, dtype=object)

os.makedirs("../datasets/data32/", exist_ok=True)
np.save(f"../datasets/data32/poly_32", datas)
###############################################################################################


for i in range(datas.shape[0]):
    if i%1000 == 0:
        print(i)
    pan = datas[i].copy()
    for j in range(pan.shape[0]):
        poly = pan[j].copy()
        poly = trans_index(poly)
        datas[i][j] = poly

count = 0

for i in range(datas.shape[0]):
    # if i == 33794:
    #     continue
    if i%1000 == 0:
        print(i)
        print(count)
    pan = datas[i].copy()
    for j in range(pan.shape[0]):
        poly = pan[j].copy()
        while True:
            flag = 1
            for k in range(poly.shape[0]):
                san = Polygon(np.concatenate([poly[k:k+1], 
                                        poly[(k+1)%(poly.shape[0]):(k+1)%(poly.shape[0])+1], 
                                        poly[(k+2)%(poly.shape[0]):(k+2)%(poly.shape[0])+1]], axis = 0))
                if san.area <= 1:
                    idx = (k+1)%(poly.shape[0])
                    flag = 0
                    count += 1
                    break
            if flag == 0:
                poly = np.delete(poly, idx, axis = 0)
            if flag == 1:
                break
        # print(poly.shape)
        poly = trans_index(poly)
        # print(poly.shape)
        # print(datas[i][j].shape)
        datas[i][j] = poly


count = 0
for i in range(datas.shape[0]):
    if i%1000 == 0:
        print(i)
        print(count)
    pan = datas[i].copy()
    for j in range(pan.shape[0]):
        poly = pan[j].copy()
        if poly.shape[0]>20:
            count+=1
            while poly.shape[0]>20:
                min_area = 1000000
                idx = 0
                for k in range(poly.shape[0]):
                    san = Polygon(np.concatenate([poly[k:k+1], 
                                            poly[(k+1)%(poly.shape[0]):(k+1)%(poly.shape[0])+1], 
                                            poly[(k+2)%(poly.shape[0]):(k+2)%(poly.shape[0])+1]], axis = 0))
                    if san.area < min_area:
                        min_area = san.area
                        idx = (k+1)%(poly.shape[0])
                poly = np.delete(poly, idx, axis = 0)
            poly = trans_index(poly)
            assert poly.shape[0]<=20
            datas[i][j] = poly


tr_idx = random.sample(range(len(datas)), 45000)
datain = datas[tr_idx]

np.save(f"../datasets/data32/poly_32_45000", datas)

data_poly = np.zeros([45000, 32, 20, 2])
data_info = np.zeros([45000, 32])

for batch in range(datain.shape[0]):
    if batch%1000 == 0:
        print(batch)
    pan = datain[batch].copy()
    for i in range(pan.shape[0]):
        poly = pan[i].copy()
        len = poly.shape[0]
        data_poly[batch, i, :len, :] = poly
        data_info[batch, i] = len


np.save(f"../datasets/data32/poly_np", data_poly)
np.save(f"../datasets/data32/polyinfo_np", data_info)





