from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
from torchvision import transforms

import random
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.autograd import Variable

# new is visual dynamics
from chainer import cuda

def my_collate(batch):
    len_batch = len(batch[0])
    len_rel = 2

    ret = []
    for i in range(len_batch - len_rel):
        d = [item[i] for item in batch]
        if isinstance(d[0], int):
            d = torch.LongTensor(d)
        else:
            d = torch.FloatTensor(torch.stack(d))
        ret.append(d)

    # processing relations
    # R: B x seq_length x n_rel x (n_p + n_s)
    for i in range(len_rel):
        R = [item[-len_rel + i] for item in batch]
        max_n_rel = 0
        seq_length, _, N = R[0].size()
        for j in range(len(R)):
            max_n_rel = max(max_n_rel, R[j].size(1))
        for j in range(len(R)):
            r = R[j]
            r = torch.cat([r, torch.zeros(seq_length, max_n_rel - r.size(1), N)], 1)
            R[j] = r

        R = torch.FloatTensor(torch.stack(R))

        ret.append(R)

    return tuple(ret)


def my_collate_extra(batch):
    """
    (attr, particles, n_particle, n_shape, scene_params, Rr, Rs)
    attr (n_p + n_s) x attr_dim
    particles  seq_length x (n_p + n_s) x state_dim
    :param batch:
    :return:
    """
    len_batch = len(batch[0])
    len_rel = 2

    ret = []

    # padding attr
    # attr : (n_p + n_s) * attr_dim
    attr = [item[0] for item in batch]
    max_n = max([item[0].shape[0] for item in batch])  # max(n_p+n_s) for padding
    attr_dim = attr[0].shape[1]
    for i in range(len(attr)):
        attr_ = attr[i]
        attr[i] = torch.cat([attr_, torch.zeros(max_n - attr_.shape[0], attr_dim)], 0)
    ret.append(torch.FloatTensor(torch.stack(attr)))

    # padding particles
    # particles : seq_length x (n_p + n_s) x state_dim
    # padding into seq_length
    particles = [item[1] for item in batch]
    seq_length, _, state_dim = particles[0].shape
    for i in range(len(particles)):
        particles_ = particles[i]
        particles[i] = torch.cat([particles_, torch.zeros(seq_length, max_n - particles_.shape[1], state_dim)], 1)
    ret.append(torch.FloatTensor(torch.stack(particles)))

    for i in range(2, len_batch - len_rel):
        d = [item[i] for item in batch]
        if isinstance(d[0], int):
            d = torch.LongTensor(d)
        else:
            d = torch.FloatTensor(torch.stack(d))
        ret.append(d)

    # processing relations
    # R: B x seq_length x n_rel x (n_p + n_s)
    for i in range(len_rel):
        R = [item[-len_rel + i] for item in batch]
        max_n_rel = 0
        for j in range(len(R)):
            max_n_rel = max(max_n_rel, R[j].size(1))
        for j in range(len(R)):
            r = R[j]
            seq_length, _, N = r.size()
            r = torch.cat([r, torch.zeros(seq_length, max_n_rel - r.size(1), N)], 1)
            r = torch.cat([r, torch.zeros(seq_length, max_n_rel, max_n - N)], -1)
            R[j] = r

        R = torch.FloatTensor(torch.stack(R))

        ret.append(R)

    return tuple(ret)


def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


class Tee(object):
    def __init__(self, name, mode):
        self.file = open(name, mode)
        self.stdout = sys.stdout
        sys.stdout = self

    def __del__(self):
        sys.stdout = self.stdout
        self.file.close()

    def write(self, data):
        self.file.write(data)
        self.stdout.write(data)

    def flush(self):
        self.file.flush()

    def close(self):
        self.__del__()

def get_image_to_tensor_balanced(image_size=0):
    ops = []
    if image_size > 0:
        ops.append(transforms.Resize(image_size))
    ops.extend(
        [transforms.ToTensor()]
    )
    return transforms.Compose(ops)

class AverageMeter(object):
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def check_gradient(step):
    def hook(grad):
        print(step, torch.mean(grad, 1)[:4])
    return hook


def add_log(fn, content, is_append=True):
    if is_append:
        with open(fn, "a+") as f:
            f.write(content)
    else:
        with open(fn, "w+") as f:
            f.write(content)


def rand_int(lo, hi):
    return np.random.randint(lo, hi)


