import multiprocessing as mp
import os
import sys
sys.path.insert(
    0, os.path.abspath(os.path.join(os.path.dirname(__file__), "emd"))
)
import cv2
import time

import h5py
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d, Axes3D
import numpy as np
import scipy.misc
import scipy.spatial as spatial
import torch
import torch.nn.functional as F
import pickle
from torch.autograd import Variable
from torch.utils.data import Dataset

from utils import rand_int, rand_float, get_color_info
from utils import FPS_sampling, capture_scene_image, farthest_point_sampling, capture_motion_image

# from pcd_matching import Hungarian

### from DPI

def store_data(data_names, data, path):
    hf = h5py.File(path, 'w')
    for i in range(len(data_names)):
        hf.create_dataset(data_names[i], data=data[i])
    hf.close()


def load_data(data_names, path):
    hf = h5py.File(path, 'r')
    data = []
    for i in range(len(data_names)):
        d = np.array(hf.get(data_names[i]))
        data.append(d)
    hf.close()
    return data


def combine_stat(stat_0, stat_1):
    mean_0, std_0, n_0 = stat_0[:, 0], stat_0[:, 1], stat_0[:, 2]
    mean_1, std_1, n_1 = stat_1[:, 0], stat_1[:, 1], stat_1[:, 2]

    mean = (mean_0 * n_0 + mean_1 * n_1) / (n_0 + n_1)
    std = np.sqrt((std_0 ** 2 * n_0 + std_1 ** 2 * n_1 + \
                (mean_0 - mean) ** 2 * n_0 + (mean_1 - mean) ** 2 * n_1) / (n_0 + n_1))
    n = n_0 + n_1

    return np.stack([mean, std, n], axis=-1)


def init_stat(dim):
    # mean, std, count
    return np.zeros((dim, 3))


def normalize(data, stat, var=False):
    if var:
        for i in range(len(stat)):
            stat[i][stat[i][:, 1] == 0, 1] = 1.
            s = Variable(torch.FloatTensor(stat[i]).cuda())

            stat_dim = stat[i].shape[0]
            n_rep = int(data[i].size(1) / stat_dim)
            data[i] = data[i].view(-1, n_rep, stat_dim)

            data[i] = (data[i] - s[:, 0]) / s[:, 1]

            data[i] = data[i].view(-1, n_rep * stat_dim)

    else:
        for i in range(len(stat)):
            stat[i][stat[i][:, 1] == 0, 1] = 1.

            stat_dim = stat[i].shape[0]
            n_rep = int(data[i].shape[1] / stat_dim)
            data[i] = data[i].reshape((-1, n_rep, stat_dim))

            data[i] = (data[i] - stat[i][:, 0]) / stat[i][:, 1]

            data[i] = data[i].reshape((-1, n_rep * stat_dim))

    return data


def denormalize(data, stat, var=False):
    if var:
        for i in range(len(stat)):
            s = Variable(torch.FloatTensor(stat[i]).cuda())
            data[i] = data[i] * s[:, 1] + s[:, 0]
    else:
        for i in range(len(stat)):
            data[i] = data[i] * stat[i][:, 1] + stat[i][:, 0]

    return data


def calc_rigid_transform(XX, YY):
    X = XX.copy().T
    Y = YY.copy().T

    mean_X = np.mean(X, 1, keepdims=True)
    mean_Y = np.mean(Y, 1, keepdims=True)
    X = X - mean_X
    Y = Y - mean_Y
    C = np.dot(X, Y.T)
    U, S, Vt = np.linalg.svd(C)
    D = np.eye(3)
    D[2, 2] = np.linalg.det(np.dot(Vt.T, U.T))
    R = np.dot(Vt.T, np.dot(D, U.T))
    T = mean_Y - np.dot(R, mean_X)

    '''
    YY_fitted = (np.dot(R, XX.T) + T).T
    print("MSE fit", np.mean(np.square(YY_fitted - YY)))
    '''

    return R, T


def normalize_scene_param(scene_params, param_idx, param_range, norm_range=(-1, 1)):
    normalized = np.copy(scene_params[:, param_idx])
    low, high = param_range
    if low == high:
        return normalized
    nlow, nhigh = norm_range
    normalized = nlow + (normalized - low) * (nhigh - nlow) / (high - low)
    return normalized


