'''
@Author: Wenhao Ding
@Email: wenhaod@andrew.cmu.edu
@Date: 2019-11-18 22:13:22
LastEditTime: 2020-10-12 20:21:55
@Description: 
'''

import os
import numpy as np
import pandas as pd
import random
import copy
import matplotlib.colors as col
import matplotlib.cm as cm

import torch
from plyfile import PlyData, PlyElement

import torch
import torch.nn as nn
import torch.nn.init as init


def regist_colormap():
    startcolor = '#B22222' #'#9932CC' 
    #endcolor = '#20B2AA'
    middle_1 = '#FFA500'
    middle_2 = '#00BFFF'
    middle_3 = '#9932CC' #'#FF69B4'
    
    cmap2 = col.LinearSegmentedColormap.from_list('own_cm', [startcolor, middle_1, middle_2, middle_3])
    cm.register_cmap(cmap=cmap2)
    

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


class COLOR():
    WHITE = '\033[1;0m'
    PURPLE = '\033[1;35m'
    BLUE = '\033[1;34m'
    YELLOW = '\033[1;33m'
    GREEN = '\033[1;32m'
    RED = '\033[1;31m'


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Conv2d') != -1:
        m.weight.data.normal_(0.0, 0.0099999)
    elif classname.find('ConvTranspose2d') != -1:
        m.weight.data.normal_(0.0, 0.0099999)


def kaiming_init(m):
    if isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


def rotate_X(deg):
    rad = np.deg2rad(deg)
    return np.array([[1.0, 0.0, 0.0], [0.0, np.cos(rad), np.sin(rad)],[0.0, -np.sin(rad), np.cos(rad)]])


def rotate_Y(deg):
    rad = np.deg2rad(deg)
    return np.array([[np.cos(rad), 0.0, -np.sin(rad)], [0.0, 1.0, 0.0], [np.sin(rad), 0.0, np.cos(rad)]])


def rotate_Z(deg):
    rad = np.deg2rad(deg)
    return np.array([[np.cos(rad), np.sin(rad), 0.0], [-np.sin(rad), np.cos(rad), 0.0], [0.0, 0.0, 1.0]])


def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


def CUDA(var):
    return var.cuda() if torch.cuda.is_available() else var
    #return var


def CPU(var):
    return var.detach().cpu().numpy()


def save_ply(save_path, points, text=True):
    # without color information
    if points.shape[1] == 3:
        points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
    else:
        points = [(points[i, 0], points[i, 1], points[i, 2], points[i, 3], points[i, 4], points[i, 5]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
    el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
    PlyData([el], text=text).write(save_path)


def read_ply(filename):
    """ read XYZRGB point cloud from filename PLY file 
    """
    plydata = PlyData.read(filename).elements[0].data
    data_pd = pd.DataFrame(plydata) 
    data_np = np.zeros(data_pd.shape, dtype=np.float)
    property_names = plydata[0].dtype.names  
    for i, name in enumerate(property_names):
        data_np[:, i] = data_pd[name]
    return data_np


def read_obj(filename_obj):
    # load vertices
    vertices = []
    for line in open(filename_obj).readlines():
        if len(line.split()) == 0:
            continue
        if line.split()[0] == 'v':
            vertices.append([float(v) for v in line.split()[1:4]])
    vertices = np.vstack(vertices).astype('float32')

    # load faces
    faces_idx = []
    for line in open(filename_obj).readlines():
        if len(line.split()) == 0:
            continue
        if line.split()[0] == 'f':
            vs = line.split()[1:]
            nv = len(vs)
            v0 = int(vs[0].split('/')[0])
            for i in range(nv - 2):
                v1 = int(vs[i + 1].split('/')[0])
                v2 = int(vs[i + 2].split('/')[0])
                faces_idx.append((v0, v1, v2))
    # obj format save the index of vertices from 1
    faces_idx = np.vstack(faces_idx).astype('int32') - 1
    return vertices, faces_idx


def save_obj(filename, vertices, faces):
    assert vertices.ndim == 2
    assert faces.ndim == 2
    with open(filename, 'w') as f:
        f.write('# %s\n' % os.path.basename(filename))
        f.write('#\n')
        f.write('\n')
        f.write('g mesh\n')
        f.write('\n')
        for vertex in vertices:
            f.write('v  %.4f %.4f %.4f\n' % (vertex[0], vertex[1], vertex[2]))
        f.write('\n')
        for face in faces:
            f.write('f  %d %d %d\n' % (face[0] + 1, face[1] + 1, face[2] + 1))


def pc_to_rangemap(points, fov_horizon_range, lower_fov, upper_fov, width, height, max_range, label=None):
    # range image size, depends on your sensor, i.e., VLP-16: 16x1800, OS1-64: 64x1024
    image_rows_full = height
    image_cols = width
    ang_res_y = (upper_fov-lower_fov)/float(image_rows_full-1) # vertical resolution

    # project points to range image
    x = points[:, 0]
    y = points[:, 1]
    z = points[:, 2]
    if label is not None:
        assert label.shape[0] == points.shape[0]

    # find row id
    vertical_angle = np.arctan2(z, np.sqrt(x**2 + y**2)) * 180.0 / np.pi
    relative_vertical_angle = vertical_angle - lower_fov
    rowId = image_rows_full - np.int_(np.round_(relative_vertical_angle / ang_res_y))
    
    # find column id
    horitontal_angle = np.arctan2(x, y)
    horitontal_angle -= np.min(horitontal_angle) # should be in range [left_fov, right_fov]
    colId = np.int_(horitontal_angle*image_cols/np.deg2rad(fov_horizon_range))

    # filter range
    thisRange = np.sqrt(x**2 + y**2 + z**2)
    thisRange[thisRange > max_range] = max_range

    # save range info to range image, xyz image, and label image
    # NOTE: the initial range map should be max_range
    range_image = max_range*np.ones((image_rows_full, image_cols), dtype=np.float32)
    # for instance label
    if label is not None:
        label_image = -1*np.ones((image_rows_full, image_cols), dtype=np.int32)
    for i in range(len(thisRange)):
        if rowId[i] < 0 or rowId[i] >= image_rows_full or colId[i] < 0 or colId[i] >= image_cols:
            continue
        range_image[rowId[i], colId[i]] = thisRange[i]
        # for instance label
        if label is not None:
            label_image[rowId[i], colId[i]] = label[i]
    assert range_image.shape == (image_rows_full, image_cols)

    if label is not None:
        return range_image, label_image
    else:
        return range_image


def rangemap_to_pc(range_map, lower_fov, upper_fov, left_fov, right_fov):
    horizontal_angle = np.linspace(np.deg2rad(left_fov), np.deg2rad(right_fov), range_map.shape[1])
    vertical_angle = np.arange(range_map.shape[0]-1, -1, -1)
    vertical_angle = vertical_angle*(upper_fov-lower_fov)/(range_map.shape[0]-1)
    vertical_angle += lower_fov
    vertical_angle = np.deg2rad(vertical_angle)

    z = range_map.T*np.sin(vertical_angle)
    dist = range_map.T*np.cos(vertical_angle)
    x = np.cos(horizontal_angle)[:, None] * dist
    y = np.sin(horizontal_angle)[:, None] * dist

    x = x.reshape((-1, 1))
    y = y.reshape((-1, 1))
    z = z.reshape((-1, 1))

    pc = np.stack([x, y, z], axis=1)
    return pc
