'''
@Author: 
@Email: 
@Date: 2019-11-18 22:13:22
LastEditTime: 2021-05-07 15:07:52
@Description: 
'''

import os
import numpy as np
import pandas as pd
import random
import matplotlib.colors as col
import matplotlib.cm as cm
import matplotlib as mpl

import torch
from plyfile import PlyData, PlyElement

import yaml
import torch
import torch.nn as nn
import torch.nn.init as init


def load_config(config_path="config.yml"):
    if os.path.isfile(config_path):
        f = open(config_path)
        return yaml.load(f, Loader=yaml.FullLoader)
    else:
        raise Exception("Configuration file is not found in the path: "+config_path)


def regist_colormap():
    startcolor = '#B22222' #'#9932CC' 
    #endcolor = '#20B2AA'
    middle_1 = '#FFA500'
    middle_2 = '#00BFFF'
    middle_3 = '#9932CC' #'#FF69B4'
    
    cmap2 = col.LinearSegmentedColormap.from_list('own_cm', [startcolor, middle_1, middle_2, middle_3])
    cm.register_cmap(cmap=cmap2)
    

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


class COLOR():
    WHITE = '\033[1;0m'
    PURPLE = '\033[1;35m'
    BLUE = '\033[1;34m'
    YELLOW = '\033[1;33m'
    GREEN = '\033[1;32m'
    RED = '\033[1;31m'


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Conv2d') != -1:
        m.weight.data.normal_(0.0, 0.0099999)
    elif classname.find('ConvTranspose2d') != -1:
        m.weight.data.normal_(0.0, 0.0099999)


def kaiming_init(m):
    if isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


def rotate_X(deg):
    rad = np.deg2rad(deg)
    return np.array([[1.0, 0.0, 0.0], [0.0, np.cos(rad), np.sin(rad)],[0.0, -np.sin(rad), np.cos(rad)]])


def rotate_Y(deg):
    rad = np.deg2rad(deg)
    return np.array([[np.cos(rad), 0.0, -np.sin(rad)], [0.0, 1.0, 0.0], [np.sin(rad), 0.0, np.cos(rad)]])


def rotate_Z(deg):
    rad = np.deg2rad(deg)
    return np.array([[np.cos(rad), np.sin(rad), 0.0], [-np.sin(rad), np.cos(rad), 0.0], [0.0, 0.0, 1.0]])


def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


def CUDA(var):
    return var.cuda() if torch.cuda.is_available() else var
    #return var


def CPU(var):
    return var.detach().cpu().numpy()


def save_ply_with_gradient(grad_x_path, grad_y_path, points):
    # regist color
    start_color = '#FF0000'
    end_color = '#00FF00'
    cmap2 = col.LinearSegmentedColormap.from_list('gradient', [start_color, end_color])
    cm.register_cmap(cmap=cmap2)
    
    grad = CPU(points.grad) # [N, 3]
    grad = np.abs(grad)
    norm_x = mpl.colors.Normalize(vmin=0, vmax=np.max(grad[:, 1]))
    m_x = cm.ScalarMappable(norm=norm_x, cmap='gradient')
    norm_y = mpl.colors.Normalize(vmin=0, vmax=np.max(grad[:, 2]))
    m_y = cm.ScalarMappable(norm=norm_y, cmap='gradient')

    color_x = m_x.to_rgba(grad[:, 1])[:, 0:3]*255
    color_y = m_y.to_rgba(grad[:, 2])[:, 0:3]*255

    xyz_rgb_x = np.concatenate([CPU(points), color_x], axis=1)
    xyz_rgb_y = np.concatenate([CPU(points), color_y], axis=1)

    save_ply(grad_x_path, xyz_rgb_x)
    save_ply(grad_y_path, xyz_rgb_y)