def gen_PyFleX(info):
    env, env_idx = info['env'], info['env_idx']
    thread_idx, data_dir, data_names = info['thread_idx'], info['data_dir'], info['data_names']
    n_rollout, time_step = info['n_rollout'], info['time_step']
    shape_state_dim, dt = info['shape_state_dim'], info['dt']

    gen_vision = info['gen_vision']
    vision_dir, vis_width, vis_height = info['vision_dir'], info['vis_width'], info['vis_height']

    np.random.seed(round(time.time() * 1000 + thread_idx) % 2 ** 32)

    # positions
    stats = [init_stat(3)]

    import pyflex
    pyflex.init()

    for i in range(n_rollout):

        if i % 10 == 0:
            print("%d / %d" % (i, n_rollout))

        rollout_idx = thread_idx * n_rollout + i
        rollout_dir = os.path.join(data_dir, str(rollout_idx))
        os.system('mkdir -p ' + rollout_dir)

        if env == 'RigidFall':
            g_low, g_high = info['physics_param_range']
            gravity = rand_float(g_low, g_high)
            print("Generated RigidFall rollout {} with gravity {} from range {} ~ {}".format(
                i, gravity, g_low, g_high))

            n_instance = 3
            draw_mesh = 1
            scene_params = np.zeros(n_instance * 3 + 3)
            scene_params[0] = n_instance
            scene_params[1] = gravity
            scene_params[-1] = draw_mesh


            low_bound = 0.09
            for j in range(n_instance):
                x = rand_float(0., 0.1)
                y = rand_float(low_bound, low_bound + 0.01)
                z = rand_float(0., 0.1)

                scene_params[j * 3 + 2] = x
                scene_params[j * 3 + 3] = y
                scene_params[j * 3 + 4] = z

                low_bound += 0.21

            pyflex.set_scene(env_idx, scene_params, thread_idx)
            pyflex.set_camPos(np.array([0.2, 0.875, 2.0]))

            n_particles = pyflex.get_n_particles()
            n_shapes = 1    # the floor

            positions = np.zeros((time_step, n_particles + n_shapes, 3), dtype=np.float32)
            shape_quats = np.zeros((time_step, n_shapes, 4), dtype=np.float32)

            for j in range(time_step):
                positions[j, :n_particles] = pyflex.get_positions().reshape(-1, 4)[:, :3]

                ref_positions = positions[0]

                for k in range(n_instance):
                    XX = ref_positions[64*k:64*(k+1)]
                    YY = positions[j, 64*k:64*(k+1)]

                    X = XX.copy().T
                    Y = YY.copy().T

                    mean_X = np.mean(X, 1, keepdims=True)
                    mean_Y = np.mean(Y, 1, keepdims=True)
                    X = X - mean_X
                    Y = Y - mean_Y
                    C = np.dot(X, Y.T)
                    U, S, Vt = np.linalg.svd(C)
                    D = np.eye(3)
                    D[2, 2] = np.linalg.det(np.dot(Vt.T, U.T))
                    R = np.dot(Vt.T, np.dot(D, U.T))
                    t = mean_Y - np.dot(R, mean_X)

                    YY_fitted = (np.dot(R, XX.T) + t).T
                    # print("MSE fit", np.mean(np.square(YY_fitted - YY)))

                    positions[j, 64*k:64*(k+1)] = YY_fitted

                if gen_vision:
                    pyflex.step(capture=True, path=os.path.join(rollout_dir, str(j) + '.tga'))
                else:
                    pyflex.step()

                data = [positions[j], shape_quats[j], scene_params]
                store_data(data_names, data, os.path.join(rollout_dir, str(j) + '.h5'))

            if gen_vision:
                images = np.zeros((time_step, vis_height, vis_width, 3), dtype=np.uint8)
                for j in range(time_step):
                    img_path = os.path.join(rollout_dir, str(j) + '.tga')
                    img = scipy.misc.imread(img_path)[:, :, :3][:, :, ::-1]
                    img = cv2.resize(img, (vis_width, vis_height), interpolation=cv2.INTER_AREA)
                    images[j] = img
                    os.system('rm ' + img_path)

                store_data(['positions', 'images', 'scene_params'], [positions, images, scene_params],
                           os.path.join(vision_dir, str(rollout_idx) + '.h5'))

        elif env == 'MassRope':
            s_low, s_high = info['physics_param_range']
            stiffness = rand_float(s_low, s_high)
            print("Generated MassRope rollout {} with gravity {} from range {} ~ {}".format(
                i, stiffness, s_low, s_high))

            x = 0.
            y = 1.0
            z = 0.
            length = 0.7
            draw_mesh = 1.

            scene_params = np.array([x, y, z, length, stiffness, draw_mesh])

            pyflex.set_scene(env_idx, scene_params, 0)
            pyflex.set_camPos(np.array([0.13, 2.0, 3.2]))

            action = np.zeros(3)

            # the last particle is the pin, regarded as shape
            n_particles = pyflex.get_n_particles() - 1
            n_shapes = 1    # the mass at the top of the rope

            positions = np.zeros((time_step + 1, n_particles + n_shapes, 3), dtype=np.float32)
            shape_quats = np.zeros((time_step + 1, n_shapes, 4), dtype=np.float32)

            action = np.zeros(3)
            for j in range(time_step + 1):
                positions[j] = pyflex.get_positions().reshape(-1, 4)[:, :3]
                if j >= 1:
                    # append the action (position of the pin) to the previous time step
                    positions[j - 1, -1, :] = positions[j, -1, :]

                ref_positions = positions[0]

                # apply rigid projection to the rigid object
                # cube: [0, 81)
                # rope: [81, 95)
                # pin: [95, 96)
                XX = ref_positions[:81]
                YY = positions[j, :81]

                X = XX.copy().T
                Y = YY.copy().T

                mean_X = np.mean(X, 1, keepdims=True)
                mean_Y = np.mean(Y, 1, keepdims=True)
                X = X - mean_X
                Y = Y - mean_Y
                C = np.dot(X, Y.T)
                U, S, Vt = np.linalg.svd(C)
                D = np.eye(3)
                D[2, 2] = np.linalg.det(np.dot(Vt.T, U.T))
                R = np.dot(Vt.T, np.dot(D, U.T))
                t = mean_Y - np.dot(R, mean_X)

                YY_fitted = (np.dot(R, XX.T) + t).T

                positions[j, :81] = YY_fitted

                scale = 0.1
                action[0] += rand_float(-scale, scale) - positions[j, -1, 0] * 0.1
                action[2] += rand_float(-scale, scale) - positions[j, -1, 2] * 0.1

                if gen_vision:
                    pyflex.step(action * dt, capture=True, path=os.path.join(rollout_dir, str(j) + '.tga'))
                else:
                    pyflex.step(action * dt)

                if j >= 1:
                    data = [positions[j - 1], shape_quats[j - 1], scene_params]
                    store_data(data_names, data, os.path.join(rollout_dir, str(j - 1) + '.h5'))

            if gen_vision:
                images = np.zeros((time_step, vis_height, vis_width, 3), dtype=np.uint8)
                for j in range(time_step):
                    img_path = os.path.join(rollout_dir, str(j) + '.tga')
                    img = scipy.misc.imread(img_path)[:, :, :3][:, :, ::-1]
                    img = cv2.resize(img, (vis_width, vis_height), interpolation=cv2.INTER_AREA)
                    images[j] = img
                    os.system('rm ' + img_path)

                store_data(['positions', 'images', 'scene_params'], [positions, images, scene_params],
                           os.path.join(vision_dir, str(rollout_idx) + '.h5'))

        else:
            raise AssertionError("Unsupported env")

        # change dtype for more accurate stat calculation
        # only normalize positions
        datas = [positions[:time_step].astype(np.float64)]

        for j in range(len(stats)):
            stat = init_stat(stats[j].shape[0])
            stat[:, 0] = np.mean(datas[j], axis=(0, 1))[:]
            stat[:, 1] = np.std(datas[j], axis=(0, 1))[:]
            stat[:, 2] = datas[j].shape[0] * datas[j].shape[1]
            stats[j] = combine_stat(stats[j], stat)

    pyflex.clean()

    return stats