def rand_float(lo, hi):
    return np.random.rand() * (hi - lo) + lo


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def to_var(tensor, use_gpu, requires_grad=False):
    if use_gpu:
        return Variable(torch.FloatTensor(tensor).cuda(),
                        requires_grad=requires_grad)
    else:
        return Variable(torch.FloatTensor(tensor),
                        requires_grad=requires_grad)


def make_graph(log, title, args):
    """make a loss graph"""
    plt.plot(log)
    plt.xlabel('iter')
    plt.ylabel('loss')

    title + '_loss_graph'
    plt.title(title)
    plt.savefig(os.path.join(args.logf, title + '.png'))
    plt.close()


def get_color_from_prob(prob, colors):
    # there's only one instance
    if len(colors) == 1:
        return colors[0] * prob
    elif len(prob) == 1:
        return colors * prob[0]
    else:
        res = np.zeros(4)
        for i in range(len(prob)):
            res += prob[i] * colors[i]
        return res


def create_instance_colors(n):
    # TODO: come up with a better way to initialize instance colors
    return np.array([
        [1., 0., 0., 1.],
        [0., 1., 0., 1.],
        [0., 0., 1., 1.],
        [1., 1., 0., 1.],
        [1., 0., 1., 1.]])[:n]


def convert_groups_to_colors(group, n_particles, n_rigid_instances, instance_colors, env=None):
    """
    Convert grouping to RGB colors of shape (n_particles, 4)
    :param grouping: [p_rigid, p_instance, physics_param]
    :return: RGB values that can be set as color densities
    """
    # p_rigid: n_instance
    # p_instance: n_p x n_instance
    p_rigid, p_instance = group[:2]

    p = p_instance

    colors = np.empty((n_particles, 4))

    for i in range(n_particles):
        colors[i] = get_color_from_prob(p[i], instance_colors)

    # print("colors", colors)
    return colors


def visualize_point_clouds(point_clouds, c=['b', 'r'], view=None, store=False, store_path=''):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.set_aspect('equal')

    frame = plt.gca()
    frame.axes.xaxis.set_ticklabels([])
    frame.axes.yaxis.set_ticklabels([])
    frame.axes.zaxis.set_ticklabels([])

    for i in range(len(point_clouds)):
        points = point_clouds[i]
        ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=c[i], s=10, alpha=0.3)

    X, Y, Z = point_clouds[0][:, 0], point_clouds[0][:, 1], point_clouds[0][:, 2]

    max_range = np.array([X.max() - X.min(), Y.max() - Y.min(), Z.max() - Z.min()]).max()
    Xb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][0].flatten() + 0.5 * (X.max() + X.min())
    Yb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][1].flatten() + 0.5 * (Y.max() + Y.min())
    Zb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][2].flatten() + 0.5 * (Z.max() + Z.min())
    # Comment or uncomment following both lines to test the fake bounding box:
    for xb, yb, zb in zip(Xb, Yb, Zb):
        ax.plot([xb], [yb], [zb], 'w')

    ax.grid(False)
    plt.show()

    if view is None:
        view = 0, 0
    ax.view_init(view[0], view[1])
    plt.draw()

    # plt.pause(5)

    if store:
        os.system('mkdir -p ' + store_path)
        fig.savefig(os.path.join(store_path, "vis.png"), bbox_inches='tight')

    '''
    for angle in range(0, 360, 2):
        ax.view_init(90, angle)
        plt.draw()
        # plt.pause(.001)

        if store:
            if angle % 100 == 0:
                print("Saving frame %d / %d" % (angle, 360))

            os.system('mkdir -p ' + store_path)
            fig.savefig(os.path.join(store_path, "%d.png" % angle), bbox_inches='tight')
    '''


def quatFromAxisAngle(axis, angle):
    axis /= np.linalg.norm(axis)

    half = angle * 0.5
    w = np.cos(half)

    sin_theta_over_two = np.sin(half)
    axis *= sin_theta_over_two

    quat = np.array([axis[0], axis[1], axis[2], w])

    return quat


def quatFromAxisAngle_var(axis, angle):
    axis /= torch.norm(axis)

    half = angle * 0.5
    w = torch.cos(half)

    sin_theta_over_two = torch.sin(half)
    axis *= sin_theta_over_two

    quat = torch.cat([axis, w])
    # print("quat size", quat.size())

    return quat


