import os
import numpy as np
import pandas as pd
import seaborn as sns

from natsort import natsorted
from scipy.spatial.transform import Rotation as R
from pytransform3d import rotations, batch_rotations, transformations, trajectories

def quat_to_axis_angle_action(action):
    """
    Convert a quaternion action to an axis-angle action.
    
    Args:
        action: Tensor of shape (..., 8) representing [x,y,z,qw,qx,qy,qz,jaw]

    Returns:
        axis_angle_actions: Tensor of shape (..., 7) representing [x,y,z,rx,ry,rz,jaw]
    """
    quat_actions = action[:, 3:7]  # Shape: (n_actions, 4)

    r_actions = R.from_quat(quat_actions).as_rotvec()

    # Prepare the final diff array
    axis_angle_actions = np.zeros((action.shape[0], 7))  # Shape: (n_actions, 7)

    # Populate the diff_expand array
    axis_angle_actions[:, 0:3] = action[:, 0:3]     # Delta translation
    axis_angle_actions[:, 3:6] = r_actions          # Delta rotation (axis-angle)
    axis_angle_actions[:, 6] = action[:, 7]         # Abs Jaw
    
    return axis_angle_actions



def compute_diffs(ids, data_dir, chunk_size=100, phantoms=False):


    sp_psm1 = ["psm1_sp.position.x", "psm1_sp.position.y", "psm1_sp.position.z",
            "psm1_sp.orientation.x", "psm1_sp.orientation.y", "psm1_sp.orientation.z", "psm1_sp.orientation.w",
            "psm1_jaw_sp"]

    sp_psm2 = ["psm2_sp.position.x", "psm2_sp.position.y", "psm2_sp.position.z",
            "psm2_sp.orientation.x", "psm2_sp.orientation.y", "psm2_sp.orientation.z", "psm2_sp.orientation.w",
            "psm2_jaw_sp"]

    t = 0
    samples = {}

    for id in ids:
        samples[id] = {}
        if phantoms:
            root = os.path.join(data_dir, f"phantom_{id}")
        else:
            root = os.path.join(data_dir, f"tissue_{id}")
        print(root)
        dirlist = [item for item in os.listdir(root) if os.path.isdir(os.path.join(root, item)) ]
        dirlist = natsorted(dirlist)

        total_demo_num = 0
        for dir in dirlist:
            phase = os.path.join(root, dir)
            samples[id][dir] = []
            for item in os.listdir(phase):
                samples[id][dir].append(item)
            total_demo_num += len(samples[id][dir])
        t += total_demo_num
        print(id, ", total demo num =", total_demo_num)
    print("total demo num =", t)
    
    diffs = []

    for id in ids:
        print("id:", id)
        if phantoms:
            root = os.path.join(data_dir, f"phantom_{id}")
        else:
            root = os.path.join(data_dir, f"tissue_{id}")
        dirlist = [item for item in os.listdir(root) if os.path.isdir(os.path.join(root, item)) ]
        dirlist = natsorted(dirlist)
        for phase in samples[id].keys():

            sample = samples[id][phase]
            for s in sample:
                if s == "Corrections":
                    sample_dir = os.path.join(root, phase, s)
                    new_sample = os.listdir(sample_dir)
                    for ss in new_sample:
                        sample_dir = os.path.join(sample_dir, ss)
                        break
                    pth = os.path.join(sample_dir, "ee_csv.csv")
                else:
                    # pth = os.path.join(root, phase, s, "ee_estimate.csv")
                    pth = os.path.join(root, phase, s, "ee_csv.csv")
                csv = pd.read_csv(pth)

                for jj in range(len(csv)):
                    
                    chunk_el_psm1 = csv[sp_psm1].iloc[jj:jj+chunk_size, :].to_numpy()

                    # convert to axis-angle actions
                    axis_angle_actions_psm1 = quat_to_axis_angle_action(chunk_el_psm1)

                    diff_psm1 = np.zeros((chunk_size, 7))

                    # Pad the actions up to the action horizon
                    diff_psm1[:axis_angle_actions_psm1.shape[0], :] = axis_angle_actions_psm1

                    chunk_el_psm2 = csv[sp_psm2].iloc[jj:jj+chunk_size, :].to_numpy()

                    axis_angle_actions_psm2 = quat_to_axis_angle_action(chunk_el_psm2)

                    diff_psm2 = np.zeros((chunk_size, 7))
                    diff_psm2[:axis_angle_actions_psm2.shape[0], :] = axis_angle_actions_psm2

                    diff_stacked = np.column_stack((diff_psm1, diff_psm2))
                    diffs.append(diff_stacked)

        print(len(diffs))

    diffs_np = np.concatenate(diffs, axis=0)
    mean = diffs_np.mean(axis=0)
    std = diffs_np.std(axis=0).clip(1e-2, 10)
    min = diffs_np.min(axis = 0)
    max = diffs_np.max(axis = 0)

    return mean, std, min, max

# Define the main function to generate the task configuration file
def generate_task_config():

    ids = [1, 2, 3, 4, 5, 6, 7, 8, 9]
    data_dir = os.getenv('PATH_TO_DATASET')

    
    mean, std, min, max = compute_diffs(ids, data_dir)

    std_str = ', '.join(map(str, std))
    mean_str = ', '.join(map(str, mean))
    min_str = ', '.join(map(str, min))
    max_str = ', '.join(map(str, max))

    print("mean:", mean_str)
    print("std:", std_str)
    print("min:", min_str)
    print("max:", max_str)

    # write the results into a txt file
    with open("./std_mean.txt", "w") as f:
        f.write(f"tissue ids: {ids}\n")
        f.write(f"mean: {mean_str}\n")
        f.write(f"std: {std_str}\n")
        f.write(f"min: {min_str}\n")
        f.write(f"max: {max_str}\n")


# Run the main function to generate the task configuration
generate_task_config()