def axisEqual3D(ax):
    extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
    sz = extents[:, 1] - extents[:, 0]
    centers = np.mean(extents, axis=1)
    maxsize = max(abs(sz))
    r = maxsize / 2
    for ctr, dim in zip(centers, 'xyz'):
        getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)


def visualize_neighbors(anchors, queries, idx, neighbors, angle=100):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    ax.scatter(queries[idx, 0], queries[idx, 2], queries[idx, 1], c='g', s=80, alpha=0.4)
    ax.scatter(anchors[neighbors, 0], anchors[neighbors, 2], anchors[neighbors, 1], c='r', s=40, alpha=0.4)
    ax.scatter(anchors[:, 0], anchors[:, 2], anchors[:, 1], alpha=0.2)
    ax.view_init(10, angle)
    axisEqual3D(ax)

    plt.show()


def find_relations_neighbor(pos, query_idx, anchor_idx, radius, order, var=False):
    if np.sum(anchor_idx) == 0:
        return []

    point_tree = spatial.cKDTree(pos[anchor_idx])
    neighbors = point_tree.query_ball_point(pos[query_idx], radius, p=order)

    '''
    for i in range(len(neighbors)):
        visualize_neighbors(pos[anchor_idx], pos[query_idx], i, neighbors[i])
    '''
    # print(neighbors)
    relations = []
    for i in range(len(neighbors)):
        count_neighbors = len(neighbors[i])
        if count_neighbors == 0:
            continue

        receiver = np.ones(count_neighbors, dtype=np.int) * query_idx[i]
        sender = np.array(anchor_idx[neighbors[i]])

        # receiver, sender, relation_type
        relations.append(np.stack([receiver, sender], axis=1))

    return relations


def find_k_relations_neighbor(k, positions, query_idx, anchor_idx, radius, order, var=False):
    """
    Same as find_relations_neighbor except that each point is only connected to the k nearest neighbors

    For each particle, only take the first min_neighbor neighbors, where
    min_neighbor = minimum number of neighbors among all particle's numbers of neighbors
    """
    if np.sum(anchor_idx) == 0:
        return []

    pos = positions.data.cpu().numpy() if var else positions

    point_tree = spatial.cKDTree(pos[anchor_idx])
    neighbors = point_tree.query_ball_point(pos[query_idx], radius, p=order)

    '''
    for i in range(len(neighbors)):
        visualize_neighbors(pos[anchor_idx], pos[query_idx], i, neighbors[i])
    '''

    relations = []
    min_neighbors = None
    for i in range(len(neighbors)):
        if min_neighbors is None:
            min_neighbors = len(neighbors[i])
        elif len(neighbors[i]) < min_neighbors:
            min_neighbors = len(neighbors[i])
        else:
            pass

    for i in range(len(neighbors)):
        receiver = np.ones(min_neighbors, dtype=np.int) * query_idx[i]
        sender = np.array(anchor_idx[neighbors[i][:min_neighbors]])

        # receiver, sender, relation_type
        relations.append(np.stack([receiver, sender], axis=1))

    return relations


def get_scene_info(data):
    """
    A subset of prepare_input() just to get number of particles
    for initialization of grouping
    """
    positions, shape_quats, scene_params = data
    n_shapes = shape_quats.shape[0]
    count_nodes = positions.shape[0]
    n_particles = count_nodes - n_shapes

    return n_particles, n_shapes, scene_params


