from matplotlib.pyplot import axis
import numpy as np
import os
import random
import torch
import pickle

from torch import nn
from scipy.interpolate import interp1d
import open3d as o3d
from bps_torch.bps import bps_torch

def quaternion_to_rotation_matrix_for_f(q):
    """
    将四元数转换为旋转矩阵
    
    Args:
        q (torch.Tensor): 四元数，形状为 (4, )
        
    Returns:
        torch.Tensor: 旋转矩阵，形状为 (3, 3)
    """
    # 规范化四元数
    q = q / torch.norm(q)
    
    # 四元数的元素
    x, y, z, w = q
    
    # 计算旋转矩阵的元素
    xx = x * x
    xy = x * y
    xz = x * z
    xw = x * w
    
    yy = y * y
    yz = y * z
    yw = y * w
    
    zz = z * z
    zw = z * w
    
    # 构建旋转矩阵
    R = torch.zeros(3, 3, device=q.device)
    R[0, 0] = 1 - 2 * (yy + zz)
    R[0, 1] = 2 * (xy - zw)
    R[0, 2] = 2 * (xz + yw)
    
    R[1, 0] = 2 * (xy + zw)
    R[1, 1] = 1 - 2 * (xx + zz)
    R[1, 2] = 2 * (yz - xw)
    
    R[2, 0] = 2 * (xz - yw)
    R[2, 1] = 2 * (yz + xw)
    R[2, 2] = 1 - 2 * (xx + yy)
    
    return R

def calculate_frobenius_norm_of_rotation_difference(q_pred, q_gt, device=torch.device('cuda:0')):
    """
    计算两个四元数表示的旋转之间的 Frobenius 范数
    
    Args:
        q_pred (torch.Tensor): 预测的四元数，形状为 (4, )
        q_gt (torch.Tensor): 真实的四元数，形状为 (4, )
        device (torch.device): 计算所使用的设备，默认为 CPU
        
    Returns:
        float: 旋转差的 Frobenius 范数
    """
    # 将四元数转换为旋转矩阵
    R_pred = quaternion_to_rotation_matrix_for_f(q_pred.to(device))
    R_gt = quaternion_to_rotation_matrix_for_f(q_gt.to(device))

    # 计算旋转矩阵的差
    diff = R_pred - R_gt

    # 计算差矩阵的 Frobenius 范数
    frobenius_norm = torch.norm(diff, p='fro')

    return frobenius_norm.item()


@torch.jit.script
def quat_mul(a, b):
    assert a.shape == b.shape
    shape = a.shape
    a = a.reshape(-1, 4)
    b = b.reshape(-1, 4)

    x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]
    x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]
    ww = (z1 + x1) * (x2 + y2)
    yy = (w1 - y1) * (w2 + z2)
    zz = (w1 + y1) * (w2 - z2)
    xx = ww + yy + zz
    qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
    w = qq - ww + (z1 - y1) * (y2 - z2)
    x = qq - xx + (x1 + w1) * (x2 + w2)
    y = qq - yy + (w1 - x1) * (y2 + z2)
    z = qq - zz + (z1 + y1) * (w2 - x2)

    quat = torch.stack([x, y, z, w], dim=-1).view(shape)

    return quat

