from typing import TypeVar, Union, Optional, Tuple
from omegaconf import DictConfig
import torch
import numpy as np
import pandas as pd

import src.utils.utils as utils


required_IK_data_columns = [
    "tx",
    "ty",
    "dtx",
    "dty",
    "ddtx",
    "ddty",
    "a_pelvis",
    "da_pelvis",
    "dda_pelvis",
    "a_hip_r",
    "da_hip_r",
    "dda_hip_r",
    "a_knee_r",
    "da_knee_r",
    "dda_knee_r",
    "a_ankle_r",
    "da_ankle_r",
    "dda_ankle_r",
    "a_hip_l",
    "da_hip_l",
    "dda_hip_l",
    "a_knee_l",
    "da_knee_l",
    "dda_knee_l",
    "a_ankle_l",
    "da_ankle_l",
    "dda_ankle_l",
]

required_segment_lengths_columns_seifer = [
    "pelvis",
    "thigh_r",
    "shank_r",
    "foot_r",
    "thigh_l",
    "shank_l",
    "foot_l",
]

required_joints_columns_koelewijn = ["subject", "joint", "location"]

required_joints_koelewijn_ordered = [
    "back",
    "ground_pelvis",
    "hip_r",
    "knee_r",
    "ankle_r",
    "hip_l",
    "knee_l",
    "ankle_l",
]

required_contact_points_columns = [
    "subject",
    "segment",
    "position",
]

segment_dependencies = {
    "back_offset": ["pelvis"],
    "torso": ["pelvis"],
    "hip_offset_r": ["pelvis"],
    "thigh_r": ["pelvis", "hip_r"],
    "shank_r": ["pelvis", "hip_r", "knee_r"],
    "ankle_r": ["pelvis", "hip_r", "knee_r", "ankle_r"],
    "foot_r": ["pelvis", "hip_r", "knee_r", "ankle_r"],
    "hip_offset_l": ["pelvis"],
    "thigh_l": ["pelvis", "hip_l"],
    "shank_l": ["pelvis", "hip_l", "knee_l"],
    "ankle_l": ["pelvis", "hip_l", "knee_l", "ankle_l"],
    "foot_l": ["pelvis", "hip_l", "knee_l", "ankle_l"],
}

# Define a type variable that can be either a NumPy array or a PyTorch tensor
NP_T = TypeVar("NP_T", np.ndarray, torch.Tensor)
DF_T = TypeVar("DF_T", pd.DataFrame, torch.Tensor)


def _rotate_2d_data(data: NP_T, angle: NP_T) -> NP_T:
    """
    Rotate 2D data points by a given angle.

    Parameters:
    - data: np.ndarray or torch.Tensor with shape (n, 2), where n is the number of points.
    - angle: same type as data with shape (n,), angles to rotate each point.

    Returns:
    - rotated_data: np.ndarray with shape (n, 2), the rotated data points.
    """
    cos_angle, sin_angle, rotated_data = None, None, None
    if isinstance(data, np.ndarray) and isinstance(angle, np.ndarray):
        assert data.ndim == angle.ndim + 1
        cos_angle = np.cos(angle)
        sin_angle = np.sin(angle)
        rotated_data = np.empty_like(data)
    elif isinstance(data, torch.Tensor) and isinstance(angle, torch.Tensor):
        assert data.dim() == angle.dim() + 1
        cos_angle = torch.cos(angle)
        sin_angle = torch.sin(angle)
        rotated_data = torch.empty_like(data)
    else:
        raise TypeError("data and angle should be either np.ndarray or torch.Tensor")

    rotated_data[:, 0] = cos_angle * data[:, 0] - sin_angle * data[:, 1]  # type: ignore
    rotated_data[:, 1] = sin_angle * data[:, 0] + cos_angle * data[:, 1]  # type: ignore
    return rotated_data


def _rotate_2d_tensors(
    data: torch.Tensor, angle: torch.Tensor, device: torch.device
) -> torch.Tensor:
    """_summary_

    Args:
        data (torch.Tensor): data shape: (batch_size, seq_len, 2)
        angle (torch.Tensor): _description_

    Returns:
        torch.Tensor: _description_
    """
    if data.dim() == 3:
        data = data.unsqueeze(0)
    if angle.dim() == 2:
        angle = angle.unsqueeze(0)

    cos_angle = torch.cos(angle).to(device=device)
    sin_angle = torch.sin(angle).to(device=device)
    rotated_data = torch.empty_like(data, device=device)

    rotated_data[:, :, 0, :] = (
        cos_angle * data[:, :, 0, :] - sin_angle * data[:, :, 1, :]
    )
    rotated_data[:, :, 1, :] = (
        sin_angle * data[:, :, 0, :] + cos_angle * data[:, :, 1, :]
    )

    return rotated_data