def get_env_group(args, n_particles, n_shapes, scene_params, use_gpu=False):
    # n_particles (int)
    # scene_params: B x param_dim
    B = scene_params.shape[0]

    max_len = max([n_particles[i] + n_shapes[i] for i in range(len(n_particles))])

    p_rigid = torch.zeros(B, args.n_instance)
    p_instance = torch.zeros(B, max_len, args.n_instance)
    physics_param = torch.zeros(B, max_len)

    if args.env == 'RigidFall':
        norm_g = normalize_scene_param(scene_params, 1, args.physics_param_range)
        physics_param[:] = torch.FloatTensor(norm_g).view(B, 1)

        p_rigid[:] = 1

        for i in range(args.n_instance):
            p_instance[:, 64 * i:64 * (i + 1), i] = 1

    elif args.env == 'MassRope':
        norm_stiff = normalize_scene_param(scene_params, 4, args.physics_param_range)
        physics_param[:] = torch.FloatTensor(norm_stiff).view(B, 1)

        n_rigid_particle = 81

        p_rigid[:, 0] = 1
        p_instance[:, :n_rigid_particle, 0] = 1
        p_instance[:, n_rigid_particle:, 1] = 1

    elif args.env == 'pour' or args.env == 'pour_extra':
        for i in range(len(n_particles)):
            p_instance[i, :n_particles[i]] = 1
    elif args.env == 'shake_extra':

        for i in range(len(n_particles)):

            p_num = n_particles[i]

            p_instance[i, :p_num-(27*3), 0] = 1  # water

            p_instance[i, p_num - (27 * 3):p_num - (27*2), 1] = 1 # red box
            p_instance[i, p_num- (27 * 2):p_num - 27, 2] = 1 # green box
            p_instance[i, p_num - 27:p_num, 3] = 1 # yellow box

            p_rigid[:, 1] = 1
            p_rigid[:, 2] = 1
            p_rigid[:, 3] = 1



    elif args.env == 'granular_push':
        for i in range(len(n_particles)):
            p_instance[i, :n_particles[i]] = 1


    else:
        raise AssertionError("Unsupported env")

    if use_gpu:
        p_rigid = p_rigid.cuda()
        p_instance = p_instance.cuda()
        physics_param = physics_param.cuda()

    # p_rigid: B x n_instance
    # p_instance: B x n_p x n_instance
    # physics_param: B x n_p
    return [p_rigid, p_instance, physics_param]


def vis_shape_neigh(positions, par_idx, neighbours, output_pth='o.png', n_s=330, n_p=200):
    n = positions.shape[0]
    c = [(0.8, 0.5, 0.7)] * n_p  + [(0, 1, 1)] * n_s
    c[par_idx] = (1, 0, 0)
    shape_idx = neighbours[neighbours[:, 0] == par_idx][:, 1]
    print(shape_idx)
    for i in list(shape_idx):
        c[i] = (0, 1, 0)
    capture_scene_image(positions, color=c, angle=100, output_pth=output_pth)


def vis_rels_container_to_partilce(pcd, rels, container_id, n_p, n_s, output_pth='debug.png', angle=100):
    par_id = (rels[:, 1] == container_id).nonzero()[0]
    c = ['y'] * n_p + ['b'] * n_s

    c[container_id] = 'g'

    par_id = list(par_id)

    for i in par_id:
        c[rels[i][0]] = 'r'


    print(c)

    capture_scene_image(pcd, color=c, output_pth=output_pth, angle=angle)