def load_arctic_data(object_name="box", device="cuda:0"):
    home = os.path.expanduser('~')
    object_name = object_name
    device = device

    funtionals_list = ["use", "grab"]
    
    if object_name == "all":
        used_training_objects = [
            "box",
            "capsulemachine",
            "espressomachine",
            "ketchup",
            # "laptop",
            "microwave",
            "mixer",
            "notebook",
            # "phone",
            "scissors",
            "waffleiron",
        ]
    else:
        used_training_objects = [object_name]

    arctic_data = []
    # seq_list = ["07", "08", "09", "10"]
    for obj_i, object_name in enumerate(used_training_objects):
        for functional in funtionals_list:
            seq_list = ["01", "02", "04", "07", "08", "09"]
            for seq_i in seq_list:
                for j in ["01", "02", "03", "04"]:
                    mano_p = "{}/arctic/data/arctic_data/data/raw_seqs/s{}/{}_{}_{}.mano.npy".format(home, seq_i, object_name, functional, j)
                    obj_p = "{}/arctic/data/arctic_data/data/raw_seqs/s{}/{}_{}_{}.object.npy".format(home, seq_i, object_name, functional, j)
                    # MANO
                    try:
                        data = np.load(
                            mano_p,
                            allow_pickle=True,
                        ).item()
                    except:
                        continue

                    mano_processed = "{}/arctic/outputs/processed/seqs/s{}/{}_{}_{}.npy".format(home, seq_i, object_name, functional, j)
                    mano_processed_data = np.load(
                        mano_processed,
                        allow_pickle=True,
                    ).item()

                    num_frames = len(data["right"]["rot"])

                    view_idx = 1

                    # view 1
                    cam2world_matrix = torch.tensor([[ 0.8946, -0.4464,  0.0197,  0.1542],
                    [-0.1109, -0.2646, -0.9580,  0.9951],
                    [ 0.4328,  0.8548, -0.2862,  4.6415],
                    [ 0.0000,  0.0000,  0.0000,  1.0000]])[:3, :3].inverse().repeat(mano_processed_data["cam_coord"]["obj_rot_cam"].shape[0], 1, 1)

                    world2cam_matrix = torch.tensor([[ 0.8946, -0.4464,  0.0197,  0.1542],
                    [-0.1109, -0.2646, -0.9580,  0.9951],
                    [ 0.4328,  0.8548, -0.2862,  4.6415],
                    [ 0.0000,  0.0000,  0.0000,  1.0000]])

                    import dexteroushandenvs.utils.rot as rot
                    quat_cam2world = rot.matrix_to_quaternion(cam2world_matrix).cuda()
                    obj_r_cam = rot.axis_angle_to_quaternion(torch.FloatTensor(mano_processed_data["cam_coord"]["obj_rot_cam"][:, view_idx, :]).cuda())
                    obj_r_world = rot.quaternion_to_axis_angle(
                        rot.quaternion_multiply(quat_cam2world, obj_r_cam)
                    )

                    rot_r_cam = rot.axis_angle_to_quaternion(torch.FloatTensor(mano_processed_data["cam_coord"]["rot_r_cam"][:, view_idx, :]).cuda())
                    rot_r_world = rot.quaternion_to_axis_angle(
                        rot.quaternion_multiply(quat_cam2world, rot_r_cam)
                    )
                    rot_l_cam = rot.axis_angle_to_quaternion(torch.FloatTensor(mano_processed_data["cam_coord"]["rot_l_cam"][:, view_idx, :]).cuda())
                    rot_l_world = rot.quaternion_to_axis_angle(
                        rot.quaternion_multiply(quat_cam2world, rot_l_cam)
                    )

                    obj_rot_quat = rot.axis_angle_to_quaternion(obj_r_world)
                    rot_r_quat = rot.axis_angle_to_quaternion(rot_r_world)
                    rot_l_quat = rot.axis_angle_to_quaternion(rot_l_world)

                    rot_r = torch.FloatTensor(data["right"]["rot"])
                    pose_r = torch.FloatTensor(data["right"]["pose"])
                    trans_r = torch.FloatTensor(data["right"]["trans"])
                    shape_r = torch.FloatTensor(data["right"]["shape"]).repeat(num_frames, 1)
                    fitting_err_r = data["right"]["fitting_err"]

                    rot_l = torch.FloatTensor(data["left"]["rot"])
                    pose_l = torch.FloatTensor(data["left"]["pose"])
                    trans_l = torch.FloatTensor(data["left"]["trans"])
                    shape_l = torch.FloatTensor(data["left"]["shape"]).repeat(num_frames, 1)
                    obj_params = torch.FloatTensor(np.load(obj_p, allow_pickle=True))
                    obj_params[:, 4:7] /= 1000
                    obj_params[:, 1:4] = obj_r_world
                    rot_r = rot_r_world
                    rot_l = rot_l_world

                    begin_frame = 40

                    rot_r = torch.tensor(rot_r, device=device)[begin_frame:]
                    trans_r = torch.tensor(trans_r, device=device)[begin_frame:]
                    rot_l = torch.tensor(rot_l, device=device)[begin_frame:]
                    trans_l = torch.tensor(trans_l, device=device)[begin_frame:]
                    obj_params = torch.tensor(obj_params, device=device)[begin_frame:]

                    obj_rot_quat = torch.tensor(obj_rot_quat, device=device)[begin_frame:]
                    rot_r_quat = torch.tensor(rot_r_quat, device=device)[begin_frame:]
                    rot_l_quat = torch.tensor(rot_l_quat, device=device)[begin_frame:]

                    obj_rot_quat_tem = obj_rot_quat.clone()
                    obj_rot_quat[:, 0] = obj_rot_quat_tem[:, 1].clone()
                    obj_rot_quat[:, 1] = obj_rot_quat_tem[:, 2].clone()
                    obj_rot_quat[:, 2] = obj_rot_quat_tem[:, 3].clone()
                    obj_rot_quat[:, 3] = obj_rot_quat_tem[:, 0].clone()

                    rot_r_quat_tem = rot_r_quat.clone()
                    rot_r_quat[:, 0] = rot_r_quat_tem[:, 1].clone()
                    rot_r_quat[:, 1] = rot_r_quat_tem[:, 2].clone()
                    rot_r_quat[:, 2] = rot_r_quat_tem[:, 3].clone()
                    rot_r_quat[:, 3] = rot_r_quat_tem[:, 0].clone()

                    rot_l_quat_tem = rot_l_quat.clone()
                    rot_l_quat[:, 0] = rot_l_quat_tem[:, 1].clone()
                    rot_l_quat[:, 1] = rot_l_quat_tem[:, 2].clone()
                    rot_l_quat[:, 2] = rot_l_quat_tem[:, 3].clone()
                    rot_l_quat[:, 3] = rot_l_quat_tem[:, 0].clone()

                    # transform quat for arm
                    right_transform_quat = torch.tensor([0.0, -0.707, 0.0, 0.707], dtype=torch.float, device=device).repeat((rot_r_quat.shape[0], 1))
                    left_transform_quat = torch.tensor([0.707, 0.0, 0.707, 0.0], dtype=torch.float, device=device).repeat((rot_l_quat.shape[0], 1))
                    rot_l_quat = quat_mul(rot_l_quat, left_transform_quat)
                    rot_r_quat = quat_mul(rot_r_quat, right_transform_quat)

                    interpolate_time = 1
                    
                    rot_r = interpolate_tensor(rot_r, interpolate_time)
                    trans_r = interpolate_tensor(trans_r, interpolate_time)
                    rot_l = interpolate_tensor(rot_l, interpolate_time)
                    trans_l = interpolate_tensor(trans_l, interpolate_time)
                    obj_params = interpolate_tensor(obj_params, interpolate_time)
                    obj_rot_quat = interpolate_tensor(obj_rot_quat, interpolate_time)
                    rot_r_quat = interpolate_tensor(rot_r_quat, interpolate_time)
                    rot_l_quat = interpolate_tensor(rot_l_quat, interpolate_time)

                    for i, rot_quat in enumerate(obj_rot_quat):
                        if i > 0:
                            if calculate_frobenius_norm_of_rotation_difference(rot_quat, last_obj_rot_global, device=device) > 0.5:
                                obj_rot_quat[i] = last_obj_rot_global.clone()
                            
                        last_obj_rot_global = rot_quat.clone()

                    # offset
                    trans_r[:, 2] += -0.07
                    trans_r[:, 0] += 0.07
                    trans_l[:, 2] += -0.05
                    trans_l[:, 0] += -0.07

                    trans_r[:, 1] += 0.04
                    trans_l[:, 1] += 0.04
                    
                    # obj_bps = get_bps_representation_from_mesh(object_name).view(-1, 1)
                    
                    # obj_bps = torch.tensor(obj_bps, dtype=torch.float, device=device).repeat((rot_r_quat.shape[0], 1))
                    obj_bps = torch.tensor([obj_i], dtype=torch.float, device=device).repeat((rot_r_quat.shape[0], 1))
                    
                    arctic_data.append({"rot_r": rot_r.clone(), "trans_r": trans_r.clone(), "rot_l": rot_l.clone(),
                                        "trans_l": trans_l.clone(), "obj_params": obj_params.clone(), "obj_rot_quat": obj_rot_quat.clone(),
                                        "rot_r_quat": rot_r_quat.clone(), "rot_l_quat": rot_l_quat.clone(), "obj_bps": obj_bps.clone()})

    print("seq_num: ", len(arctic_data))

    
    return arctic_data

