#!/usr/bin/env python
# -*- coding: utf-8 -*-
 


import os
import sys
import glob
import h5py
import numpy as np
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from utils import eulerangles


zeroTolerance = 1e-5

def download():
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    #DATA_DIR = os.path.join(BASE_DIR, 'data')
    DATA_DIR = os.path.join('datasets','modelnet40_ply_hdf5_2048')

    if not os.path.exists(DATA_DIR):
        os.mkdir(DATA_DIR)
    if not os.path.exists( DATA_DIR ):
        www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
        zipfile = os.path.basename(www)
        os.system('wget %s; unzip %s' % (www, zipfile))
        os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
        os.system('rm %s' % (zipfile))


def load_data(partition, num_points, perturb):
    # download()
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    print(BASE_DIR)
    DATA_DIR = os.path.join(BASE_DIR, 'datasets')
    all_data = []
    all_label = []
    for h5_name in glob.glob(os.path.join(DATA_DIR,'modelnet10_hdf5_2048', '%s*.h5'%partition), recursive=True):
    # for h5_name in glob.glob(os.path.join(DATA_DIR,'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition), recursive=True):
        f = h5py.File(h5_name)
        data = f['data'][:].astype('float32')
        data = data[:,0: num_points, :]
        if perturb != 0:
            data = jitter_pointcloud(data, sigma=perturb, clip=perturb *2)
        label = f['label'][:].astype('int64')
        f.close()
        all_data.append(data)
        all_label.append(label)
    all_data = np.concatenate(all_data, axis=0)
    # all_Lap_mat = computeLaplacian(computeAdj(all_data), True)
    all_Lap_mat = []
    all_label = np.concatenate(all_label, axis=0)

    # print(len(all_label))

    # all_label = np.where(all_label == 1, 0, all_label)
    # # print(np.count_nonzero(all_label == 2))
    # all_label = np.where(all_label == 2, 1, all_label)
    # all_label = np.where(all_label != 1, 0, all_label)


    return all_data, all_Lap_mat, all_label


def translate_pointcloud(pointcloud):
    xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
    xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
       
    translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
    return translated_pointcloud

def computeAdj(x):
        # x: (batch_size, num_points, num_features)
        x_transpose = np.transpose(x, axes = [0,2,1])
        x_inner = np.matmul(x, x_transpose)
        x_inner = -2 * x_inner
        x_square = np.sum(np.square(x), axis = -1, keepdims = True)
        x_square_transpose = np.transpose(x_square, axes = [0,2,1])
        adj_mat = x_square + x_inner + x_square_transpose
        adj_mat = np.exp(-adj_mat)

        return adj_mat

def computeLaplacian(adj_mat, normalize):
        if normalize:
            D = np.sum(adj_mat, axis = 1)  # (batch_size,num_points)
            eye = np.ones_like(D[0,:])

            eye = np.diag(eye)
            # print(eye.shape)

            D = 1 / np.sqrt(D)

            D_diag = np.diag(D[0,:])
            D_diag = np.expand_dims(D_diag, axis = 0)
            # print(D_diag.shape)

            for i in range(1, D.shape[0], 1):
                Dtemp = np.diag(D[i,:])
                Dtemp = np.expand_dims(Dtemp, axis = 0)
                D_diag = np.concatenate((D_diag, Dtemp), axis=0) 
            
            # print(D_diag.shape)
            L = eye - np.matmul(np.matmul(D_diag, adj_mat), D_diag)
            # print(np.count_nonzero(abs(L) < zeroTolerance))
            L = np.where(abs(L) < zeroTolerance, 0, L)




        else:
            D = np.sum(adj_mat, axis=1)  # (batch_size,num_points)
            # print(D.shape[0])
            # eye = tf.ones_like(D)
            # eye = tf.matrix_diag(eye)
            # D = 1 / tf.sqrt(D)
            D_diag = np.diag(D[0,:])
            D_diag = np.expand_dims(D_diag, axis = 0)
            # print(D_diag.shape)

            for i in range(1, D.shape[0], 1):
                Dtemp = np.diag(D[i,:])
                Dtemp = np.expand_dims(Dtemp, axis = 0)
                D_diag = np.concatenate((D_diag, Dtemp), axis=0) 
            # D = np.diag(D[1])
            # print(D_diag.shape)
            L = D_diag - adj_mat
            L = np.where(abs(L) < zeroTolerance, 0, L)
            
        return L


def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
    B, N, C = pointcloud.shape
    pointcloud += np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
    return pointcloud