def prepare_input(positions, n_particle, n_shape, args, var=False):
    # positions: (n_p + n_s) x 3

    verbose = args.verbose_data

    count_nodes = n_particle + n_shape

    if verbose:
        print("prepare_input::positions", positions.shape)
        print("prepare_input::n_particle", n_particle)
        print("prepare_input::n_shape", n_shape)

    ### object attributes
    attr = np.zeros((count_nodes, args.attr_dim))

    ##### add env specific graph components
    rels = []
    if args.env == 'RigidFall':
        # object attr:
        # [particle, floor]
        attr[n_particle, 1] = 1
        pos = positions.data.cpu().numpy() if var else positions

        # conncetion between floor and particles when they are close enough
        dis = pos[:n_particle, 1] - pos[n_particle, 1]
        nodes = np.nonzero(dis < args.dist_cp)[0]

        floor = np.ones(nodes.shape[0], dtype=np.int) * n_particle
        rels += [np.stack([nodes, floor], axis=1)]


    elif args.env == 'MassRope':
        pos = positions.data.cpu().numpy() if var else positions
        dis = np.sqrt(np.sum((pos[n_particle] - pos[:n_particle])**2, 1))
        nodes = np.nonzero(dis < args.neighbor_radius)[0]

        '''
        if verbose:
            visualize_neighbors(pos, pos, 0, nodes)
            print(np.sort(dis)[:10])
        '''

        pin = np.ones(nodes.shape[0], dtype=np.int) * n_particle
        rels += [np.stack([nodes, pin], axis=1)]

    elif args.env == 'pour' or args.env == 'pour_extra':
        attr[n_particle:, 1] = 1
        pos = positions.data.cpu().numpy() if var else positions

        # conncetion between floor and particles when they are close enough
        # dis = pos[:n_particle, 1] - pos[n_particle:, 1]

        dis_x = pos[:n_particle, 0][:, None] - pos[n_particle:, 0][None, :]
        dis_y = pos[:n_particle, 1][:, None] - pos[n_particle:, 1][None, :]
        dis_z = pos[:n_particle, 2][:, None] - pos[n_particle:, 2][None, :]

        dis = np.sqrt(dis_x ** 2 + dis_y ** 2 + dis_z ** 2)  # n_p * n_s


        nodes_p, nodes_s = np.nonzero(dis < args.dist_cp) # n_rel * 1

        rels += [np.stack([nodes_p, nodes_s + n_particle], axis=1)]

    elif args.env == 'shake_extra':
        # [particles + boxes + shape] * 3
        # [0， 0， 1] : shape
        # [0， 1， 0] : boxes
        # [1， 0， 0] : particles

        attr[n_particle:, 2] = 1
        attr[:(n_particle -  27*3), 0] = 1
        attr[n_particle-27*3:n_particle, 1] = 1

        pos = positions.data.cpu().numpy() if var else positions

        dis_x = pos[:n_particle, 0][:, None] - pos[n_particle:, 0][None, :]
        dis_y = pos[:n_particle, 1][:, None] - pos[n_particle:, 1][None, :]
        dis_z = pos[:n_particle, 2][:, None] - pos[n_particle:, 2][None, :]

        dis = np.sqrt(dis_x ** 2 + dis_y ** 2 + dis_z ** 2)  # n_p * n_s

        nodes_p, node_s = np.nonzero(dis < args.dist_cp)

        rels += [np.stack([nodes_p, node_s + n_particle], axis=1)]

    elif args.env == 'granular_push':

        attr[n_particle:, 1] = 1
        pos = positions.data.cpu().numpy() if var else positions

        # conncetion between floor and particles when they are close enough
        # dis = pos[:n_particle, 1] - pos[n_particle:, 1]

        dis_x = pos[:n_particle, 0][:, None] - pos[n_particle:, 0][None, :]
        dis_y = pos[:n_particle, 1][:, None] - pos[n_particle:, 1][None, :]
        dis_z = pos[:n_particle, 2][:, None] - pos[n_particle:, 2][None, :]

        dis = np.sqrt(dis_x ** 2 + dis_y ** 2 + dis_z ** 2)  # n_p * n_s

        nodes_p, node_s = np.nonzero(dis < args.dist_cp)  # n_rel * 1

        rels += [np.stack([nodes_p, node_s + n_particle], axis=1)]


    else:
        AssertionError("Unsupported env %s" % args.env)

    ##### add relations between leaf particles

    if args.env in ['RigidFall', 'MassRope', 'pour', 'shake', 'pour_extra', 'shake_extra', 'granular_push']:
        queries = np.arange(n_particle)
        anchors = np.arange(n_particle)

    rels += find_relations_neighbor(pos, queries, anchors, args.neighbor_radius, 2, var)
    # rels += find_k_relations_neighbor(args.neighbor_k, pos, queries, anchors, args.neighbor_radius, 2, var)

    if len(rels) > 0:
        rels = np.concatenate(rels, 0)

    if verbose:
        print("Relations neighbor", rels.shape)

    n_rel = rels.shape[0]
    Rr = torch.zeros(n_rel, n_particle + n_shape)
    Rs = torch.zeros(n_rel, n_particle + n_shape)
    Rr[np.arange(n_rel), rels[:, 0]] = 1
    Rs[np.arange(n_rel), rels[:, 1]] = 1

    if verbose:
        print("Object attr:", np.sum(attr, axis=0))
        print("Particle attr:", np.sum(attr[:n_particle], axis=0))
        print("Shape attr:", np.sum(attr[n_particle:n_particle + n_shape], axis=0))

    if verbose:
        print("Particle positions stats")
        print("  Shape", positions.shape)
        print("  Min", np.min(positions[:n_particle], 0))
        print("  Max", np.max(positions[:n_particle], 0))
        print("  Mean", np.mean(positions[:n_particle], 0))
        print("  Std", np.std(positions[:n_particle], 0))

    if var:
        particle = positions
    else:
        particle = torch.FloatTensor(positions)

    if verbose:
        for i in range(count_nodes - 1):
            if np.sum(np.abs(attr[i] - attr[i + 1])) > 1e-6:
                print(i, attr[i], attr[i + 1])

    attr = torch.FloatTensor(attr)
    assert attr.size(0) == count_nodes
    assert attr.size(1) == args.attr_dim

    # attr: (n_p + n_s) x attr_dim
    # particle (unnormalized): (n_p + n_s) x state_dim
    # Rr, Rs: n_rel x (n_p + n_s)
    return attr, particle, Rr, Rs