def get_bps_representation_from_mesh(object_name):
    asset_point_cloud_files_dict = {
        "box": "../assets/arctic_assets/object_vtemplates/box/",
        "scissors": "../assets/arctic_assets/object_vtemplates/scissors/",
        "microwave": "../assets/arctic_assets/object_vtemplates/microwave/",
        "laptop": "../assets/arctic_assets/object_vtemplates/laptop/",
        "capsulemachine": "../assets/arctic_assets/object_vtemplates/capsulemachine/",
        "ketchup": "../assets/arctic_assets/object_vtemplates/ketchup/",
        "mixer": "../assets/arctic_assets/object_vtemplates/mixer/",
        "notebook": "../assets/arctic_assets/object_vtemplates/notebook/",
        "phone": "../assets/arctic_assets/object_vtemplates/phone/",
        "waffleiron": "../assets/arctic_assets/object_vtemplates/waffleiron/",
        "espressomachine": "../assets/arctic_assets/object_vtemplates/espressomachine/",
    }
    # generate offline point cloud
    used_training_objects = [object_name]

    for i, select_obj in enumerate(used_training_objects):
        pcd = o3d.io.read_triangle_mesh(
            asset_point_cloud_files_dict[select_obj] + "mesh.obj"
        )
        
        scale_mesh_to_unit_sphere(pcd)
        # o3d.visualization.draw_geometries([pcd])
        
    pc = mesh_to_pointcloud(pcd, number_of_points=100000)
        
    bps = bps_torch(bps_type='random_uniform',
                    n_bps_points=256,
                    radius=1.,
                    n_dims=3,
                    custom_basis=None)

    bps_enc = bps.encode(pc,
                        feature_type=['dists','deltas'],
                        x_features=None,
                        custom_basis=None)

    deltas = bps_enc['deltas']
    bps_dec = bps.decode(deltas)
    return bps_dec