def _assert_correct_inputs(
    IK_data: Optional[pd.DataFrame] = None,
    joints: Optional[pd.DataFrame] = None,
    contact_points: Optional[pd.DataFrame] = None,
    segment_positions: Optional[np.ndarray] = None,
    imu_offsets: Optional[np.ndarray] = None,
):
    """Checks for correct column names in the DataFrames and correct joint names in the joints DataFrame.

    Args:
        IK_data (pd.DataFrame):
        joints (pd.DataFrame):
        contact_points (pd.DataFrame):
        imu_offsets (Optional[np.ndarray], optional): Defaults to None.
    """
    if IK_data is not None:
        assert isinstance(IK_data, pd.DataFrame)
        assert all([col in IK_data.columns for col in required_IK_data_columns])

    if joints is not None:
        assert isinstance(joints, pd.DataFrame)
        assert all([col in joints.columns for col in required_joints_columns_koelewijn])
        assert all(
            [col in joints["joint"].values for col in required_joints_koelewijn_ordered]
        ), f"joints['joint'].values: {joints['joint'].values}"
        joint_pos = np.stack(joints["location"].values, axis=1).astype(np.float64)  # type: ignore
        assert joint_pos.shape == (2, 8) or joint_pos.shape == (
            3,
            8,
        ), f"joint_locations.shape: {joint_pos.shape}"

    if contact_points is not None:
        assert isinstance(contact_points, pd.DataFrame)
        assert all(
            [col in contact_points.columns for col in required_contact_points_columns]
        )
        cp_pos = np.stack(contact_points["position"].values, axis=1).astype(np.float64)  # type: ignore
        assert cp_pos.shape == (2, 4) or cp_pos.shape == (
            3,
            4,
        ), f"contact_points.shape: {cp_pos.shape}"

    if segment_positions is not None:
        assert segment_positions.shape == (
            2,
            12,
        ), f"segment_postions.shape: {segment_positions.shape}"

    if imu_offsets is not None:
        assert imu_offsets.shape == (2, 7), f"imu_offsets.shape: {imu_offsets.shape}"


def compose_segment_positions(
    joints: pd.DataFrame, contact_points: pd.DataFrame
) -> np.ndarray:
    """Compose the segment positions from the joint and contact point locations.

    Args:
    - joints: pd.DataFrame with columns ["subject", "joint", "location"].
    - contact_points: pd.DataFrame with columns ["subject", "segment", "position"].

    Returns:
    - segment_positions: np.ndarray with shape (2, 12), the x and y coordinates of the segment positions.
    """

    _assert_correct_inputs(joints=joints, contact_points=contact_points)

    segment_positions = np.zeros((2, 12))
    segment_positions[:, 0] = np.stack(
        [loc for loc in joints[joints["joint"] == "back"]["location"].values]
    ).astype(np.float64)[0, :2]
    segment_positions[:, 1] = np.stack(
        [loc for loc in joints[joints["joint"] == "ground_pelvis"]["location"].values]
    ).astype(np.float64)[0, :2]
    segment_positions[:, 2] = np.stack(
        [loc for loc in joints[joints["joint"] == "hip_r"]["location"].values]
    ).astype(np.float64)[0, :2]
    segment_positions[:, 3] = np.stack(
        [loc for loc in joints[joints["joint"] == "knee_r"]["location"].values]
    ).astype(np.float64)[0, :2]
    segment_positions[:, 4] = np.stack(
        [loc for loc in joints[joints["joint"] == "ankle_r"]["location"].values]
    ).astype(np.float64)[0, :2]
    segment_positions[:, 5:7] = (
        np.stack(
            [
                loc
                for loc in contact_points[contact_points["segment"] == "foot_r"][
                    "position"
                ].values
            ]
        )
        .astype(np.float64)[:, :2]
        .T
    )
    segment_positions[:, 7] = np.stack(
        [loc for loc in joints[joints["joint"] == "hip_l"]["location"].values]
    ).astype(np.float64)[0, :2]
    segment_positions[:, 8] = np.stack(
        [loc for loc in joints[joints["joint"] == "knee_l"]["location"].values]
    ).astype(np.float64)[0, :2]
    segment_positions[:, 9] = np.stack(
        [loc for loc in joints[joints["joint"] == "ankle_l"]["location"].values]
    ).astype(np.float64)[0, :2]
    segment_positions[:, 10:12] = (
        np.stack(
            [
                loc
                for loc in contact_points[contact_points["segment"] == "foot_l"][
                    "position"
                ].values
            ]
        )
        .astype(np.float64)[:, :2]
        .T
    )

    _assert_correct_inputs(segment_positions=segment_positions)

    return segment_positions


