from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys

import random
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.autograd import Variable

import cv2
import h5py
import pickle

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt



def visualize_point_cloud(pc0, pc1=None, store_path=None):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.set_aspect('equal')

    n_particles_to_show = 1000

    pc0_idx = np.random.choice(pc0.shape[0], n_particles_to_show, replace=False)
    pc0 = pc0[pc0_idx]
    # pc0 = pc0[pc0[:, 0] < 2.0]
    X, Y, Z = pc0[:, 0], pc0[:, 2], pc0[:, 1]
    ax.scatter(X, Y, Z, c='b', s=20)

    if pc1 is not None:
        pc1_idx = np.random.choice(pc1.shape[0], n_particles_to_show, replace=False)
        pc1 = pc1[pc1_idx]
        X1, Y1, Z1 = pc1[:, 0], pc1[:, 2], pc1[:, 1]
        ax.scatter(X1, Y1, Z1, c='r', s=20)

        X = np.concatenate([X, X1])
        Y = np.concatenate([Y, Y1])
        Z = np.concatenate([Z, Z1])

    # Create cubic bounding box to simulate equal aspect ratio
    max_range = np.array([X.max()-X.min(), Y.max()-Y.min(), Z.max()-Z.min()]).max()
    Xb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][0].flatten() + 0.5*(X.max()+X.min())
    Yb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][1].flatten() + 0.5*(Y.max()+Y.min())
    Zb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][2].flatten() + 0.5*(Z.max()+Z.min())
    # Comment or uncomment following both lines to test the fake bounding box:
    for xb, yb, zb in zip(Xb, Yb, Zb):
        ax.plot([xb], [yb], [zb], 'w')

    # ax.set_xlim(-1., 1.)
    # ax.set_ylim(-1., 1.)
    # ax.set_zlim(-0.1, 0.8)


    if store_path is not None:
        tmp_dir = store_path[:-4]
        os.system('mkdir -p %s' % tmp_dir)
        for azim in range(0, 360, 3):
            ax.view_init(elev=30., azim=azim)
            fig.savefig('%s/img_%d.png' % (tmp_dir, azim))

        fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
        out = cv2.VideoWriter(store_path, fourcc, 20, (640, 480))

        for azim in range(0, 360, 3):
            img = cv2.imread('%s/img_%d.png' % (tmp_dir, azim))
            out.write(img)

        out.release()

        # os.system('rm -rf %s' % tmp_dir)

    plt.show()



def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']



def store_data_pickle(data_names, data, path):
    d = {}
    for i in range(len(data_names)):
        d[data_names[i]] = data[i]
    pickle.dump(d, open(path, 'wb'))



class AverageMeter(object):
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def rand_int(lo, hi):
    return np.random.randint(lo, hi)


def rand_float(lo, hi):
    return np.random.rand() * (hi - lo) + lo


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class ChamferLoss(torch.nn.Module):
    def __init__(self):
        super(ChamferLoss, self).__init__()

    def chamfer_distance(self, x, y):
        # x: [N, D]
        # y: [M, D]
        x = x.repeat(y.size(0), 1, 1)   # x: [M, N, D]
        x = x.transpose(0, 1)           # x: [N, M, D]
        y = y.repeat(x.size(0), 1, 1)   # y: [N, M, D]
        dis = torch.norm(torch.add(x, -y), 2, dim=2)    # dis: [N, M]
        dis_xy = torch.mean(torch.min(dis, dim=1)[0])   # dis_xy: mean over N
        dis_yx = torch.mean(torch.min(dis, dim=0)[0])   # dis_yx: mean over M

        return dis_xy + dis_yx

    def __call__(self, pred, label):
        return self.chamfer_distance(pred, label)

