'''
@Author: 
@Email: 
@Date: 2019-11-18 22:13:22
LastEditTime: 2021-05-30 18:40:04
@Description: 
'''

import os
import numpy as np
import random
import matplotlib.colors as col
import matplotlib.cm as cm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.init as init


def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_num, trainable_num


def regist_colormap():
    startcolor = '#B22222' #'#9932CC' 
    #endcolor = '#20B2AA'
    middle_1 = '#FFA500'
    middle_2 = '#00BFFF'
    middle_3 = '#9932CC' #'#FF69B4'
    
    cmap2 = col.LinearSegmentedColormap.from_list('own_cm', [startcolor, middle_1, middle_2, middle_3])
    cm.register_cmap(cmap=cmap2)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


class COLOR():
    WHITE = '\033[1;0m'
    PURPLE = '\033[1;35m'
    BLUE = '\033[1;34m'
    YELLOW = '\033[1;33m'
    GREEN = '\033[1;32m'
    RED = '\033[1;31m'


def kaiming_init(m):
    if isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)  # its important to set bias to 0
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


def CUDA(var):
    return var.cuda() if torch.cuda.is_available() else var
    #return var


def CPU(var):
    return var.detach().cpu().numpy()


def MSE_Loss(x, x_recon):
    return torch.nn.functional.mse_loss(x_recon, x, reduction='sum')/x.size(0)


def BCE_Loss(prediction, label):
    return torch.nn.functional.binary_cross_entropy(prediction, label, reduction='sum')/prediction.size(0)


def CE_Loss(prediction, label):
    unzip = torch.nn.functional.cross_entropy(prediction, label, reduction='none')
    # mean loss
    bce = torch.mean(unzip)
    # normalize
    unzip_norm = unzip/(torch.sum(unzip)+10e-10)
    return bce, unzip_norm[:, None]


def KLD_Loss(mu, logvar):
    return -0.5*torch.sum(1+logvar-mu.pow(2)-logvar.exp())/mu.size(0)


def vehicle_box(pose, length, width):
    xy = pose[0:2]
    o = pose[2]
    position = np.array([[-length/2, -width/2], [-length/2, width/2], [length/2, width/2], [length/2, -width/2], [-length/2, -width/2]])
    rotation_matrix = np.array([[np.cos(o), -np.sin(o)], [np.sin(o), np.cos(o)]])
    position = rotation_matrix.dot(position.T).T + xy
    return position


def plot_ego_vehicle(length, width):
    position = np.array([[-length/2, -width/2], [-length/2, width/2], [length/2, width/2], [length/2, -width/2], [-length/2, -width/2]])
    plt.plot(position[:, 0], position[:, 1], 'g-')


def plot_box(points, color):
    #plt.plot(points[:, 0], points[:, 1], 'r.')
    plt.plot(points[:, 0], points[:, 1], color+'-')


def save_sample(type, pos, grid_center, prefix=''):
    lidar_range = 40
    threshold = 0.5
    # for each frame
    for f_i in range(type.shape[0]):
        plt.figure(figsize=(6, 6))
        # for each grid
        for t_i in range(type.shape[1]):
            # for each object
            for o_i in range(type.shape[2]):
                object_type = type[f_i, t_i, o_i]
                object_pos = pos[f_i, t_i, o_i]
                object_pos[0:2] = object_pos[0:2] - grid_center[t_i]
                if object_type > threshold:  # vehicle
                    position = vehicle_box(object_pos, 3, 1.5)
                    plot_box(position, 'r')

        plot_ego_vehicle(3, 1.5)
        plt.xlabel('x (m)')
        plt.ylabel('y (m)')
        plt.title('Sample '+str(f_i))
        plt.xlim([-lidar_range, lidar_range])
        plt.ylim([-lidar_range, lidar_range])
        plt.tight_layout()
        plt.grid()
        #plt.show()
        plt.savefig('./samples/'+prefix+'_'+str(f_i)+'.png')
        plt.close('all')


def save_tree_sample(frame_list, position_scale, prefix=''):
    pose_scale = np.array([position_scale, position_scale, np.pi])
    lidar_range = 40
    # for each frame
    for f_i in range(len(frame_list)):
        plt.figure(figsize=(6, 6))
        # for each grid
        for t_i in frame_list[f_i]:
            t_i = t_i * pose_scale
            box = vehicle_box(t_i, 3, 1.5)
            plot_box(box, 'r')

        plot_ego_vehicle(3, 1.5)
        plt.xlabel('x (m)')
        plt.ylabel('y (m)')
        plt.title('Sample '+str(f_i))
        plt.xlim([-lidar_range, lidar_range])
        plt.ylim([-lidar_range, lidar_range])
        plt.tight_layout()
        plt.grid()
        #plt.show()
        plt.savefig('./samples/'+prefix+'_'+str(f_i)+'.png')
        plt.close('all')