def _reconstruct_imus_koelewijn(
    IK_data: pd.DataFrame,
    segment_points: np.ndarray,
    imu_offsets: np.ndarray,
    fps: float = 100,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    _assert_correct_inputs(IK_data=IK_data, segment_positions=segment_points)
    assert fps > 0

    IK_data_dtype = IK_data.dtypes

    t = IK_data.shape[0]

    num_a = len(segment_dependencies)
    cum_a = np.zeros((t, num_a))
    cum_da = np.zeros((t, num_a))
    cum_dda = np.zeros((t, num_a))
    seg_imu_idxs = [1, 3, 4, 6, 8, 9, 11]

    for i, dependencies in enumerate(segment_dependencies.values()):
        a_dependencies = [f"a_{dep}" for dep in dependencies]
        da_dependencies = [f"da_{dep}" for dep in dependencies]
        dda_dependencies = [f"dda_{dep}" for dep in dependencies]
        cum_a[:, i] = IK_data[a_dependencies].sum(axis=1)
        cum_da[:, i] = IK_data[da_dependencies].sum(axis=1)
        cum_dda[:, i] = IK_data[dda_dependencies].sum(axis=1)

    t_0, dt_0, ddt_0 = (
        np.stack(
            [
                IK_data["tx"].values.astype(IK_data_dtype["tx"]),
                IK_data["ty"].values.astype(IK_data_dtype["ty"]),
            ],
            axis=1,
        ),
        np.stack(
            [
                IK_data["dtx"].values.astype(IK_data_dtype["dtx"]),
                IK_data["dty"].values.astype(IK_data_dtype["dty"]),
            ],
            axis=1,
        ),
        np.stack(
            [
                IK_data["ddtx"].values.astype(IK_data_dtype["ddtx"]),
                IK_data["ddty"].values.astype(IK_data_dtype["ddty"]),
            ],
            axis=1,
        ),
    )

    p = imu_offsets.copy()

    # Translational Acceleration + gravity
    ddt_0_g_global = np.repeat(ddt_0[:, :, np.newaxis], 7, axis=2)
    ddt_0_g_global[:, 1] += 9.81  # Adding gravity
    ddt_0_g_rot = _rotate_2d_data(
        ddt_0_g_global, np.negative(cum_a[:, seg_imu_idxs])
    )  # Rotate by -a to get into local frame

    # Coriolis effect
    ddt_cor = np.zeros_like(ddt_0_g_rot)
    ddt_cor[:, 0] += (
        -p[0] * np.square(cum_da[:, seg_imu_idxs]) - cum_dda[:, seg_imu_idxs] * p[1]
    )
    ddt_cor[:, 1] += (
        -p[1] * np.square(cum_da[:, seg_imu_idxs]) + cum_dda[:, seg_imu_idxs] * p[0]
    )

    reconstructed_imus = np.zeros((t, 3, 7))
    reconstructed_imus[:, 0:2] = ddt_0_g_rot + ddt_cor
    reconstructed_imus[:, 2] = cum_da[:, seg_imu_idxs]  # Angular velocities

    segment_starts, _ = _reconstruct_segments_koelewijn(
        IK_data, segment_points, fps=fps
    )
    # Calculate global imu translations for visualization
    global_imu_translations = np.copy(segment_starts[:, :, seg_imu_idxs])
    global_imu_translations[:, 0] += p[0] * np.cos(cum_a[:, seg_imu_idxs]) - p[
        1
    ] * np.sin(cum_a[:, seg_imu_idxs])
    global_imu_translations[:, 1] += p[0] * np.sin(cum_a[:, seg_imu_idxs]) + p[
        1
    ] * np.cos(cum_a[:, seg_imu_idxs])

    imu_data_df = pd.DataFrame(
        reconstructed_imus.reshape((t, -1), order="F"),
        columns=[
            # IMU data
            "imu_pelvis_ddx",
            "imu_pelvis_ddy",
            "imu_pelvis_da",
            "imu_thigh_r_ddx",
            "imu_thigh_r_ddy",
            "imu_thigh_r_da",
            "imu_shank_r_ddx",
            "imu_shank_r_ddy",
            "imu_shank_r_da",
            "imu_foot_r_ddx",
            "imu_foot_r_ddy",
            "imu_foot_r_da",
            "imu_thigh_l_ddx",
            "imu_thigh_l_ddy",
            "imu_thigh_l_da",
            "imu_shank_l_ddx",
            "imu_shank_l_ddy",
            "imu_shank_l_da",
            "imu_foot_l_ddx",
            "imu_foot_l_ddy",
            "imu_foot_l_da",
        ],
    )

    global_imu_t_df = pd.DataFrame(
        global_imu_translations.reshape((t, -1), order="F"),
        columns=[
            # Global imu translations
            "imu_pelvis_global_x",
            "imu_pelvis_global_y",
            "imu_thigh_r_global_x",
            "imu_thigh_r_global_y",
            "imu_shank_r_global_x",
            "imu_shank_r_global_y",
            "imu_foot_r_global_x",
            "imu_foot_r_global_y",
            "imu_thigh_l_global_x",
            "imu_thigh_l_global_y",
            "imu_shank_l_global_x",
            "imu_shank_l_global_y",
            "imu_foot_l_global_x",
            "imu_foot_l_global_y",
        ],
    )

    assert imu_data_df.size == reconstructed_imus.size
    assert global_imu_t_df.size == global_imu_translations.size

    return imu_data_df, global_imu_t_df


def _reconstruct_imus_koelewijn_tensors(
    IK_data: torch.Tensor,
    segment_points: torch.Tensor,
    imu_offsets: torch.Tensor,
    device: torch.device,
    fps: float = 100,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # treat unbatched data as data with batch size 1
    if IK_data.dim() == 2:
        IK_data = IK_data.unsqueeze(0)
    if segment_points.dim() == 2:
        segment_points = segment_points.unsqueeze(0)
    if imu_offsets.dim() == 2:
        imu_offsets = imu_offsets.unsqueeze(0)

    IK_columns = utils.GLOBAL_CFG.datamodule.dataset_variables.IK_data

    batch_size = IK_data.size(0)
    t = IK_data.size(1)

    # ddt_0.shape (batch_size, t, 2)
    ddt_0 = torch.stack(
        [
            IK_data[:, :, IK_columns.index("ddtx")],
            IK_data[:, :, IK_columns.index("ddty")],
        ],
        dim=2,
    )

    num_a = len(segment_dependencies)
    cum_a = torch.zeros((batch_size, t, num_a), device=device)
    cum_da = torch.zeros((batch_size, t, num_a), device=device)
    cum_dda = torch.zeros((batch_size, t, num_a), device=device)
    seg_imu_idxs = [1, 3, 4, 6, 8, 9, 11]
    num_imus = len(seg_imu_idxs)

    for i, dependencies in enumerate(segment_dependencies.values()):
        a_dependencies = [f"a_{dep}" for dep in dependencies]
        da_dependencies = [f"da_{dep}" for dep in dependencies]
        dda_dependencies = [f"dda_{dep}" for dep in dependencies]
        cum_a[:, :, i] = IK_data[
            :, :, [IK_columns.index(dep) for dep in a_dependencies]
        ].sum(dim=-1)
        cum_da[:, :, i] = IK_data[
            :, :, [IK_columns.index(dep) for dep in da_dependencies]
        ].sum(dim=-1)
        cum_dda[:, :, i] = IK_data[
            :, :, [IK_columns.index(dep) for dep in dda_dependencies]
        ].sum(dim=-1)

    # Translational Acceleration + gravity
    # ddt_0_g_global.shape (batch_size, t, 2, num_imus)
    ddt_0_g_global = ddt_0.unsqueeze(-1).repeat(1, 1, 1, num_imus).to(device=device)
    ddt_0_g_global[:, :, 1] += 9.81  # Adding gravity
    # ddt_0_g_rot.shape (batch_size, t, 2, num_imus)
    ddt_0_g_rot = _rotate_2d_tensors(
        ddt_0_g_global, cum_a[:, :, seg_imu_idxs].negative(), device=device
    )  # Rotate by -a to get into local frame

    # Coriolis effect
    # ddt_0_g_cor.shape (batch_size, t, 2, num_imus)
    ddt_cor = torch.zeros_like(ddt_0_g_rot, device=device)
    p = imu_offsets.clone().to(device=device)
    ddt_cor[:, :, 0] += (
        -p[:, :, 0] * torch.square(cum_da[:, :, seg_imu_idxs])
        - cum_dda[:, :, seg_imu_idxs] * p[:, :, 1]
    )
    ddt_cor[:, :, 1] += (
        -p[:, :, 1] * torch.square(cum_da[:, :, seg_imu_idxs])
        + cum_dda[:, :, seg_imu_idxs] * p[:, :, 0]
    )

    reconstructed_imus = torch.zeros((batch_size, t, 3, 7), device=device)
    reconstructed_imus[:, :, 0:2] = ddt_0_g_rot + ddt_cor
    reconstructed_imus[:, :, 2] = cum_da[:, :, seg_imu_idxs]  # Angular velocities

    segment_starts, segment_ends = _reconstruct_segments_koelewijn_tensors(
        IK_data, segment_points, fps=fps, device=device
    )
    # Calculate global imu translations for visualization
    global_imu_translations = (
        segment_starts[:, :, :, seg_imu_idxs].clone().to(device=device)
    )
    global_imu_translations[:, :, 0] += p[:, :, 0] * torch.cos(
        cum_a[:, :, seg_imu_idxs]
    ) - p[:, :, 1] * torch.sin(cum_a[:, :, seg_imu_idxs])
    global_imu_translations[:, :, 1] += p[:, :, 0] * torch.sin(
        cum_a[:, :, seg_imu_idxs]
    ) + p[:, :, 1] * torch.cos(cum_a[:, :, seg_imu_idxs])

    def _reshape_fortran(x, shape):
        if len(x.shape) > 0:
            x = x.permute(*reversed(range(len(x.shape))))
        return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))

    reconstructed_imus = _reshape_fortran(reconstructed_imus, (batch_size, t, -1))
    global_imu_translations = _reshape_fortran(
        global_imu_translations, (batch_size, t, -1)
    )

    return reconstructed_imus, global_imu_translations


