import os.path

import json
import os.path as osp
import open3d as o3d
import numpy as np
from pprint import pprint
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import ArtImage_data.pipelines_sapien as pipelines
import ArtImage_data.arti_utils as arti_utils


np.set_printoptions(threshold=np.inf)

INSTANCE_CLASSES = ('BG', 'laptop', 'eyeglasses', 'eyeglasses', 'drawer', 'scissors')

PART_CLASSES = {'laptop': ('BG', 'base_link', 'link1'),
                'scissors': ('BG', 'base_link', 'link1'),
                'eyeglasses': ('BG', 'base_link', 'link1', 'link2'),
                'dishwasher': ('BG', 'base_link', 'link1'),
                'drawer': ('BG', 'link1', 'link2', 'link3')}

PART_LABEL_MAPS = {'laptop': (0, 1, 2),
                   'scissors': (0, 1, 2),
                   'eyeglasses': (0, 1, 2, 3),
                   'dishwasher': (0, 1, 2),
                   'drawer': (0, 1, 2, 3, 4)} 

def DataGen(file_path, cat):
    results = {}
    root = 'data/'
    
    results['camera_intrinsic_path'] = osp.join(root, 'camera_intrinsic.json')

    ann_file_path = file_path
    joint_param_path = f'data/urdf_metas/{cat}/urdf_metas.json'
    img_prefix = ann_file_path.rsplit('/annotations', 1)[0]

    results['img_prefix'] = img_prefix

    results = fecth_instances(results,ann_file_path)
    results = fetch_joint_params(results,joint_param_path, cat)
    results = fetch_rest_trans(cat, results['urdf_id'], results)
    category_name = INSTANCE_CLASSES[results['category_id']]
    results['label_map'] = PART_LABEL_MAPS[INSTANCE_CLASSES[results['category_id']]]

    results.update(dict(instance_info=results['instance_info'],
                        color_path=results['color_path'],
                        depth_path=results['depth_path'],
                        img_width=results['img_width'],
                        img_height=results['img_height'],
                        bbox=results['bbox'],
                        category_id=results['category_id'],
                        label_map=results['label_map'],
                        joint_ins=results['joint_ins'],
                        norm_factors=results['norm_factors'],
                        corner_pts=results['corner_pts'],
                        rest_transformation=results['rest_transformation']
                        ))

    point_data_creator = pipelines.CreatePointDataSapien(downsample_voxel=0.005, with_rgb=False)
    results = point_data_creator(results)

    print(results['parts_parent_joint'])
    print(results['parts_child_joint'])

    nocs_data_creator = pipelines.LoadArtiNOCSDataSapien()
    results = nocs_data_creator(results)

    joint_data_creator = pipelines.LoadArtiJointDataSapien()
    results = joint_data_creator(results)

    joint_GT_creator = pipelines.CreateArtiJointGTSapien()
    results = joint_GT_creator(results)

    input_data = {}
    input_data.update(dict(camcs_per_point = np.array(results['parts_pts']),
                           seg_per_point = np.array(results['parts_cls']),
                           npcs_per_point = np.array(results['nocs_p']),
                           naocs_per_point = np.array(results['nocs_g']),
                           heatmap_per_point = np.array(results['offset_heatmap']),
                           joint_cls_per_point = np.array(results['joint_cls']),
                           unitvec_per_point = np.array(results['offset_unitvec']),
                           axis_per_point = np.array(results['joint_orient']),
                           joint_type = results['joint_type_gt']
                           ))

    return input_data

def fecth_instances(results, ann_file_path):
    with open(ann_file_path, 'r') as file:
        data = json.load(file)
        results['instance_info'] = data['instances'][0]
        results['n_parts'] = len(data['instances'][0]['links'])
        results['category_id'] = data['instances'][0]['category_id']
        results['img_height'] = data['height']
        results['img_width'] = data['width']
        results['color_path'] = data['color_path']
        results['depth_path'] = data['depth_path']
        results['urdf_id'] = data['instances'][0]['urdf_id']
        results['bbox'] = data['instances'][0]['bbox']
        results['n_max_parts'] = len(data['instances'][0]['links'])
    return results