def save_ply(save_path, points, text=True):
    # without color information
    if points.shape[1] == 3:
        points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
    else:
        points = [(points[i, 0], points[i, 1], points[i, 2], points[i, 3], points[i, 4], points[i, 5]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
    el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
    PlyData([el], text=text).write(save_path)


def read_ply(filename):
    """ read XYZRGB point cloud from filename PLY file 
    """
    plydata = PlyData.read(filename).elements[0].data
    data_pd = pd.DataFrame(plydata) 
    data_np = np.zeros(data_pd.shape, dtype=np.float)
    property_names = plydata[0].dtype.names  
    for i, name in enumerate(property_names):
        data_np[:, i] = data_pd[name]
    
    # adjust different background files
    if filename.split('/')[-1] == 'background_1.ply':
        pass
    elif filename.split('/')[-1] == 'background_2.ply':
        data_np[:, 2] = data_np[:, 2] - 0.3
    elif filename.split('/')[-1] == 'background_3.ply':
        data_np[:, 2] = data_np[:, 2] - 1.1
    return data_np


def read_obj(filename_obj):
    # load vertices
    vertices = []
    for line in open(filename_obj).readlines():
        if len(line.split()) == 0:
            continue
        if line.split()[0] == 'v':
            vertices.append([float(v) for v in line.split()[1:4]])
    vertices = np.vstack(vertices).astype('float32')

    # load faces
    faces_idx = []
    for line in open(filename_obj).readlines():
        if len(line.split()) == 0:
            continue
        if line.split()[0] == 'f':
            vs = line.split()[1:]
            nv = len(vs)
            v0 = int(vs[0].split('/')[0])
            for i in range(nv - 2):
                v1 = int(vs[i + 1].split('/')[0])
                v2 = int(vs[i + 2].split('/')[0])
                faces_idx.append((v0, v1, v2))
    # obj format save the index of vertices from 1
    faces_idx = np.vstack(faces_idx).astype('int32') - 1
    return vertices, faces_idx


def save_obj(filename, vertices, faces):
    assert vertices.ndim == 2
    assert faces.ndim == 2
    with open(filename, 'w') as f:
        f.write('# %s\n' % os.path.basename(filename))
        f.write('#\n')
        f.write('\n')
        f.write('g mesh\n')
        f.write('\n')
        for vertex in vertices:
            f.write('v  %.4f %.4f %.4f\n' % (vertex[0], vertex[1], vertex[2]))
        f.write('\n')
        for face in faces:
            f.write('f  %d %d %d\n' % (face[0] + 1, face[1] + 1, face[2] + 1))


def pc_to_rangemap(points, fov_horizon_range, lower_fov, upper_fov, width, height, max_range):
    # range image size, depends on your sensor, i.e., VLP-16: 16x1800, OS1-64: 64x1024
    image_rows_full = height
    image_cols = width
    ang_res_y = (upper_fov-lower_fov)/float(image_rows_full-1) # vertical resolution

    # project points to range image
    x = points[:, 0]
    y = points[:, 1]
    z = points[:, 2]

    r = points[:, 3]
    g = points[:, 4]
    b = points[:, 5]
    label = np.zeros((points.shape[0], 1))
    label[(r == 0) & (g == 0) & (b == 142)] = 1

    # find row id
    vertical_angle = np.arctan2(z, np.sqrt(x**2 + y**2)) * 180.0 / np.pi
    relative_vertical_angle = vertical_angle - lower_fov
    rowId = image_rows_full - np.int_(np.round_(relative_vertical_angle / ang_res_y))
    
    # find column id
    horitontal_angle = np.arctan2(x, y)
    horitontal_angle -= np.min(horitontal_angle) # should be in range [left_fov, right_fov]
    colId = np.int_(horitontal_angle*image_cols/np.deg2rad(fov_horizon_range))

    # filter range
    thisRange = np.sqrt(x**2 + y**2 + z**2)
    thisRange[thisRange > max_range] = max_range

    # save range info to range image, xyz image, and label image
    # NOTE: the initial range map should be max_range
    range_image = max_range*np.ones((image_rows_full, image_cols), dtype=np.float32)
    # for instance label
    label_image = -1*np.ones((image_rows_full, image_cols), dtype=np.int32)
    for i in range(len(thisRange)):
        if rowId[i] < 0 or rowId[i] >= image_rows_full or colId[i] < 0 or colId[i] >= image_cols:
            continue
        range_image[rowId[i], colId[i]] = thisRange[i]
        # for instance label
        label_image[rowId[i], colId[i]] = label[i]
    assert range_image.shape == (image_rows_full, image_cols)

    return range_image, label_image


def rangemap_to_pc(range_map, lower_fov, upper_fov, left_fov, right_fov):
    horizontal_angle = np.linspace(np.deg2rad(left_fov), np.deg2rad(right_fov), range_map.shape[1])
    vertical_angle = np.arange(range_map.shape[0]-1, -1, -1)
    vertical_angle = vertical_angle*(upper_fov-lower_fov)/(range_map.shape[0]-1)
    vertical_angle += lower_fov
    vertical_angle = np.deg2rad(vertical_angle)

    z = range_map.T*np.sin(vertical_angle)
    dist = range_map.T*np.cos(vertical_angle)
    x = np.cos(horizontal_angle)[:, None] * dist
    y = np.sin(horizontal_angle)[:, None] * dist

    x = x.reshape((-1, 1))
    y = y.reshape((-1, 1))
    z = z.reshape((-1, 1))

    pc = np.stack([x, y, z], axis=1)
    return pc