def _reconstruct_segments_koelewijn_tensors(
    IK_data: torch.Tensor,
    segment_points: torch.Tensor,
    device: torch.device,
    fps: float = 100,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if IK_data.dim() == 2:
        IK_data = IK_data.unsqueeze(0)
    if segment_points.dim() == 2:
        segment_points = segment_points.unsqueeze(0)

    IK_columns = utils.GLOBAL_CFG.datamodule.dataset_variables.IK_data

    batch_size = IK_data.size(0)
    t = IK_data.size(1)

    cum_a = torch.zeros((batch_size, t, 12), device=device)

    for i, dependencies in enumerate(segment_dependencies.values()):
        a_dependencies = [f"a_{dep}" for dep in dependencies]
        cum_a[:, :, i] = IK_data[
            :, :, [IK_columns.index(dep) for dep in a_dependencies]
        ].sum(dim=2)

    segment_starts = torch.zeros((batch_size, t, 2, 12), device=device)
    segment_ends = torch.zeros((batch_size, t, 2, 12), device=device)
    """
    segments:
    0: back_offset
    
    1: torso
    2: hip_offset_r
    3: right thigh
    4: right shank
    5: right ankle
    6: right foot
    
    7: hip offset l
    8: left thigh
    9: left shank
    10: left ankle
    11: left foot
    """

    segment_offsets_x = segment_points[:, :, 0] * torch.cos(cum_a) - segment_points[
        :, :, 1
    ] * torch.sin(cum_a)
    segment_offsets_y = segment_points[:, :, 0] * torch.sin(cum_a) + segment_points[
        :, :, 1
    ] * torch.cos(cum_a)
    segment_offsets = torch.stack([segment_offsets_x, segment_offsets_y], dim=2)
    # Back offset
    segment_starts[:, :, :, 0] = IK_data[
        :, :, [IK_columns.index(dep) for dep in ["tx", "ty"]]
    ]
    segment_ends[:, :, :, 0] = segment_starts[:, :, :, 0] + segment_offsets[:, :, :, 0]
    # Torso
    segment_starts[:, :, :, 1] = segment_ends[:, :, :, 0]
    segment_ends[:, :, :, 1] = segment_starts[:, :, :, 1] + segment_offsets[:, :, :, 1]
    # Hip offset right
    segment_starts[:, :, :, 2] = segment_starts[:, :, :, 0]
    segment_ends[:, :, :, 2] = segment_starts[:, :, :, 2] + segment_offsets[:, :, :, 2]
    # Right thight
    segment_starts[:, :, :, 3] = segment_ends[:, :, :, 2]
    segment_ends[:, :, :, 3] = segment_starts[:, :, :, 3] + segment_offsets[:, :, :, 3]
    # Right shank
    segment_starts[:, :, :, 4] = segment_ends[:, :, :, 3]
    segment_ends[:, :, :, 4] = segment_starts[:, :, :, 4] + segment_offsets[:, :, :, 4]
    # Right ankle
    segment_starts[:, :, :, 5] = segment_ends[:, :, :, 4]
    segment_ends[:, :, :, 5] = segment_starts[:, :, :, 5] + segment_offsets[:, :, :, 5]
    # Right foot
    segment_starts[:, :, :, 6] = segment_ends[:, :, :, 4]
    segment_ends[:, :, :, 6] = segment_starts[:, :, :, 6] + segment_offsets[:, :, :, 6]
    # Hip offset left
    segment_starts[:, :, :, 7] = segment_starts[:, :, :, 0]
    segment_ends[:, :, :, 7] = segment_starts[:, :, :, 7] + segment_offsets[:, :, :, 7]
    # Left thight
    segment_starts[:, :, :, 8] = segment_ends[:, :, :, 7]
    segment_ends[:, :, :, 8] = segment_starts[:, :, :, 8] + segment_offsets[:, :, :, 8]
    # Left shank
    segment_starts[:, :, :, 9] = segment_ends[:, :, :, 8]
    segment_ends[:, :, :, 9] = segment_starts[:, :, :, 9] + segment_offsets[:, :, :, 9]
    # Left ankle
    segment_starts[:, :, :, 10] = segment_ends[:, :, :, 9]
    segment_ends[:, :, :, 10] = (
        segment_starts[:, :, :, 10] + segment_offsets[:, :, :, 10]
    )
    # Left foot
    segment_starts[:, :, :, 11] = segment_ends[:, :, :, 9]
    segment_ends[:, :, :, 11] = (
        segment_starts[:, :, :, 11] + segment_offsets[:, :, :, 11]
    )

    return segment_starts, segment_ends


def _reconstruct_segments_koelewijn(
    IK_data: pd.DataFrame,
    segment_points: np.ndarray,
    fps: float = 100,
) -> Tuple[np.ndarray, np.ndarray]:
    _assert_correct_inputs(IK_data=IK_data, segment_positions=segment_points)
    assert fps > 0

    t = len(IK_data)

    cum_a = np.zeros((t, 12))

    for i, dependencies in enumerate(segment_dependencies.values()):
        a_dependencies = [f"a_{dep}" for dep in dependencies]
        cum_a[:, i] = IK_data[a_dependencies].sum(axis=1)

    segment_starts = np.zeros((t, 2, 12))
    segment_ends = np.zeros((t, 2, 12))
    """
    segments:
    0: back_offset
    
    1: torso
    2: hip_offset_r
    3: right thigh
    4: right shank
    5: right ankle
    6: right foot
    
    7: hip offset l
    8: left thigh
    9: left shank
    10: left ankle
    11: left foot
    """

    segment_offsets_x = segment_points[0] * np.cos(cum_a) - segment_points[1] * np.sin(
        cum_a
    )
    segment_offsets_y = segment_points[0] * np.sin(cum_a) + segment_points[1] * np.cos(
        cum_a
    )
    segment_offsets = np.stack([segment_offsets_x, segment_offsets_y], axis=1)
    # Back offset
    segment_starts[:, :, 0] = IK_data[["tx", "ty"]].values
    segment_ends[:, :, 0] = segment_starts[:, :, 0] + segment_offsets[:, :, 0]
    # Torso
    segment_starts[:, :, 1] = segment_ends[:, :, 0]
    segment_ends[:, :, 1] = segment_starts[:, :, 1] + segment_offsets[:, :, 1]
    # Hip offset right
    segment_starts[:, :, 2] = segment_starts[:, :, 0]
    segment_ends[:, :, 2] = segment_starts[:, :, 2] + segment_offsets[:, :, 2]
    # Right thight
    segment_starts[:, :, 3] = segment_ends[:, :, 2]
    segment_ends[:, :, 3] = segment_starts[:, :, 3] + segment_offsets[:, :, 3]
    # Right shank
    segment_starts[:, :, 4] = segment_ends[:, :, 3]
    segment_ends[:, :, 4] = segment_starts[:, :, 4] + segment_offsets[:, :, 4]
    # Right ankle
    segment_starts[:, :, 5] = segment_ends[:, :, 4]
    segment_ends[:, :, 5] = segment_starts[:, :, 5] + segment_offsets[:, :, 5]
    # Right foot
    segment_starts[:, :, 6] = segment_ends[:, :, 4]
    segment_ends[:, :, 6] = segment_starts[:, :, 6] + segment_offsets[:, :, 6]
    # Hip offset left
    segment_starts[:, :, 7] = segment_starts[:, :, 0]
    segment_ends[:, :, 7] = segment_starts[:, :, 7] + segment_offsets[:, :, 7]
    # Left thight
    segment_starts[:, :, 8] = segment_ends[:, :, 7]
    segment_ends[:, :, 8] = segment_starts[:, :, 8] + segment_offsets[:, :, 8]
    # Left shank
    segment_starts[:, :, 9] = segment_ends[:, :, 8]
    segment_ends[:, :, 9] = segment_starts[:, :, 9] + segment_offsets[:, :, 9]
    # Left ankle
    segment_starts[:, :, 10] = segment_ends[:, :, 9]
    segment_ends[:, :, 10] = segment_starts[:, :, 10] + segment_offsets[:, :, 10]
    # Left foot
    segment_starts[:, :, 11] = segment_ends[:, :, 9]
    segment_ends[:, :, 11] = segment_starts[:, :, 11] + segment_offsets[:, :, 11]

    return segment_starts, segment_ends


def reconstruct_imus(
    IK_data: Union[torch.Tensor, pd.DataFrame],
    segment_pos: Union[torch.Tensor, np.ndarray],
    imu_offsets: Union[torch.Tensor, np.ndarray],
    device: torch.device = (
        torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    ),
    fps: float = 100,
) -> Tuple[Union[torch.Tensor, pd.DataFrame], Union[torch.Tensor, pd.DataFrame]]:
    """Reconstruction of IMU sensor data from IK data.

    Args:
        IK_data (pd.DataFrame): Inverse kinematics data.
        segment_positions (np.ndarray): Segment positions.
        imu_offsets (np.ndarray): IMU offsets relative to the segment start they are attached to.
        IK_model (Literal["seifer", "koelewijn"], optional): Set which reconstruction model should be used. Defaults to "koelewijn".
        fps (float, optional): Framerate of the provided data. Defaults to 100.

    Raises:
        NotImplementedError: Seifer reconstruction model is not fully implemented yet.

    Returns:
        Tuple[pd.DataFrame, pd.DataFrame]: Returns the reconstructed IMU sensor readings and global IMU translations (used for visualizing movement).
    """
    IK_columns = utils.GLOBAL_CFG.datamodule.dataset_variables.IK_data

    if (
        isinstance(IK_data, pd.DataFrame)
        and isinstance(segment_pos, np.ndarray)
        and isinstance(imu_offsets, np.ndarray)
    ):
        if segment_pos.shape[-2:] != (2, 12):
            segment_pos = segment_pos.reshape(
                segment_pos.shape[0], segment_pos.shape[1], 12, 2
            ).transpose(0, 1, 3, 2)
        if imu_offsets.shape[-2:] != (2, 7):
            imu_offsets = imu_offsets.reshape(
                imu_offsets.shape[0], imu_offsets.shape[1], 7, 2
            ).transpose(0, 1, 3, 2)
        _assert_correct_inputs(
            IK_data=IK_data,
            segment_positions=segment_pos,
            imu_offsets=imu_offsets,
        )
        return _reconstruct_imus_koelewijn(IK_data, segment_pos, imu_offsets, fps=fps)
    elif (
        isinstance(IK_data, torch.Tensor)
        and isinstance(segment_pos, torch.Tensor)
        and isinstance(imu_offsets, torch.Tensor)
    ):
        if segment_pos.shape[-2:] != (2, 12):
            segment_pos = segment_pos.reshape(
                segment_pos.shape[0], segment_pos.shape[1], 12, 2
            ).permute(0, 1, 3, 2)
        if imu_offsets.shape[-2:] != (2, 7):
            imu_offsets = imu_offsets.reshape(
                imu_offsets.shape[0], imu_offsets.shape[1], 7, 2
            ).permute(0, 1, 3, 2)
        assert (
            IK_data.shape[-1] == IK_columns.index("dda_lumbar") + 1
        ), f"IK_data.shape: {IK_data.shape} | Expected: (..., {IK_columns.index('dda_lumbar') + 1})"
        assert segment_pos.shape[-2:] == (
            2,
            12,
        ), f"segment_positions.shape: {segment_pos.shape} | Expected: (..., 2, 12)"
        assert imu_offsets.shape[-2:] == (
            2,
            7,
        ), f"imu_offsets.shape: {imu_offsets.shape} | Expected: (..., 2, 7)"
        return _reconstruct_imus_koelewijn_tensors(
            IK_data, segment_pos, imu_offsets, fps=fps, device=device
        )
    else:
        raise TypeError(
            "IK_data, segment_positions and imu_offsets should be either np.ndarray or torch.Tensor"
        )


def reconstruct_segments(
    IK_data: pd.DataFrame,
    segment_pos: np.ndarray,
    fps: float = 100,
) -> Tuple[np.ndarray, np.ndarray]:
    if segment_pos.shape[-2:] != (2, 12):
        segment_pos = segment_pos.reshape(-1, 12, 2).transpose(0, 2, 1)
    _assert_correct_inputs(IK_data=IK_data, segment_positions=segment_pos)

    return _reconstruct_segments_koelewijn(IK_data, segment_pos, fps=fps)


def _reconstruct_imus(
    IK_data: torch.Tensor,
    IK_data_columns: list[str],
    cum_a: torch.Tensor,
    cum_da: torch.Tensor,
    cum_dda: torch.Tensor,
    p: torch.Tensor,
    segment_imu_idxs: list[int],
    device: torch.device,
):
    # IK_data.shape (batch_size, t, len(IK_data_columns))
    assert IK_data.ndim == 3
    assert IK_data.size(-1) == len(IK_data_columns)
    # cum_*.shape (batch_size, t, num_angles)
    assert cum_a.ndim == cum_da.ndim == cum_dda.ndim == 3
    assert (
        cum_a.size(-1)
        == cum_da.size(-1)
        == cum_dda.size(-1)
        == len(segment_dependencies)
    )
    # p.shape (batch_size, t, 2, 7)
    assert p.ndim == 4
    assert p.size(-1) == len(
        segment_imu_idxs
    ), f"Expected {len(segment_imu_idxs)} IMUs, got {p.size(-1)}"
    assert (
        p.size(-2) == 2
    ), f"Expected 2 dimensions for the IMU offsets, got {p.size(-2)}"

    batch_size = IK_data.size(0)
    t = IK_data.size(1)
    num_imus = len(segment_imu_idxs)

    # Translational Acceleration + gravity
    # ddt_0.shape (batch_size, t, 2)
    ddt_0 = torch.stack(
        [
            IK_data[:, :, IK_data_columns.index("ddtx")],
            IK_data[:, :, IK_data_columns.index("ddty")],
        ],
        dim=2,
    )
    # ddt_0_g_global.shape (batch_size, t, 2, num_imus)
    ddt_0_g_global = ddt_0.unsqueeze(-1).repeat(1, 1, 1, num_imus).to(device=device)
    ddt_0_g_global[:, :, 1] += 9.81  # Adding gravity
    # ddt_0_g_rot.shape (batch_size, t, 2, num_imus)
    ddt_0_g_rot = _rotate_2d_tensors(
        ddt_0_g_global, cum_a[:, :, segment_imu_idxs].negative(), device=device
    )  # Rotate by -a to get into local frame

    # Coriolis effect
    # ddt_0_g_cor.shape (batch_size, t, 2, num_imus)
    ddt_cor = torch.zeros_like(ddt_0_g_rot, device=device)
    ddt_cor[:, :, 0] += (
        -p[:, :, 0] * torch.square(cum_da[:, :, segment_imu_idxs])
        - cum_dda[:, :, segment_imu_idxs] * p[:, :, 1]
    )
    ddt_cor[:, :, 1] += (
        -p[:, :, 1] * torch.square(cum_da[:, :, segment_imu_idxs])
        + cum_dda[:, :, segment_imu_idxs] * p[:, :, 0]
    )
    # Compose the reconstructed IMU data
    reconstructed_imu_data = torch.zeros((batch_size, t, 3, 7), device=device)
    reconstructed_imu_data[:, :, 0:2] = ddt_0_g_rot + ddt_cor
    reconstructed_imu_data[:, :, 2] = cum_da[
        :, :, segment_imu_idxs
    ]  # Angular velocities
    reconstructed_imu_data = _reshape_fortran(
        reconstructed_imu_data, (batch_size, t, -1)
    )
    return reconstructed_imu_data


def _reconstruct_segments(
    IK_data: torch.Tensor,
    IK_data_columns: list[str],
    cum_a: torch.Tensor,
    da: torch.Tensor,
    segment_pos: torch.Tensor,
    segment_imu_idxs: list[int],
    device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    segments:
    0: back_offset

    1: torso
    2: hip_offset_r
    3: right thigh
    4: right shank
    5: right ankle
    6: right foot

    7: hip offset l
    8: left thigh
    9: left shank
    10: left ankle
    11: left foot
    """

    IK_columns = utils.GLOBAL_CFG.datamodule.dataset_variables.IK_data

    batch_size = IK_data.size(0)
    t = IK_data.size(1)

    # Define segment dependencies for starts and ends
    segment_structure = {
        0: None,  # 'Back' starts from the initial data
        1: 0,  # 'Torso' starts from 'Back' end
        2: 0,  # 'Hip offset right' starts from 'Back' start
        3: 2,  # 'Right thigh' starts from 'Hip offset right' end
        4: 3,  # 'Right shank' starts from 'Right thigh' end
        5: 4,  # 'Right ankle' starts from 'Right shank' end
        6: 4,  # 'Right foot' starts from 'Right shank' end
        7: 0,  # 'Hip offset left' starts from 'Back' start
        8: 7,  # 'Left thigh' starts from 'Hip offset left' end
        9: 8,  # 'Left shank' starts from 'Left thigh' end
        10: 9,  # 'Left ankle' starts from 'Left shank' end
        11: 9,  # 'Left foot' starts from 'Left shank' end
    }

    segment_starts = torch.zeros((batch_size, t, 2, 12), device=device)
    segment_ends = torch.zeros((batch_size, t, 2, 12), device=device)

    reshaped_segment_pos = segment_pos.reshape(batch_size, t, 12, 2).permute(0, 1, 3, 2)

    segment_offsets_x = (
        torch.cos(cum_a) * reshaped_segment_pos[:, :, 0]
        - torch.sin(cum_a) * reshaped_segment_pos[:, :, 1]
    )
    segment_offsets_y = (
        torch.sin(cum_a) * reshaped_segment_pos[:, :, 0]
        + torch.cos(cum_a) * reshaped_segment_pos[:, :, 1]
    )
    segment_offsets = torch.stack([segment_offsets_x, segment_offsets_y], dim=2)

    segment_starts[:, :, :, 0] = IK_data[
        :, :, [IK_columns.index(dep) for dep in ["tx", "ty"]]
    ]
    # Update all segments
    for seg_id in range(12):
        if segment_structure[seg_id] is not None:
            # Set the start of current segment to the end of its dependency
            segment_starts[:, :, :, seg_id] = segment_ends[
                :, :, :, segment_structure[seg_id]
            ]
        # Update the end of current segment
        segment_ends[:, :, :, seg_id] = (
            segment_starts[:, :, :, seg_id] + segment_offsets[:, :, :, seg_id]
        )

    d_segment_starts = torch.zeros((batch_size, t, 2, 12), device=device)
    d_segment_ends = torch.zeros((batch_size, t, 2, 12), device=device)
    d_segment_offsets_x = (
        -torch.sin(cum_a) * reshaped_segment_pos[:, :, 0] * da
        - torch.cos(cum_a) * reshaped_segment_pos[:, :, 1] * da
    )
    d_segment_offsets_y = (
        +torch.cos(cum_a) * reshaped_segment_pos[:, :, 0] * da
        - torch.sin(cum_a) * reshaped_segment_pos[:, :, 1] * da
    )
    d_segment_offsets = torch.stack([d_segment_offsets_x, d_segment_offsets_y], dim=2)

    d_segment_starts[:, :, :, 0] = IK_data[
        :, :, [IK_data_columns.index(dep) for dep in ["dtx", "dty"]]
    ]
    for seg_id in range(12):
        if segment_structure[seg_id] is not None:
            # Set the start of current segment to the end of its dependency
            d_segment_starts[:, :, :, seg_id] = d_segment_ends[
                :, :, :, segment_structure[seg_id]
            ]
        # Update the end of current segment
        d_segment_ends[:, :, :, seg_id] = (
            d_segment_starts[:, :, :, seg_id] + d_segment_offsets[:, :, :, seg_id]
        )

    return segment_starts, segment_ends, d_segment_starts, d_segment_ends


def _reconstruct_global_imu_data(
    segment_starts: torch.Tensor,
    p: torch.Tensor,
    cum_a: torch.Tensor,
    seg_imu_idxs: list[int],
    device: torch.device,
):
    batch_size = segment_starts.size(0)
    t = segment_starts.size(1)

    # Calculate global imu translations
    global_reconstructed_imu_data = (
        segment_starts[:, :, :, seg_imu_idxs].clone().to(device=device)
    )
    global_reconstructed_imu_data[:, :, 0] += p[:, :, 0] * torch.cos(
        cum_a[:, :, seg_imu_idxs]
    ) - p[:, :, 1] * torch.sin(cum_a[:, :, seg_imu_idxs])
    global_reconstructed_imu_data[:, :, 1] += p[:, :, 0] * torch.sin(
        cum_a[:, :, seg_imu_idxs]
    ) + p[:, :, 1] * torch.cos(cum_a[:, :, seg_imu_idxs])

    global_reconstructed_imu_data = _reshape_fortran(
        global_reconstructed_imu_data, (batch_size, t, -1)
    )
    return global_reconstructed_imu_data


def _reshape_fortran(x, shape):
    if len(x.shape) > 0:
        x = x.permute(*reversed(range(len(x.shape))))
    return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))


def reconstruct_IK_data(
    IK_data: torch.Tensor,
    segment_pos: torch.Tensor,
    imu_offsets: torch.Tensor,
    cfg: DictConfig,
    device: torch.device = (
        torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    ),
) -> dict[str, torch.Tensor]:
    """Computes the reconstructed body model from the IK data.

    Args:
        IK_data (torch.Tensor): Tensor as described in configs/datamodule/koelewijn.yaml.
        segment_pos (torch.Tensor): Tensor as describted in configs/datamodule/koelewijn.yaml.
        imu_offsets (torch.Tensor): Tensor as describted in configs/datamodule/koelewijn.yaml.

    Returns:
        dict[str, torch.Tensor]: Dictionary containing the reconstructed data.
        "IMU_data", "global_IMU_data", "segment_starts", "segment_ends"
    """
    reconstructed_data = {}
    IK_data_columns = cfg.datamodule.dataset_variables.IK_data
    segment_pos_columns = cfg.datamodule.dataset_variables.segment_pos
    imu_offsets_columns = cfg.datamodule.dataset_variables.imu_offsets

    assert IK_data.dim() == 3, f"Expected 3 dimensions, got {IK_data.dim()}"
    assert IK_data.size(-1) == len(IK_data_columns)
    assert segment_pos.dim() == 3
    assert segment_pos.size(-1) == len(segment_pos_columns)
    assert imu_offsets.dim() == 3
    assert imu_offsets.size(-1) == len(imu_offsets_columns)

    batch_size = IK_data.size(0)
    t = IK_data.size(1)

    # imu_offsets.shape (batch_size, t, 14)
    # p.shape (batch_size, t, 2, 7) (batch_size, t, x/y, imu)
    p = (
        imu_offsets.reshape(batch_size, t, 7, 2)
        .permute(0, 1, 3, 2)
        .clone()
        .to(device=device)
    )

    num_a = len(segment_dependencies)
    a = torch.zeros((batch_size, t, num_a), device=device)
    da = torch.zeros((batch_size, t, num_a), device=device)
    dda = torch.zeros((batch_size, t, num_a), device=device)
    for i, dependencies in enumerate(segment_dependencies.values()):
        a[:, :, i] = IK_data[:, :, IK_data_columns.index(f"a_{dependencies[-1]}")]
        da[:, :, i] = IK_data[:, :, IK_data_columns.index(f"da_{dependencies[-1]}")]
        dda[:, :, i] = IK_data[:, :, IK_data_columns.index(f"dda_{dependencies[-1]}")]

    reconstructed_data["a"] = a
    reconstructed_data["da"] = da
    reconstructed_data["dda"] = dda

    cum_a = torch.zeros((batch_size, t, num_a), device=device)
    cum_da = torch.zeros((batch_size, t, num_a), device=device)
    cum_dda = torch.zeros((batch_size, t, num_a), device=device)
    # Calculate cumulative angle, derivative and second derivative
    for i, dependencies in enumerate(segment_dependencies.values()):
        a_dependencies = [f"a_{dep}" for dep in dependencies]
        da_dependencies = [f"da_{dep}" for dep in dependencies]
        dda_dependencies = [f"dda_{dep}" for dep in dependencies]
        cum_a[:, :, i] = IK_data[
            :, :, [IK_data_columns.index(dep) for dep in a_dependencies]
        ].sum(dim=-1)
        cum_da[:, :, i] = IK_data[
            :, :, [IK_data_columns.index(dep) for dep in da_dependencies]
        ].sum(dim=-1)
        cum_dda[:, :, i] = IK_data[
            :, :, [IK_data_columns.index(dep) for dep in dda_dependencies]
        ].sum(dim=-1)

    seg_imu_idxs = [1, 3, 4, 6, 8, 9, 11]

    reconstructed_data["IMU_data"] = _reconstruct_imus(
        IK_data=IK_data,
        IK_data_columns=IK_data_columns,
        cum_a=cum_a,
        cum_da=cum_da,
        cum_dda=cum_dda,
        p=p,
        segment_imu_idxs=seg_imu_idxs,
        device=device,
    )

    (
        reconstructed_data["segment_starts"],
        reconstructed_data["segment_ends"],
        reconstructed_data["d_segment_starts"],
        reconstructed_data["d_segment_ends"],
    ) = _reconstruct_segments(
        IK_data=IK_data,
        IK_data_columns=IK_data_columns,
        cum_a=cum_a,
        da=cum_da,
        segment_pos=segment_pos,
        segment_imu_idxs=seg_imu_idxs,
        device=device,
    )

    reconstructed_data["global_IMU_data"] = _reconstruct_global_imu_data(
        segment_starts=reconstructed_data["segment_starts"],
        p=p,
        cum_a=cum_a,
        seg_imu_idxs=seg_imu_idxs,
        device=device,
    )

    return reconstructed_data