def fetch_ReArt_state(cat, urdf_id, results):
    state = [None] * results['n_parts']
    with open(f'data/urdf/rest_state.json', 'r') as file:
        reader = file.read()
        data = json.loads(reader)
        for id in data[f'{urdf_id}']:
            if id == '0':
                continue
            state_rad = data[f'{urdf_id}'][f'{id}']['state']
            if cat == 'drawer':
                state_degree = state_rad
            else:
                state_degree = np.radians(state_rad)
            state[int(id)] = state_degree
    return state

def fetch_Art_state(cat, urdf_id, results):
    state = [None] * results['n_parts']
    with open(f'data/urdf/rest_state.json', 'r') as file:
        reader = file.read()
        data = json.loads(reader)
        for id in data[f'{urdf_id}']:
            if id == '0':
                continue
            state_rad = data[f'{urdf_id}'][f'{id}']['state']
            if cat == 'drawer':
                state_degree = state_rad
            else:
                state_degree = np.radians(state_rad)
            state[int(id)] = state_degree
    return state

def fetch_rest_trans(cat, urdf_id, results):
    state = fetch_Art_state(cat, urdf_id, results)
    rest_trans = [np.eye(4)] * results['n_parts']
    joint_xyz = np.array(results['joint_ins']['xyz'])
    joint_rpy = np.array(results['joint_ins']['axis'])
    for i in range(results['n_parts']):
        if i == 0:
            continue
        else:
            if cat == 'eyeglasses' or cat == 'scissors':
                state_i = -state[i]
                rest_trans[i] = arti_utils.RotateAnyAxis(joint_xyz[i], joint_xyz[i] + joint_rpy[i], state_i)
            elif cat == 'laptop' or cat == 'dishwasher':
                state_i = state[i]
                rest_trans[i] = arti_utils.RotateAnyAxis(joint_xyz[i], joint_xyz[i] + joint_rpy[i], state_i)
            else:
                state_i = state[i]
                joint_rpy_d = joint_rpy[i]
                tran = np.eye(4)
                tran[:3, 3] = joint_rpy_d * state_i
                rest_trans[i] = tran
    results['rest_transformation'] = rest_trans
    return results

def fetch_rot_trans(cat, results):
    n_parts = results['n_parts']
    state = results['state_act']
    rot_trans = [np.eye(4)] * results['n_parts']
    joint_xyz = np.array(results['joint_ins']['xyz'])
    joint_rpy = np.array(results['joint_ins']['axis'])
    for i in range(n_parts):
        state_i = state[i]
        if i == 0:
            continue
        if cat == 'drawer':
            state_i = state[i]
            joint_rpy_d = joint_rpy[i]
            tran = np.eye(4)
            tran[:3, 3] = joint_rpy_d * state_i
            rot_trans[i] = tran
        else:
            rot_tran = arti_utils.RotateAnyAxis(joint_xyz[i], joint_xyz[i] + joint_rpy[i], state_i)
            rot_trans[i] = rot_tran

    return rot_trans

def fetch_joint_params(results, joint_param_path, cat):

    joint_ins = dict(xyz=[[0., 0., 0.]],
                     axis=[[0., 0., 0.]],
                     type=[None],
                     parent=[None],
                     child=[None])

    with open(joint_param_path, 'r') as file:
        urdf_metas = json.load(file)
        for data in urdf_metas['urdf_metas']:
            if results['urdf_id'] == data['id']:
                results['norm_factors'] = np.array(data['norm_factors'])
                results['corner_pts'] = np.array(data['corner_pts'])
                joint_types = data['joint_types']
                joint_parents = data['joint_parents']
                joint_children = data['joint_children']
                joint_xyz = data['joint_xyz']
                joint_rpy = data['joint_rpy']

                assert len(joint_types) == len(joint_parents) == len(joint_children) == len(joint_xyz) == len(joint_rpy)

                num_joints = len(joint_types)
                for n in range(num_joints):
                    x, y, z = joint_xyz[n]
                    joint_ins['xyz'].append([y, z, x])
                    r, p, y = joint_rpy[n]  
                    joint_ins['axis'].append([-p, y, r])
                    joint_ins['type'].append(joint_types[n])
                    joint_ins['parent'].append(joint_parents[n])
                    joint_ins['child'].append(joint_children[n])

                results['joint_ins'] = joint_ins
        return results

if __name__ == '__main__':
    input_data = DataGen('data/', 'eyeglasses')