def scale_mesh_to_unit_sphere(mesh):
    # 获取网格的外接球半径
    bbox = mesh.get_axis_aligned_bounding_box()
    center = bbox.get_center()
    max_dist = max([np.linalg.norm(center - p) for p in bbox.get_box_points()])
    radius = max_dist / 2.0

    # 计算缩放因子
    scale_factor = 1.0 / radius

    # 缩放网格
    mesh.scale(scale_factor, center)

def mesh_to_pointcloud(mesh, number_of_points):
    # 均匀采样点
    points = np.array(mesh.sample_points_uniformly(number_of_points=number_of_points).points)
    return torch.tensor(points, dtype=torch.float32)

def stack_arctic_data(data, batch_size, num_stack):
    # 初始化一个空的列表来保存结果
    result = []

    # 循环遍历每个 batch
    for i in range(batch_size - num_stack):
        # 获取当前 batch 的样本
        # 将连续的 10 个样本分别合并成一个新的样本
        concatenated_sample = np.concatenate(data[i:i+num_stack], axis=0)
        # 将新的样本添加到结果列表中
        result.append(concatenated_sample)
        
    result = np.stack(result)
        
    return result

def interpolate_tensor(input_tensor, interpolate_time):
    """
    对输入的数据列表进行插值，返回插值后的新数据列表
    参数：
    input_tensor: 输入的张量，维度为(batch, 4)
    new_batch_size: 插值后的新批次大小
    返回值：
    插值后的新数据列表
    """
    batch_size = input_tensor.size(0)
    
    # 原始形状
    original_shape = input_tensor.size()
    # 新形状
    new_batch_size = (original_shape[0] - 1) * interpolate_time + original_shape[0]
    new_shape = (new_batch_size, original_shape[1])
    # 新数据列表
    # 对每个向量进行插值
    interpolated_data = np.zeros((new_shape[0], new_shape[1]))
    for i in range(new_shape[1]):
        # 对每个维度进行线性插值
        x = np.linspace(0, batch_size - 1, batch_size)  # 使用原始批次大小作为 x 数组的长度
        y = input_tensor[:, i].cpu().numpy()
        f = interp1d(x, y, kind='linear')
        new_x = np.linspace(0, batch_size - 1, new_batch_size)
        interpolated_data[:, i] = f(new_x)
    
    # 构建输出张量
    output_tensor = torch.tensor(interpolated_data, dtype=input_tensor.dtype, device=input_tensor.device)

    return output_tensor