def prepare_input_boundary_free(positions, n_particle, n_shape, args, var=False, norm=None):
    # positions: (n_p + n_s) x 3
    verbose = args.verbose_data

    count_nodes = n_particle

    if verbose:
        print("prepare_input::positions", positions.shape)
        print("prepare_input::n_particle", n_particle)
        print("prepare_input::n_shape", n_shape)

    ### object attributes
    attr = np.zeros((count_nodes, args.attr_dim))

    ##### add env specific graph components
    rels = []
    if args.env == 'pour' or args.env == 'pour_extra':
        attr[n_particle:, 1] = 1
        pos = positions.data.cpu().numpy() if var else positions

        # conncetion between floor and particles when they are close enough
        # dis = pos[:n_particle, 1] - pos[n_particle:, 1]
        # norm n_s * 3

        dis_x = pos[:n_particle, 0][:, None] - pos[n_particle:, 0][None, :] # n_p * n_s
        dis_y = pos[:n_particle, 1][:, None] - pos[n_particle:, 1][None, :]
        dis_z = pos[:n_particle, 2][:, None] - pos[n_particle:, 2][None, :]

        dis = np.sqrt(dis_x ** 2 + dis_y ** 2 + dis_z ** 2)  # n_p * n_s


        ## method 1 : mask + sum vec

        boundary_mask = dis < args.dist_cp  # n_p * n_s

        boundary_mask_repeat = np.concatenate([boundary_mask[:, :, None] for i in range(3)], -1)  # n_p * n_s * 3

        # diff_s_p = np.stack([dis_x, dis_y, dis_z], -1)  # n_p * n_s * 3
        #
        # v_p = diff_s_p * boundary_mask_repeat   # n_p * n_s * 3
        #
        # v_p = v_p.sum(1) / (boundary_mask_repeat.sum(1)+1)  # n_p * 3


        v_p = norm * boundary_mask_repeat # n_p * n_s * 3

        v_p = v_p.sum(1) / (boundary_mask_repeat.sum(1) + 1)

        # from utils import capture_scene_image_with_norm
        #
        # capture_scene_image_with_norm(pos[:n_particle], v_p, output_pth='v_p_norm_.jpg')

    elif args.env == 'shake' or args.env == 'shake_extra':
        attr[n_particle:, 2] = 1
        attr[:64, 0] = 1
        attr[64:n_particle, 1] = 1
        pos = positions.data.cpu().numpy() if var else positions

        dis_x = pos[:n_particle, 0][:, None] - pos[n_particle:, 0][None, :]
        dis_y = pos[:n_particle, 1][:, None] - pos[n_particle:, 1][None, :]
        dis_z = pos[:n_particle, 2][:, None] - pos[n_particle:, 2][None, :]

        dis = np.sqrt(dis_x ** 2 + dis_y ** 2 + dis_z ** 2)  # n_p * n_s

        nodes_p, node_s = np.nonzero(dis < args.dist_cp)

        rels += [np.stack([nodes_p, node_s + n_particle], axis=1)]

    elif args.env == 'granular_push':

        attr[n_particle:, 1] = 1
        pos = positions.data.cpu().numpy() if var else positions

        # conncetion between floor and particles when they are close enough
        # dis = pos[:n_particle, 1] - pos[n_particle:, 1]

        dis_x = pos[:n_particle, 0][:, None] - pos[n_particle:, 0][None, :]
        dis_y = pos[:n_particle, 1][:, None] - pos[n_particle:, 1][None, :]
        dis_z = pos[:n_particle, 2][:, None] - pos[n_particle:, 2][None, :]

        dis = np.sqrt(dis_x ** 2 + dis_y ** 2 + dis_z ** 2)  # n_p * n_s

        nodes_p, node_s = np.nonzero(dis < args.dist_cp)  # n_rel * 1

        rels += [np.stack([nodes_p, node_s + n_particle], axis=1)]


    else:
        AssertionError("Unsupported env %s" % args.env)

    ##### add relations between leaf particles

    if args.env in ['RigidFall', 'MassRope', 'pour', 'shake', 'pour_extra', 'shake_extra', 'granular_push']:
        queries = np.arange(n_particle)
        anchors = np.arange(n_particle)

    rels += find_relations_neighbor(pos, queries, anchors, args.neighbor_radius, 2, var)
    # rels += find_k_relations_neighbor(args.neighbor_k, pos, queries, anchors, args.neighbor_radius, 2, var)

    if len(rels) > 0:
        rels = np.concatenate(rels, 0)

    if verbose:
        print("Relations neighbor", rels.shape)

    n_rel = rels.shape[0]
    Rr = torch.zeros(n_rel, n_particle)
    Rs = torch.zeros(n_rel, n_particle)
    Rr[np.arange(n_rel), rels[:, 0]] = 1
    Rs[np.arange(n_rel), rels[:, 1]] = 1

    if verbose:
        print("Object attr:", np.sum(attr, axis=0))
        print("Particle attr:", np.sum(attr[:n_particle], axis=0))
        print("Shape attr:", np.sum(attr[n_particle:n_particle + n_shape], axis=0))

    if verbose:
        print("Particle positions stats")
        print("  Shape", positions.shape)
        print("  Min", np.min(positions[:n_particle], 0))
        print("  Max", np.max(positions[:n_particle], 0))
        print("  Mean", np.mean(positions[:n_particle], 0))
        print("  Std", np.std(positions[:n_particle], 0))


    if var:
        particle = torch.cat([positions[:n_particle], torch.from_numpy(v_p).to(positions.device)], -1)  # n_p * 6
    else:
        particle = torch.FloatTensor(np.concatenate([positions[:n_particle], v_p], -1))

    if verbose:
        for i in range(count_nodes - 1):
            if np.sum(np.abs(attr[i] - attr[i + 1])) > 1e-6:
                print(i, attr[i], attr[i + 1])

    attr = torch.FloatTensor(attr)
    assert attr.size(0) == count_nodes
    assert attr.size(1) == args.attr_dim

    # attr: (n_p) x attr_dim
    # particle (unnormalized): (n_p) x state_dim
    # Rr, Rs: n_rel x (n_p)
    return attr, particle, Rr, Rs