def draw_point_cloud(input_points, canvasSize=500, space=200, diameter=25,
                     xrot=0, yrot=0, zrot=0, switch_xyz=[0, 1, 2], normalize=True):
    """ Render point cloud to image with alpha channel.
        Input:
            points: Nx3 numpy array (+y is up direction)
        Output:
            gray image as numpy array of size canvasSizexcanvasSize
    """
    image = np.zeros((canvasSize, canvasSize))
    if input_points is None or input_points.shape[0] == 0:
        return image

    points = input_points[:, switch_xyz]
    M = eulerangles.euler2mat(zrot, yrot, xrot)
    points = (np.dot(M, points.transpose())).transpose()

    # Normalize the point cloud
    # We normalize scale to fit points in a unit sphere
    if normalize:
        centroid = np.mean(points, axis=0)
        points -= centroid
        furthest_distance = np.max(np.sqrt(np.sum(abs(points) ** 2, axis=-1)))
        points /= furthest_distance

    # Pre-compute the Gaussian disk
    radius = (diameter - 1) / 2.0
    disk = np.zeros((diameter, diameter))
    for i in range(diameter):
        for j in range(diameter):
            if (i - radius) * (i - radius) + (j - radius) * (j - radius) <= radius * radius:
                disk[i, j] = np.exp((-(i - radius) ** 2 - (j - radius) ** 2) / (radius ** 2))
    mask = np.argwhere(disk > 0)
    dx = mask[:, 0]
    dy = mask[:, 1]
    dv = disk[disk > 0]

    # Order points by z-buffer
    zorder = np.argsort(points[:, 2])
    points = points[zorder, :]
    points[:, 2] = (points[:, 2] - np.min(points[:, 2])) / (np.max(points[:, 2] - np.min(points[:, 2])))
    max_depth = np.max(points[:, 2])

    for i in range(points.shape[0]):
        j = points.shape[0] - i - 1
        x = points[j, 0]
        y = points[j, 1]
        xc = canvasSize / 2 + (x * space)
        yc = canvasSize / 2 + (y * space)
        xc = int(np.round(xc))
        yc = int(np.round(yc))

        px = dx + xc
        py = dy + yc

        image[px, py] = image[px, py] * 0.7 + dv * (max_depth - points[j, 2]) * 0.3

    image = image / np.max(image)
    return image

def point_cloud_isoview(points):
    """ input points Nx3 numpy array (+y is up direction).
        return an numpy array gray image of size 500x1500. """
    # +y is up direction
    # xrot is azimuth
    # yrot is in-plane
    # zrot is elevation

    img = draw_point_cloud(points, zrot= 0 / 180.0 * np.pi, xrot= -10/ 180.0 * np.pi, yrot = 15/ 180.0 * np.pi, switch_xyz=[1, 2, 0])
    return img

class ModelNet40(Dataset):
    def __init__(self, num_points, partition='train', perturb = 0):
        self.data, self.lap_mat,  self.label = load_data(partition, num_points, perturb)
        self.num_points = num_points
        self.partition = partition        

    def __getitem__(self, item):
        pointcloud = self.data[item, :, :]
        # lap_matrix = self.lap_mat[item, :, :]
        lap_matrix =[]
        label = self.label[item]
        # if self.partition == 'train':
        #     pointcloud = translate_pointcloud(pointcloud)
        #     np.random.shuffle(pointcloud)
        return pointcloud, lap_matrix, label

    def __len__(self):
        return self.data.shape[0]


if __name__ == '__main__':
    train = ModelNet40(300)
    test = ModelNet40(80, 'test')
    i = 0


    
    for data, mat, label in train:
        
        if label == 1:
            i = i + 1
            if i == 1:
                x = data
                f = plt.figure()
                ax = plt.axes()
                plt.xlim([0,550])
                plt.ylim([20,500])
                # f.canvas.set_window_title("title")

    # plt.get_current_fig_manager().window.wm_geometry(str(pos[0]) + "x" + str(pos[128]) + "+"+str(pos[256])+"+"+str(pos[512]))

                image = point_cloud_isoview(data)
                image = np.ma.masked_where(image < 0.0005, image)
                cmap = plt.cm.summer
                # cmap = plt.cm.copper
                cmap.set_bad(color='white')

                ax.imshow(image, cmap=cmap)
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
                ax.axis('off')
                # ax.set_title("label")
    #ax.axis('off')

                plt.show()


                fig = plt.figure()
                ax = plt.axes(projection='3d')


                zdata = data[:, 2]
                xdata = data[:, 0]
                ydata = data[:, 1]
                ax.scatter(xdata, ydata, zdata, marker='.', color='b', vmin=0,vmax=1)
                ax.plot(xdata, ydata, zdata)
                plt.show()




                # fig = plt.figure()
                # ax = plt.axes(projection='3d')
                # zdata = data[:, 2]
                # xdata = data[:, 0]
                # ydata = data[:, 1]
                # ax.scatter(xdata, ydata, zdata, marker='.', color='b', vmin=0,vmax=1)
                # ax.grid(False)
                # ax.set_xticks([])
                # ax.set_yticks([])
                # ax.set_zticks([])
                # ax.set_xlim([-1,1])
                # ax.set_ylim([-1, 1])
                # ax.set_zlim([-1, 1])
   
                # ax.view_init(90, -90)
                # plt.show()
                break
    # data,y,z = train.__getitem__(825)
    # print(z)
#     fig = plt.figure()
#     ax = plt.axes(projection='3d')
#     zdata = data[:, 2]
#     xdata = data[:, 0]
#     ydata = data[:, 1]
#     ax.scatter3D(xdata, ydata, zdata)
#     ax.grid(False)

# # Hide axes ticks
#     ax.set_xticks([])
#     ax.set_yticks([])
#     ax.set_zticks([])
#     plt.show()
#     # x,y,z = train.__getitem__(3000)

#     # print(x)
#     # print(z)
   