class ChamferLoss(torch.nn.Module):
    def __init__(self):
        super(ChamferLoss, self).__init__()

    def chamfer_distance(self, x, y):
        # x: [N, D]
        # y: [M, D]
        x = x.repeat(y.size(0), 1, 1)  # x: [M, N, D]
        x = x.transpose(0, 1)  # x: [N, M, D]
        y = y.repeat(x.size(0), 1, 1)  # y: [N, M, D]
        dis = torch.norm(torch.add(x, -y), 2, dim=2)  # dis: [N, M]
        dis_xy = torch.mean(torch.min(dis, dim=1)[0])  # dis_xy: mean over N
        dis_yx = torch.mean(torch.min(dis, dim=0)[0])  # dis_yx: mean over M

        return dis_xy + dis_yx

    def __call__(self, pred, label):
        return self.chamfer_distance(pred, label)

# def ChamferInChamfer(x, y):
#     """
#
#     :param x: b * A * 3
#     :param y: b * B *3
#     :return:
#     """
#     assert x.shape[0] == y.shape[0]
#     b = x.shape[0]
#     A, B = x.shape[1], y.shape[1]
#     x = x.expand(B, b, A, 3).permute(1, 0, 2, 3)  # b * B * A * 3
#     x = x.transpose(1, 2) # b * A * B * 3
#     y = y.expand(A, b, B, 3).permute(1, 0, 2, 3)  # b * A * B * 3
#     dist = torch.norm(torch.add(x, -y), dim=-1)  # b * A * B * 1
#     weight_xy = torch.min(dist, dim=1)[0]  # b * B * 3
#     weight_yz = torch.min(dist, dim=2)[0]  # b * A * 3
#
#
#



def get_l2_loss(g):
    num_particles = len(g)
    return torch.norm(num_particles - torch.norm(g, dim=1, keepdim=True))

def FPS_sampling(pcl, sampling_n):
    """
    Personal Implementation of Farthest Particle Sampling
    :param pcl: N * 3
    :param sampling_n: points number to sample
    :return: [index, pcl_sampled]
    """

    # transform to torch.Tensor if input is numpy.ndarray
    if isinstance(pcl, np.ndarray):
        pcl = torch.from_numpy(pcl)
    N = pcl.shape[0]
    assert sampling_n < N
    selected_pool = []
    unselected_pool = [i for i in range(N)]
    for k in range(sampling_n):
        if k == 0:
            selected_point = 0
            unselected_pool.remove(selected_point)
            selected_pool.append(selected_point)
        else:
            # A for seleced, B for unselected
            # print(selected_pool)
            A_cld = pcl[selected_pool, :]  # a * 3
            B_cld = pcl[unselected_pool, :]  # b * 3
            A_cld = A_cld.unsqueeze(0)  # 1 * a * 3
            B_cld = B_cld.unsqueeze(1)  # b * 1 * 3

            distance_mat = B_cld - A_cld  # b * a * 3
            distance_mat = distance_mat[:, :, 0] ** 2 + \
                           distance_mat[:, :, 1] ** 2 + \
                           distance_mat[:, :, 2] ** 2  # d^2, b * a
            min_mat, _ = distance_mat.min(1)  # b * 1
            _, max_idx = min_mat.max(0)
            selected_point = unselected_pool[max_idx.item()]
            unselected_pool.remove(selected_point)
            selected_pool.append(selected_point)
    return selected_pool, pcl[selected_pool]

def l2_norm(x, y):
    """Calculate l2 norm (distance) of `x` and `y`.
    Args:
        x (numpy.ndarray or cupy): (batch_size, num_point, coord_dim)
        y (numpy.ndarray): (batch_size, num_point, coord_dim)
    Returns (numpy.ndarray): (batch_size, num_point,)
    """
    return ((x - y) ** 2).sum(axis=2)