if __name__ == "__main__":
    import os
    import json
    import h5py
    import numpy as np

    import robomimic
    import robomimic.utils.file_utils as FileUtils

    # the dataset registry can be found at robomimic/__init__.py
    from robomimic import DATASET_REGISTRY
    
    # get_bps_representation_from_mesh()
    # exit()
    
    # open file
    dataset_path = "/home/user/robomimic/datasets/lift/ph/low_dim_v141.hdf5"
    # dataset_path = "/home/user/DexterousHandEnvs/dexteroushandenvs/high_level_planner/dataset.h5"
    f = h5py.File(dataset_path, "r")

    # each demonstration is a group under "data"
    demos = list(f["data"].keys())
    num_demos = len(demos)

    print("hdf5 file {} has {} demonstrations".format(dataset_path, num_demos))
    
    # each demonstration is named "demo_#" where # is a number.
    # Let's put the demonstration list in increasing episode order
    inds = np.argsort([int(elem[5:]) for elem in demos])
    demos = [demos[i] for i in inds]

    for ep in demos:
        num_actions = f["data/{}/actions".format(ep)].shape[0]
        print("{} has {} samples".format(ep, num_actions))
        
        print(f["data/{}".format(ep)].attrs['num_samples'])
        
    # look at first demonstration
    demo_key = demos[0]
    demo_grp = f["data/{}".format(demo_key)]

    # Each observation is a dictionary that maps modalities to numpy arrays, and
    # each action is a numpy array. Let's print the observations and actions for the 
    # first 5 timesteps of this trajectory.
    for t in range(5):
        print("timestep {}".format(t))
        obs_t = dict()
        next_obs_t = dict()
        # each observation modality is stored as a subgroup
        for k in demo_grp["obs"]:
            obs_t[k] = demo_grp["obs/{}".format(k)][t] # numpy array
        act_t = demo_grp["rewards"][t]

        for k in demo_grp["next_obs"]:
            next_obs_t[k] = demo_grp["next_obs/{}".format(k)][t] # numpy array

        # pretty-print observation and action using json
        obs_t_pp = { k : obs_t[k].tolist() for k in obs_t }
        # print("obs")
        # print(json.dumps(obs_t_pp, indent=4))
        print("rewards")
        print(act_t)
    # print(f["data"].attrs['env_args'])
        
    # exit()
        
    arctic_datasets = load_arctic_data(object_name="all", device="cpu")
    
    # Create a new HDF5 file
    with h5py.File('', 'w') as f:
        # Create the "data" group
        data_group = f.create_group('data')
        
        # Add attributes to the "data" group
        data_group.attrs['total'] = len(arctic_datasets)  # Set initial value for total attribute
        env_args = {
            'env_name': 'arctic',
            'type': 1,
            'env_kwargs': {}  # Update with actual kwargs
        }
        # print(env_args)
        data_group.attrs['env_args'] = json.dumps(env_args)  # Convert dictionary to JSON string
        
        traj_len = 1000
        num_stack = 10

        # Create dummy datasets for demonstration purposes
        for i in range(len(arctic_datasets)):  # Create three demo trajectories
            demo_group = data_group.create_group(f'demo_{i}')
            
            demo_group.attrs['num_samples'] = traj_len - num_stack  # Set initial value for num_samples attribute
            demo_group.attrs['model_file'] = 'xml_string_here'  # Update with actual XML string
            
            # stack arctic data
            rot_r = arctic_datasets[i]["rot_r_quat"].numpy()
            trans_r = arctic_datasets[i]["trans_r"].numpy()
            rot_l = arctic_datasets[i]["rot_l_quat"].numpy()
            trans_l = arctic_datasets[i]["trans_l"].numpy()
            batch_size = traj_len
            
            stacked_rot_r = stack_arctic_data(rot_r, traj_len, num_stack)
            stacked_trans_r = stack_arctic_data(trans_r, traj_len, num_stack)
            stacked_rot_l = stack_arctic_data(rot_l, traj_len, num_stack)
            stacked_trans_l = stack_arctic_data(trans_l, traj_len, num_stack)

            concatenated_actions = np.concatenate((stacked_rot_r, stacked_trans_r, stacked_rot_l, stacked_trans_l), axis=1)

            print(concatenated_actions.shape)

            obj_joint = arctic_datasets[i]["obj_params"][:, 0:1].numpy()
            obj_trans = arctic_datasets[i]["obj_params"][:, 4:7].numpy()
            obj_quat = arctic_datasets[i]["obj_rot_quat"].numpy()
            obj_bps = arctic_datasets[i]["obj_bps"].cpu().numpy()

            stacked_obj_joint = stack_arctic_data(obj_joint, traj_len, num_stack)
            stacked_obj_trans = stack_arctic_data(obj_trans, traj_len, num_stack)
            stacked_obj_quat = stack_arctic_data(obj_quat, traj_len, num_stack)
            stacked_obj_bps = stack_arctic_data(obj_bps, traj_len, num_stack)

            concatenated_obs = np.concatenate((stacked_obj_joint, stacked_obj_trans, stacked_obj_quat, stacked_obj_bps), axis=1)
                                    
            print(concatenated_obs.shape)
                        
            # Create datasets for states, actions, rewards, and dones
            demo_group.create_dataset('states', data=concatenated_obs)  # Add dummy data
            demo_group.create_dataset('actions', data=concatenated_actions)  # Add dummy data
            demo_group.create_dataset('rewards', data=np.zeros((batch_size, 1)))  # Add dummy data
            demo_group.create_dataset('dones', data=np.zeros((batch_size, 1)))  # Add dummy data
            
            # Create groups for observations and next observations
            obs_group = demo_group.create_group('obs')
            next_obs_group = demo_group.create_group('next_obs')

            # Create dummy datasets for observation keys
            obs_group.create_dataset('obj_joint', data=stacked_obj_joint)  # Add dummy data
            next_obs_group.create_dataset('obj_joint', data=stacked_obj_joint)  # Add dummy data
            obs_group.create_dataset('obj_trans', data=stacked_obj_trans)  # Add dummy data
            next_obs_group.create_dataset('obj_trans', data=stacked_obj_trans)  # Add dummy data
            obs_group.create_dataset('obj_quat', data=stacked_obj_quat)  # Add dummy data
            next_obs_group.create_dataset('obj_quat', data=stacked_obj_quat)  # Add dummy data
            obs_group.create_dataset('obj_bps', data=stacked_obj_bps)  # Add dummy data
            next_obs_group.create_dataset('obj_bps', data=stacked_obj_bps)  # Add dummy data
            
        # Create the "mask" group (optional)
        mask_group = f.create_group('mask')
        mask_group.create_dataset('valid', data=[])  # Add dummy data