class Dynamics_Dataset(Dataset):

    def __init__(self, args, phase):
        self.args = args
        self.phase = phase
        self.data_dir = os.path.join(self.args.dataf, phase)
        self.vision_dir = self.data_dir + '_vision'
        self.stat_path = os.path.join(self.args.dataf, 'stat.h5')
        self.using_gt = args.using_gt

        if args.gen_data:
            os.system('mkdir -p ' + self.data_dir)
        if args.gen_vision:
            os.system('mkdir -p ' + self.vision_dir)

        # if args.env in ['RigidFall', 'MassRope']:
        #     self.data_names = ['positions', 'shape_quats', 'scene_params']
        # else:
        #     raise AssertionError("Unsupported env")

        ratio = self.args.train_valid_ratio

        if phase == 'train':
            self.n_rollout = int(self.args.n_rollout * ratio)
        elif phase == 'valid':
            self.n_rollout = self.args.n_rollout - int(self.args.n_rollout * ratio)
        else:
            raise AssertionError("Unknown phase")

        # load in all position : n_rollout * n_frame * fps_n
        if args.using_gt:
            p_data_path = os.path.join(args.root, 'dataset/GT_particle_fps/{}/fps_{}/'.format(args.env, args.fps))
        else:
            p_data_path = os.path.join(args.root, 'dataset/particle_fps_new/{}/fps_{}/'.format(args.env, args.fps))


        if args.env == 'shake_extra':
            invalid_id = [316]
        else:
            invalid_id = []


        if phase == 'train':
            rollout_list = range(int(self.n_rollout * ratio))
        else:
            rollout_list = range(int(self.args.n_rollout * ratio), args.n_rollout)

        if 'extra' in args.env or 'granular' in args.env:
            if phase == 'train':
                n_roll_per_scene = int(self.n_rollout) // 4
                rollout_list = []
                for i in range(4):
                    rollout_list += [item for item in range(i * 100, i * 100 + n_roll_per_scene)]
            else:
                rollout_list = [item for item in range(400, 400 + self.n_rollout)]

        self.particle_states = []  # len = n_rollout, item : time_step * N * state_dim

        for roll_ids in rollout_list:
            if roll_ids in invalid_id:
                rollout_list.remove(roll_ids)


        self.rollout_list = rollout_list



        for rollout_idx in rollout_list:
            if args.using_gt:
                if args.env == 'pour':
                    if not args.same_order:
                        dump = np.load(p_data_path + "{}.npy".format(rollout_idx)) # 200 * 1
                        info_path = os.path.join(args.raw_data_path, str(rollout_idx), "info.p")
                        info = pickle.load(open(info_path, 'rb'))['particles']
                        batch_idx = np.array(range(0, dump.shape[0])).repeat(dump.shape[1])
                        self.particle_states.append(info[batch_idx,dump.flatten(),:].reshape(dump.shape[0], dump.shape[1], -1))
                    else:
                        dump = np.load(p_data_path + "0.npy".format(rollout_idx))[0]  # 200 * 1
                        info_path = os.path.join(args.raw_data_path, str(rollout_idx), "info.p")
                        info = pickle.load(open(info_path, 'rb'))['particles']
                        self.particle_states.append(info[:, dump, :])
                elif args.env == 'shake':
                    dump = np.load(p_data_path + "0.npy".format(rollout_idx))[0] # 200 * 1
                    info_path = os.path.join(args.raw_data_path, str(rollout_idx), "info.p")
                    info = pickle.load(open(info_path, 'rb'))['particles']
                    self.particle_states.append(info[:, dump, :])
            else:
                dump = torch.load(p_data_path + "/{}.bin".format(rollout_idx))
                self.particle_states.append(dump)


        # load shape info
        if 'extra' not in args.env: # same shape for all traj, granular push is in this category

            shape_data_path = os.path.join(args.root, f'dataset/box_sampling/{args.env}/sampling_8_interval_{args.interval}/')
            self.shape_states = np.load(shape_data_path + "pkg.npy")[rollout_list, ...]  # 500 * 300 * n_s * 3
            self.n_p, self.n_s = self.args.fps, self.shape_states.shape[2]
        else:
            if args.env == 'pour_extra':
                if args.add_norm_vector or args.boundary_free:
                     shape_data_path = os.path.join(args.root, f'dataset/box_sampling_with_norm/{args.env}/sampling_8_interval_0.15/')
                else:
                    shape_data_path = os.path.join(args.root,
                                                   f'dataset/box_sampling/{args.env}/sampling_8_interval_0.15/')

                shape_states = []
                for i in range(5):
                    p = np.load(shape_data_path + f"pkg{i}.npy")
                    shape_states += [p[o] for o in range(p.shape[0])]

                self.shape_states = []
                for roll_id in self.rollout_list:
                    self.shape_states.append(shape_states[roll_id])
            elif args.env == 'shake_extra':
                if args.add_norm_vector or args.boundary_free:
                    pass
                else:
                    shape_data_path = os.path.join(args.root,
                                                   f'dataset/box_sampling/{args.env}/sampling_8_interval_0.2/')

                shape_states = []
                for i in range(5):
                    p = np.load(shape_data_path + f"{i}.npy")
                    shape_states += [p[o] for o in range(p.shape[0])]

                self.shape_states = []
                for roll_id in self.rollout_list:
                    self.shape_states.append(shape_states[roll_id])


        self.start_cut = 0
        if 'extra' in args.env:
            self.start_cut = 5
        if 'granular' in args.env:
            self.start_cut = 35


    def __len__(self):
        """
        Each data point is consisted of a whole trajectory
        """

        args = self.args
        return len(self.rollout_list) * (args.time_step // args.frame_jump - args.sequence_length + 1 - self.start_cut)

    def load_data(self, name):
        print("Loading stat from %s ..." % self.stat_path)
        self.stat = load_data(self.data_names[:1], self.stat_path)


    def __getitem__(self, idx):
        """
        Load a trajectory of length sequence_length
        """
        args = self.args


        offset = args.time_step // args.frame_jump - args.sequence_length + 1 - self.start_cut

        idx_rollout = idx // offset
        st_idx = idx % offset + self.start_cut
        ed_idx = st_idx + args.sequence_length




        if idx == 9 * args.batch_size:
            a = 0


        if args.stage in ['dy']:
            # load ground truth data
            attrs, particles, Rrs, Rss = [], [], [], []

            if args.env == 'shake_extra':
                box_info = get_color_info(self.rollout_list[idx_rollout])
                box_info = torch.FloatTensor(box_info)
            else:
                box_info = torch.FloatTensor([0, 0, 0])

            max_n_rel = 0
            for t in range(st_idx, ed_idx):


                data_p, data_s = self.particle_states[idx_rollout][t*args.frame_jump], \
                                 self.shape_states[idx_rollout][t*args.frame_jump]


                if t == ed_idx + args.align_step:
                    anchor_p = data_p

                # print(data_p.shape, data_s.shape)

                if args.add_norm_vector or args.boundary_free:
                    data_s_norm = data_s[:, 3:] # norm
                    data_s = data_s[:, :3]


                data = np.concatenate([data_p, data_s], 0)


                # load scene param
                if t == st_idx:
                    n_particle, n_shape, scene_params = data_p.shape[0], data_s.shape[0], np.array([0, 0, 0])

                # attr: (n_p + n_s) x attr_dim
                # particle (unnormalized): (n_p + n_s) x state_dim
                # Rr, Rs: n_rel x (n_p + n_s)
                if not args.boundary_free:
                    attr, particle, Rr, Rs = prepare_input(data, n_particle, n_shape, self.args)
                else:
                    attr, particle, Rr, Rs = prepare_input_boundary_free(data, n_particle, n_shape, self.args, norm=data_s_norm, var=False)
                max_n_rel = max(max_n_rel, Rr.size(0))

                attrs.append(attr)


                if args.add_norm_vector:
                    particle = torch.cat([particle,
                                          torch.cat([torch.zeros(n_particle, 3), torch.from_numpy(data_s_norm)], 0)],
                                         -1)


                particles.append(particle.numpy())
                Rrs.append(Rr)
                Rss.append(Rs)

            # 0, 1, 2, 3, ..., n_his - 1, n_seq - 1
            # using n_his - 1 to be the anchor value

            # if args.mode != 'only_shape' and (not args.using_gt):
            #
            #     for t in range(st_idx, ed_idx + args.align_step):
            #         t -= st_idx
            #         data_p = particles[t][:self.n_p]  # fps_num * 3
            #         data_s = particles[t][self.n_p:]
            #         rearranged_order = []
            #         dis_x = anchor_p[:, 0][:, None] - data_p[:, 0][None, :]
            #         dis_y = anchor_p[:, 1][:, None] - data_p[:, 1][None, :]
            #         dis_z = anchor_p[:, 2][:, None] - data_p[:, 2][None, :]
            #
            #         dist = dis_x ** 2 + dis_y ** 2 + dis_z ** 2
            #
            #         reorder = dist.argmin(1)  # [200]
            #
            #         data_p_new = data_p[reorder, :]
            #
            #         particles[t] = np.concatenate([data_p_new, data_s], 0)

                    # soft_align

        '''
        add augmentation
        '''
        if args.stage in ['dy']:
            for t in range(args.sequence_length):
                if t == args.n_his - 1:
                    # set anchor for transforming rigid objects
                    particle_anchor = particles[t].copy()

                if t < args.n_his:
                    # add noise to observation frames - idx smaller than n_his
                    noise = np.random.randn(n_particle, 3) * args.std_d * args.augment_ratio
                    particles[t][:n_particle, :3] += noise

                else:
                    # for augmenting rigid object,
                    # make sure the rigid transformation is the same before and after augmentation
                    if args.env == 'RigidFall':
                        for k in range(args.n_instance):
                            XX = particle_anchor[64 * k:64 * (k + 1)]
                            XX_noise = particles[args.n_his - 1][64 * k:64 * (k + 1)]

                            YY = particles[t][64 * k:64 * (k + 1)]

                            R, T = calc_rigid_transform(XX, YY)

                            particles[t][64 * k:64 * (k + 1)] = (np.dot(R, XX_noise.T) + T).T

                            '''
                            # checking the correctness of the implementation
                            YY_noise = particles[t][64*k:64*(k+1)]
                            RR, TT = calc_rigid_transform(XX_noise, YY_noise)
                            print(R, T)
                            print(RR, TT)
                            '''

                    elif args.env == 'shake':
                        n_rigid_particle = 64

                        XX = particle_anchor[:n_rigid_particle]
                        XX_noise = particles[args.n_his - 1][:n_rigid_particle]
                        YY = particles[t][:n_rigid_particle]

                        R, T = calc_rigid_transform(XX, YY)

                        particles[t][:n_rigid_particle] = (np.dot(R, XX_noise.T) + T).T

                        '''
                        # checking the correctness of the implementation
                        YY_noise = particles[t][:n_rigid_particle]
                        RR, TT = calc_rigid_transform(XX_noise, YY_noise)
                        print(R, T)
                        print(RR, TT)
                        '''

        elif args.stage in ['pos', 'param']:
            # for position/group refinement and parameter estimation, the input is the perception results
            # no need to add position augmentation
            pass

        else:
            AssertionError("Unknown stage %s" % args.stage)


        # attr: (n_p + n_s) x attr_dim
        # particles (unnormalized): seq_length x (n_p + n_s) x state_dim
        # scene_params: param_dim
        attr = torch.FloatTensor(attrs[0])
        particles = torch.FloatTensor(np.stack(particles))
        scene_params = torch.FloatTensor(scene_params)

        # pad the relation set
        # Rr, Rs: seq_length x n_rel x (n_p + n_s)
        if args.stage in ['dy']:
            if not args.boundary_free:
                padding_shape = n_particle + n_shape
            else:
                padding_shape = n_particle
            for i in range(len(Rrs)):
                Rr, Rs = Rrs[i], Rss[i]
                Rr = torch.cat([Rr, torch.zeros(max_n_rel - Rr.size(0), padding_shape)], 0)
                Rs = torch.cat([Rs, torch.zeros(max_n_rel - Rs.size(0), padding_shape)], 0)
                Rrs[i], Rss[i] = Rr, Rs
            Rr = torch.FloatTensor(np.stack(Rrs))
            Rs = torch.FloatTensor(np.stack(Rss))


        if args.stage in ['dy']:

            return attr, particles, n_particle, n_shape, scene_params, box_info, Rr, Rs