def farthest_point_sampling(pts, k, initial_idx=None, metrics=l2_norm,
                            skip_initial=False, indices_dtype=np.int32,
                            distances_dtype=np.float32):
    """Batch operation of farthest point sampling
    Code referenced from below link by @Graipher
    https://codereview.stackexchange.com/questions/179561/farthest-point-algorithm-in-python
    Args:
        pts (numpy.ndarray or cupy.ndarray): 2-dim array (num_point, coord_dim)
            or 3-dim array (batch_size, num_point, coord_dim)
            When input is 2-dim array, it is treated as 3-dim array with
            `batch_size=1`.
        k (int): number of points to sample
        initial_idx (int): initial index to start farthest point sampling.
            `None` indicates to sample from random index,
            in this case the returned value is not deterministic.
        metrics (callable): metrics function, indicates how to calc distance.
        skip_initial (bool): If True, initial point is skipped to store as
            farthest point. It stabilizes the function output.
        xp (numpy or cupy):
        indices_dtype (): dtype of output `indices`
        distances_dtype (): dtype of output `distances`
    Returns (tuple): `indices` and `distances`.
        indices (numpy.ndarray or cupy.ndarray): 2-dim array (batch_size, k, )
            indices of sampled farthest points.
            `pts[indices[i, j]]` represents `i-th` batch element of `j-th`
            farthest point.
        distances (numpy.ndarray or cupy.ndarray): 3-dim array
            (batch_size, k, num_point)
    """
    if pts.ndim == 2:
        # insert batch_size axis
        pts = pts[None, ...]
    assert pts.ndim == 3
    xp = cuda.get_array_module(pts)
    batch_size, num_point, coord_dim = pts.shape
    indices = xp.zeros((batch_size, k, ), dtype=indices_dtype)

    # distances[bs, i, j] is distance between i-th farthest point `pts[bs, i]`
    # and j-th input point `pts[bs, j]`.
    distances = xp.zeros((batch_size, k, num_point), dtype=distances_dtype)
    if initial_idx is None:
        indices[:, 0] = 0
    else:
        indices[:, 0] = initial_idx

    batch_indices = xp.arange(batch_size)
    farthest_point = pts[batch_indices, indices[:, 0]]
    # minimum distances to the sampled farthest point
    try:
        min_distances = metrics(farthest_point[:, None, :], pts)
    except Exception as e:
        import IPython; IPython.embed()

    if skip_initial:
        # Override 0-th `indices` by the farthest point of `initial_idx`
        indices[:, 0] = xp.argmax(min_distances, axis=1)
        farthest_point = pts[batch_indices, indices[:, 0]]
        min_distances = metrics(farthest_point[:, None, :], pts)

    distances[:, 0, :] = min_distances
    for i in range(1, k):
        indices[:, i] = xp.argmax(min_distances, axis=1)
        farthest_point = pts[batch_indices, indices[:, i]]
        dist = metrics(farthest_point[:, None, :], pts)
        distances[:, i, :] = dist
        min_distances = xp.minimum(min_distances, dist)
    return indices, distances


def load_shape_info(dset, root, id=None):
    # shape : some env objects like container and robo-arms
    if dset == 'pour':
        info_pth = os.path.join(root, 'datasets/shape_info/', 'FluidPour.txt')
    elif dset == 'shake':
        info_pth = os.path.join(root, 'datasets/shape_info/', 'FluidShake.txt')
    elif dset == 'pour_extra':
        assert id is not None
        info_pth = os.path.join(root, 'datasets/shape_info/', f'shapes_FluidPourExtra_{id}.txt')
    elif dset == 'shake_extra':
        info_pth = os.path.join(root, 'datasets/shape_info/', f'shapes_FluidShakeExtra_{id}.txt')
    elif dset == 'granular_push':
        info_pth = os.path.join(root, 'datasets/shape_info/', f'granular_push.txt')
    f = open(info_pth)
    mesh_num = box_num = 0
    shape_info = []
    for lines in f:
        if lines.startswith('asset'):
            mesh_num += 1
            continue
        else:
            """
            [4.   0.02 4.  ] 0 [0.9 0.9 0.9]
            """
            end_pos = lines.find("]")
            visible = int(lines[end_pos + 2])

            x, y, z = lines[:end_pos].split()
            x = x[1:]

            x, y, z = float(x), float(y), float(z)

            shape_info.append([(x, y, z), visible])
    return mesh_num, shape_info


def box_sampling(box_info, N, same_interval, interval):
    """
    box_info contains [scale, rot, center]
    assume that the box is something like a plane, with a relatively small thickness,
    so we just sample one point to represent the thickness
    """
    # find the min dimension
    scale, rot, center = box_info
    if isinstance(scale, list):
        scale = np.array(scale)

    min_axis = np.argmin(np.abs(scale))

    x, y, z = scale[0], scale[1], scale[2]
    if not same_interval:
        axis = [np.linspace(-x / 2, x / 2, N), np.linspace(-y / 2, y / 2, N), np.linspace(-z / 2, z / 2, N)]
        axis[min_axis] = [0 for i in range(N)]
        Nx = Ny = Nz = N
    else:
        Nx, Ny, Nz = int(x // interval), int(y // interval), int(z // interval)
        # print("Nx:{}, Ny:{}, Nz:{}".format(Nx, Ny, Nz))
        if Nz <= 1:
            Nz = 2
        if Ny <= 1:
            Ny = 2
        if Nx <= 1:
            Nx = 2

        axis = [np.linspace(-x / 2, x / 2, Nx), np.linspace(-y / 2, y / 2, Ny), np.linspace(-z / 2, z / 2, Nz)]

    x_m, y_m, z_m = axis

    X, Y, Z = np.meshgrid(x_m, y_m, z_m)

    mesh = np.concatenate((X[..., None], Y[..., None], Z[..., None]), -1)  # N * N * 3

    mesh = mesh.reshape((Nx * Ny * Nz, 3))

    mesh = mesh[:, :, None]
    mesh = np.matmul(rot, mesh)

    mesh = mesh[:, :, 0]
    mesh += center

    return mesh
def capture_scene_image(feature, output_pth='debug.png', angle=100, color=None, dset='pour', dpi=200, axis='on', h=10):

    x, y, z = feature[:, 0], feature[:, 1], feature[:, 2]
    fig = plt.figure()
    ax3D = fig.add_subplot(111, projection='3d')
    ax3D.scatter(x, z, y, s=5, marker='o', color=color)
    ax3D.view_init(h, angle)
    ax3D.set_xlabel('x')
    ax3D.set_ylabel('y')
    ax3D.set_zlabel('z')
    if dset == 'pour':
        ax3D.set_xlim3d(-1.5, 1.5)
        ax3D.set_ylim3d(-1.5, 1.5)
        ax3D.set_zlim3d(0.5, 3.5)
    if dset == 'shake':
        ax3D.set_xlim3d(-1, 1)
        ax3D.set_ylim3d(-1, 1)
        ax3D.set_zlim3d(0, 2)
    if 'granular' in dset:
        ax3D.set_xlim3d(-3, 3)
        ax3D.set_ylim3d(-3, 3)
        ax3D.set_zlim3d(0, 6)
    #

    plt.axis(axis)
    plt.savefig(output_pth, dpi=dpi, bbox_inches='tight')
    plt.close()


def capture_scene_image_with_norm(feature, norm, output_pth='debug.png', angle=100, color=None, dset='pour', dpi=200, axis='on'):


    assert feature.shape[0] == norm.shape[0]

    x, y, z = feature[:, 0], feature[:, 1], feature[:, 2]
    fig = plt.figure()
    ax3D = fig.add_subplot(111, projection='3d')
    ax3D.scatter(x, z, y, s=2, marker='o', color=color)
    ax3D.view_init(10, angle)
    ax3D.set_xlabel('x')
    ax3D.set_ylabel('y')
    ax3D.set_zlabel('z')


    # add draw norm for each point
    for i in range(feature.shape[0]):
        ax3D.plot([feature[i][0], feature[i][0] + norm[i][0] / 5],

                  [feature[i][2], feature[i][2] + norm[i][2] / 5],

                  [feature[i][1], feature[i][1] + norm[i][1] / 5], linewidth=1)


    if dset == 'pour':
        ax3D.set_xlim3d(-1.5, 1.5)
        ax3D.set_ylim3d(-1.5, 1.5)
        ax3D.set_zlim3d(0.5, 3.5)
    if dset == 'shake':
        ax3D.set_xlim3d(-1, 1)
        ax3D.set_ylim3d(-1, 1)
        ax3D.set_zlim3d(0, 2)
    if 'granular' in dset:
        ax3D.set_xlim3d(-3, 3)
        ax3D.set_ylim3d(-3, 3)
        ax3D.set_zlim3d(0, 6)
    #

    plt.axis(axis)
    plt.savefig(output_pth, dpi=dpi, bbox_inches='tight')
    plt.close()

def capture_motion_image(feature_list, output_pth='motion.png', angle=100, color=None, point=None):
    frame_num = len(feature_list)
    fig = plt.figure()
    ax3D = fig.add_subplot(111, projection='3d')
    for i in range(frame_num):
        print(i)
        feature = feature_list[i]
        x, y, z = feature[:, 0], feature[:, 1], feature[:, 2]
        ax3D.scatter(x, z, y, s=1, marker='o', color=color[i])
        if i > 0:
            feautre_pre = feature_list[i - 1]
            if point is not None:
                l = point
            else:
                l = range(feature_list[0].shape[0])
            for o in l:
                x_start, x_end = feature[o][0], feautre_pre[o][0]
                y_start, y_end = feature[o][1], feautre_pre[o][1]
                z_start, z_end = feature[o][2], feautre_pre[o][2]
                ax3D.plot(xs=[x_start, x_end], ys=[z_start, z_end], zs=[y_start, y_end], c='b', linewidth=1)
    plt.savefig(output_pth, dpi=2000)
    plt.close()



def rotation_matrix_from_quaternion(quant):
    # params dim - 4: w, x, y, z
    if isinstance(quant, np.ndarray):
        quant = torch.from_numpy(quant)
    if len(quant.shape) == 3:
        quant = quant.unsqueeze(1)
    one = torch.ones(1, 1)
    zero = torch.zeros(1, 1)

    # multiply the rotation matrix from the right-hand side
    # the matrix should be the transpose of the conventional one

    # Reference
    # http://www.euclideanspace.com/maths/geometry/rotations/conversions/quaternionToMatrix/index.htm

    quant = quant / torch.norm(quant)
    w, x, y, z = quant[0].view(1, 1), quant[1].view(1, 1), quant[2].view(1, 1), quant[3].view(1, 1)

    rot = torch.cat((
        torch.cat((one - y * y * 2 - z * z * 2, x * y * 2 + z * w * 2, x * z * 2 - y * w * 2), 1),
        torch.cat((x * y * 2 - z * w * 2, one - x * x * 2 - z * z * 2, y * z * 2 + x * w * 2), 1),
        torch.cat((x * z * 2 + y * w * 2, y * z * 2 - x * w * 2, one - x * x * 2 - y * y * 2), 1)), 0)

    return rot


def mutual_distance_3d_batch(pcd, sqrt=False):
    """

    :param pcd: B * N * 3
    :param sqrt:
    :return: B * N * N * 1
    """
    pcd_x, pcd_y, pcd_z = pcd[:, :, 0], pcd[:, :, 1], pcd[:, :, 2]
    x_dist = pcd_x[:, :, None] - pcd_x[:, None, :]
    y_dist = pcd_y[:, :, None] - pcd_y[:, None, :]
    z_dist = pcd_z[:, :, None] - pcd_z[:, None, :]

    dist_all = x_dist ** 2 + y_dist ** 2 + z_dist ** 2
    if sqrt:
        dist_all = torch.sqrt(dist_all)
    return dist_all



def coalition_loss_batch(pcd, loss_bar, grad_k, n_p):
    """

    :param pcd: B * N * 3, batch of point cloud
    :return: B * 1
    """
    B, N = pcd.shape[0], pcd.shape[1]

    mutual_dist = mutual_distance_3d_batch(pcd, sqrt=False)  # B * N * N
    mutual_dist[:, list(range(0, N)), list(range(0, N))] += 999
    nn_dist = mutual_dist.min(-1)[0][:, :n_p]  # B * N * 1
    # mask = nn_dist < loss_bar
    # loss = torch.exp(- nn_dist[mask] / grad_k).sum() / (B * N)
    relu = torch.nn.ReLU()
    loss = ((relu(loss_bar - nn_dist)) ** 2).sum()
    return loss


from chamferdist import ChamferDistance

def dy_loss(gt_pos, pred_pos, n_particles, args, box_info=None):
    crit = ChamferDistance()
    if args.env == 'pour' or args.env == 'pour_extra':

        b = n_particles.shape[0]

        for i in range(b):
            n_particle = n_particles[i]
            if i == 0:
                loss = crit(gt_pos[[i], :n_particle], pred_pos[[i], :n_particle]) ** 2 + args.chamfer_ratio * crit(
            pred_pos[[i], :n_particle], gt_pos[[i], :n_particle]) ** 2
            else:
                loss += crit(gt_pos[[i], :n_particle], pred_pos[[i], :n_particle]) ** 2 + args.chamfer_ratio * crit(
                    pred_pos[[i], :n_particle], gt_pos[[i], :n_particle]) ** 2

        loss = loss / 2
        pred_loss = loss

        if args.emd_loss_ratio > 0:
            # print(gt_pos[:, :n_particle].shape, pred_pos[:, :n_particle].shape)
            # print(earth_mover_distance(gt_pos[:, :n_particle], pred_pos[:, :n_particle], transpose=False).shape)
            loss += args.emd_loss_ratio * earth_mover_distance(gt_pos[:, :n_particle], pred_pos[:, :n_particle],
                                                               transpose=False).sum()

    elif args.env == 'shake_extra':

        b = n_particles.shape[0]

        for i in range(b):
            n_particle = n_particles[i]
            water_index_rage = [i for i in range(n_particle-27*3)]
            red_box_index_range = [n_particle - 27 * 3 + i for i in range(27)]
            green_box_index_range = [n_particle - 27 * 2 + i for i in range(27)]
            yellow_box_index_range = [n_particle - 27 * 1 + i for i in range(27)]
            index_range_list = [water_index_rage, red_box_index_range, green_box_index_range, yellow_box_index_range]
            box_mask = box_info[i]
            for k, idx in enumerate(index_range_list):
                if i == 0 and k == 0:
                    loss = crit(gt_pos[[i], idx].unsqueeze(0), pred_pos[[i], idx].unsqueeze(0)) ** 2 + args.chamfer_ratio * crit(
                        pred_pos[[i], idx].unsqueeze(0), gt_pos[[i], idx].unsqueeze(0)) ** 2
                else:
                    if k == 0 or (k > 0 and box_mask[k-1]):
                        scale = 0
                        if k > 0:
                            scale = 5
                        loss += scale * (crit(gt_pos[[i], idx].unsqueeze(0),pred_pos[[i], idx].unsqueeze(0)) ** 2 + \
                                         args.chamfer_ratio * crit(pred_pos[[i], idx].unsqueeze(0), gt_pos[[i], idx].unsqueeze(0)) ** 2)

        loss = loss / 2
        pred_loss = loss


    elif args.env == 'granular_push':

        b = n_particles.shape[0]

        for i in range(b):
            n_particle = n_particles[i]
            if i == 0:
                loss = crit(gt_pos[[i], :n_particle], pred_pos[[i], :n_particle]) ** 2 + args.chamfer_ratio * crit(
                    pred_pos[[i], :n_particle], gt_pos[[i], :n_particle]) ** 2
            else:
                loss += crit(gt_pos[[i], :n_particle], pred_pos[[i], :n_particle]) ** 2 + args.chamfer_ratio * crit(
                    pred_pos[[i], :n_particle], gt_pos[[i], :n_particle]) ** 2

        loss = loss / 2
        pred_loss = loss

    if args.emd_loss_ratio > 0:
            # print(gt_pos[:, :n_particle].shape, pred_pos[:, :n_particle].shape)
            # print(earth_mover_distance(gt_pos[:, :n_particle], pred_pos[:, :n_particle], transpose=False).shape)
            loss += args.emd_loss_ratio * earth_mover_distance(gt_pos[:, :n_particle],
                                                               pred_pos[:, :n_particle],
                                                               transpose=False).sum()
    col_loss = None
    if args.coalition_loss:
        col_loss = coalition_loss_batch(pred_pos, loss_bar=args.coalition_bar, grad_k=1, n_p=n_particle)
        if col_loss is not None:
            loss += col_loss * args.coal_weight


    return pred_loss, col_loss,  loss



def dy_vis(args, vis_pth, pred_pos, n_particles, gt_pos, n_shape, i):
    if args.env in ['pour', 'pour_extra']:


        if not args.boundary_free:
            gt_size = n_particles[0] + n_shape
            q = n_shape
        else:
            gt_size = n_particles[0]
            q = 0

        capture_scene_image(np.concatenate([pred_pos[0, :n_particles[0], :3].detach().cpu().numpy(),
                                            gt_pos[0, :(gt_size), :3].cpu().numpy()]),
                            angle=100,
                            output_pth=vis_pth + "/{}.png".format(i),
                            color=['r'] * n_particles[0] + ['b'] * n_particles[0] + ['g'] * q)

    elif args.env == 'shake':
        capture_scene_image(
            np.concatenate([pred_pos[0, :n_particles[0], :3].detach().cpu().numpy(), gt_pos[0].cpu().numpy()]),
            dset='shake',
            angle=100,
            output_pth=vis_pth + "/{}.png".format(i),
            color=['orange'] * 64 + ['r'] * (n_particles[0] - 64) + ['lime'] * 64 + ['b'] * (n_particles[0] - 64) + [
                'g'] * 208)

    elif args.env == 'granular_push':

        capture_scene_image(np.concatenate([pred_pos[0, :n_particles[0], :3].detach().cpu().numpy(),
                                            gt_pos[0, :(n_particles[0] + n_shape), :3].cpu().numpy()]),
                            angle=100,
                            output_pth=vis_pth + "/{}.png".format(i),
                            color=['r'] * n_particles[0] + ['b'] * n_particles[0] + ['g'] * n_shape,
                            dset='granular_push')



def dy_plot(train_loss_list, valid_loss_list, best_valid_loss_list, valid_loss_epc, train_loss_epc,
            train_loss_list_1, train_loss_list_2, train_loss_list_3, valid_loss_list_1,
            valid_loss_list_2, valid_loss_list_3, args):
    # dy_plot(train_loss_list, valid_loss_list, )

    plt.plot(train_loss_list, color='r')
    plt.savefig(args.outf + "/curve/train_iter.png")
    plt.close()

    plt.plot(valid_loss_list, color='b')
    plt.savefig(args.outf + "/curve/val_iter.png")
    plt.close()

    plt.plot(best_valid_loss_list)
    plt.savefig(args.outf + '/curve/best_valid.png')
    plt.close()

    plt.plot(best_valid_loss_list)
    plt.savefig(args.outf + '/curve/best_valid.png')
    plt.close()

    plt.plot(train_loss_epc)
    plt.savefig(args.outf + '/curve/train_epc.png')
    plt.close()

    plt.plot(valid_loss_epc)
    plt.savefig(args.outf + '/curve/valid_epc.png')
    plt.close()

    if args.train_extra_steps:
        plt.plot(train_loss_list_1, color='r')
        plt.plot(train_loss_list_2, color='b')
        if args.train_extra_steps_num > 1:
            plt.plot(train_loss_list_3, color='y')
            plt.legend(["rolling=1", "rolling=2", "rolling=3"])
        else:
            plt.legend(["rolling=1", "rolling=2"])
        plt.savefig(args.outf + '/curve/train_div.png')
        plt.close()

        plt.plot(valid_loss_list_1, color='r')
        plt.plot(valid_loss_list_2, color='b')
        if args.train_extra_steps_num > 1:
            plt.plot(valid_loss_list_3, color='y')
            plt.legend(["rolling=1", "rolling=2", "rolling=3"])
        else:
            plt.legend(["rolling=1", "rolling=2"])
        plt.savefig(args.outf + '/curve/valid_div.png')
        plt.close()


def list_min_index(a):
    min_index = 0
    for i in range(len(a)):
        if a[i] < a[min_index]:
            min_index = i
    return min_index


def get_color_info(traj_id):
    import pickle
    scene_info = pickle.load(open('/home/htxue/datasets/data_FluidShakeExtra_new/{}/info.p'.format(traj_id), 'rb'))[
        'scene_params']

    color_dir = {
        '010': 'green',
        '100': 'red',
        '110': 'yellow'
    }

    v_red = v_green = v_yellow = 0  # initialize

    v_list = [v_red, v_green, v_yellow]

    def assign_color(name, v_list):
        if name == 'red':
            v_list[0] = 1
        elif name == 'green':
            v_list[1] = 1
        elif name == 'yellow':
            v_list[2] = 1

    if len(scene_info) == 41:
        v_list[0] = v_list[1] = v_list[2] = 1

    elif len(scene_info) == 31:
        color1 = str(int(scene_info[-5])) + str(int(scene_info[-4])) + str(int(scene_info[-3]))
        color2 = str(int(scene_info[-15])) + str(int(scene_info[-14])) + str(int(scene_info[-13]))
        assign_color(color_dir[color1], v_list)
        assign_color(color_dir[color2], v_list)

    elif len(scene_info) == 21:
        color1 = str(int(scene_info[-5])) + str(int(scene_info[-4])) + str(int(scene_info[-3]))
        assign_color(color_dir[color1], v_list)

    [v_red, v_green, v_yellow] = v_list

    return v_red, v_green, v_yellow


