from typing import TypeVar, Union, Optional, Tuple
from omegaconf import DictConfig
import torch
import numpy as np
import pandas as pd

import src.utils.utils as utils

required_IK_data_columns = [
    "tx",
    "ty",
    "dtx",
    "dty",
    "ddtx",
    "ddty",
    "a_pelvis",
    "da_pelvis",
    "dda_pelvis",
    "a_hip_r",
    "da_hip_r",
    "dda_hip_r",
    "a_knee_r",
    "da_knee_r",
    "dda_knee_r",
    "a_ankle_r",
    "da_ankle_r",
    "dda_ankle_r",
    "a_hip_l",
    "da_hip_l",
    "dda_hip_l",
    "a_knee_l",
    "da_knee_l",
    "dda_knee_l",
    "a_ankle_l",
    "da_ankle_l",
    "dda_ankle_l",
]

required_segment_lengths_columns = [
    "pelvis",
    "thigh_r",
    "shank_r",
    "foot_r",
    "thigh_l",
    "shank_l",
    "foot_l",
]

required_GC_model_columns = [
    "heel_x",
    "heel_y",
    "toe_x",
    "toe_y",
]

imu_parents = ['pelvis', 'hip_r', 'knee_r', 'ankle_r', 'hip_l', 'knee_l', 'ankle_l']
gc_model = {
    "r_heel": {
        "parent": "ankle_r",
        "key": "heel"
    },
    "r_toe": {
        "parent": "ankle_r",
        "key": "toe"
    },
    "l_heel": {
        "parent": "ankle_l",
        "key": "heel"
    },
    "l_toe": {
        "parent": "ankle_l",
        "key": "toe"
    },
}
# Symmetric kinematics model
kinematics_model = {
    "pelvis": {
        "parent": None
    },
    "hip_r": {
        "parent": "pelvis",
        "offset_direction": torch.tensor([[0.0, 0.0]]),
        "body_constant": "thigh_length"
    },
    "knee_r": {
        "parent": "hip_r",
        "offset_direction": torch.tensor([[0.0, -1.0]]),
        "body_constant": "thigh_length"
    },
    "ankle_r": {
        "parent": "knee_r",
        "offset_direction": torch.tensor([[0.0, -1.0]]),
        "body_constant": "shank_length"
    },
    "hip_l": {
        "parent": "pelvis",
        "offset_direction": torch.tensor([[0.0, 0.0]]),
        "body_constant": "thigh_length"
    },
    "knee_l": {
        "parent": "hip_l",
        "offset_direction": torch.tensor([[0.0, -1.0]]),
        "body_constant": "thigh_length"
    },
    "ankle_l": {
        "parent": "knee_l",
        "offset_direction": torch.tensor([[0.0, -1.0]]),
        "body_constant": "shank_length"
    },
}

def get_joint_position(
        IK_data_prev_joint: torch.Tensor,
        IK_data_this_joint: torch.Tensor,
        joint_offset: torch.Tensor,
        device
    ) -> torch.Tensor:
    """
        Calculate the kinematics of the next joint based on the previous joint and the new joint offset
        :param IK_data_prev_joint: The Global Kinematics data of the previous joint, shape (batch_size, seq_len, 9)
        :param IK_data_this_joint: The Rotation Kinematics data of the current joint, shape (batch_size, seq_len, 3)
        :param joint_offset: The joint offset, shape (batch_size, 2)
        :return: The kinematics of the next joint, shape (batch_size, seq_len, 9).
        Contains the position, velocity and acceleration of the joint in x and y global coordinates
    """
    # Next joint tensor
    next_joint = torch.zeros_like(IK_data_prev_joint).to(device)
    next_joint[:,:,-3:] = IK_data_prev_joint[:,:,-3:] + IK_data_this_joint[:,:,-3:]

    # Calculate the position of the next joint
    # rotate joint_ offset around da to get joint_offset in local coordinates
    next_joint[:,:,0] = joint_offset[:,:,0]*torch.cos(IK_data_prev_joint[:,:,6]) - joint_offset[:,:,1]*torch.sin(IK_data_prev_joint[:,:,6]) + IK_data_prev_joint[:,:,0]
    next_joint[:,:,3] = joint_offset[:,:,0]*torch.sin(IK_data_prev_joint[:,:,6]) + joint_offset[:,:,1]*torch.cos(IK_data_prev_joint[:,:,6]) + IK_data_prev_joint[:,:,3]
    # next joint's velocity is the sum of the previous joint's velocity and the offset times the angular velocity
    next_joint[:,:,1] = IK_data_prev_joint[:,:,1] + (-torch.sin(IK_data_prev_joint[:,:,6]) * joint_offset[:,:,0] - torch.cos(IK_data_prev_joint[:,:,6]) * joint_offset[:,:,1]) * IK_data_prev_joint[:,:,7]
    next_joint[:,:,4] = IK_data_prev_joint[:,:,4] + (torch.cos(IK_data_prev_joint[:,:,6]) * joint_offset[:,:,0] - torch.sin(IK_data_prev_joint[:,:,6]) * joint_offset[:,:,1]) * IK_data_prev_joint[:,:,7]

    # next joint's acceleration is the sum of the previous joint's acceleration and the coriolis acceleration
    dd_x_loc = - joint_offset[:,:,0]*(IK_data_prev_joint[:,:,7]**2) - IK_data_prev_joint[:,:,8]*joint_offset[:,:,1]
    dd_y_loc = - joint_offset[:,:,1]*(IK_data_prev_joint[:,:,7]**2) + IK_data_prev_joint[:,:,8]*joint_offset[:,:,0]

    next_joint[:,:,2] = dd_x_loc*torch.cos(IK_data_prev_joint[:,:,6]) - dd_y_loc*torch.sin(IK_data_prev_joint[:,:,6]) + IK_data_prev_joint[:,:,2]
    next_joint[:,:,5] = dd_x_loc*torch.sin(IK_data_prev_joint[:,:,6]) + dd_y_loc*torch.cos(IK_data_prev_joint[:,:,6]) + IK_data_prev_joint[:,:,5]
    return next_joint

def to_local_coordinates_imu(IK_data, g, device):
    """
    Convert the IMU data to local coordinates
    :param IK_data: positions, velocities, accelerations and angles of the IMUs
    :return: imu_data: IMU data in local coordinates
    """
    # IMU data in local coordinates
    # rotate ddx and ddy around da to get ddx and ddy in local coordinates
    imu_ = torch.zeros_like(IK_data[:,:,-3:]).to(device)
    ddy = g+IK_data[:,:,5] # add gravity to ddy
    imu_[:,:,0] = IK_data[:,:,2]*torch.cos(IK_data[:,:,6]) + ddy*torch.sin(IK_data[:,:,6])
    imu_[:,:,1] = -IK_data[:,:,2]*torch.sin(IK_data[:,:,6]) + ddy*torch.cos(IK_data[:,:,6])
    imu_[:,:,2] = IK_data[:,:,7]
    return imu_


def global_kinematics(
    IK_data: torch.Tensor,
    body_constants: torch.Tensor,
    imu_offsets: torch.Tensor,
    imu_rotations: torch.Tensor,
    ground_contact_model: torch.Tensor,
    cfg: DictConfig,
    device: torch.device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),),
    debug = False
) -> dict[str, torch.Tensor]:
    """
    Calculate the kinematics of the model based on the IK data and the segment positions
    :param IK_data: The Inverse Kinematics data, shape (batch_size, seq_len, 24)
    :param body_constants tensor
    :param imu_offsets: The IMU offsets, shape (batch_size, 2, 3)
    :param cfg: The configuration dictionary
    :param device: The device to run the calculations on
    :return: A dictionary containing the kinematics data: Per joint, there will be given in global coordinates.
            - position
            - velocity
            - acceleration
            - angular position
            - angular velocity
            - angular acceleration
    """


    reconstructed_data = {}
    if debug:
        IK_data_columns = cfg.dataset_variables.IK_data
        imu_offsets_columns = cfg.dataset_variables.imu_offsets
        body_constant_columns = cfg.dataset_variables.body_constants
        gc_model_columns = cfg.dataset_variables.ground_contact_model

    else:
        assert cfg.modeltype == "2dc" # 2D opensim style model is currently not implemented
        IK_data_columns = cfg.datamodule.dataset_variables.IK_data
        imu_offsets_columns = cfg.datamodule.dataset_variables.imu_offsets
        # 2DC-like IMU reconstruction on winter model
        # segment_positions depend on body constants
        body_constant_columns = cfg.datamodule.dataset_variables.body_constants
        gc_model_columns = cfg.datamodule.dataset_variables.ground_contact_model
        pass
    #else: If the modeltype wasn't 2dc-like
    #    segment_pos_columns = cfg.datamodule.dataset_variables.segment_pos
    #    assert segment_pos.size(-1) == len(segment_pos_columns)
    assert IK_data.dim() == 3, f"Expected 3 dimensions, got {IK_data.dim()}"
    assert IK_data.size(-1) == len(IK_data_columns)
    assert imu_offsets.dim() == 3
    assert imu_offsets.size(-1) == len(imu_offsets_columns)

    # Calculate the global kinematics for 2DC modeltype

    glob_pelvis = IK_data[
           :, :, IK_data_columns.index("tx"):IK_data_columns.index("dda_pelvis") + 1
        ]
    reconstructed_data["pelvis"] = glob_pelvis

    # Get all joint global kinematics
    for joint in kinematics_model.keys():
        if joint == "pelvis":
            continue
        parent = kinematics_model[joint]["parent"]
        joint_offset = torch.matmul(body_constants[:,:,body_constant_columns.index(kinematics_model[joint]["body_constant"]):body_constant_columns.index(kinematics_model[joint]["body_constant"])+1],
                                    kinematics_model[joint]["offset_direction"].to(device).unsqueeze(0))
        IK_data_prev_joint = reconstructed_data[parent]
        IK_data_this_joint = IK_data[
            :, :, IK_data_columns.index(f"a_{joint}") : IK_data_columns.index(f"dda_{joint}") + 1
        ]
        reconstructed_data[joint] = get_joint_position(IK_data_prev_joint, IK_data_this_joint, joint_offset, device)
    imu_data = torch.zeros((IK_data.size(0), IK_data.size(1), 3*len(imu_parents)), device=device)

    # enumerate over imus, get imu data
    ankle_imus = ['ankle_r', 'ankle_l']
    ankle_imus_ = {}
    for i, imu in enumerate(imu_parents):
        imu_offset = imu_offsets[:, :, 2*i:2*i+2]
        imu_angles = torch.zeros_like(IK_data_this_joint[:,:,-3:]) # Rigidly attached & aligned IMU, so just put in zero angles
        imu_angles[:,:,0] = imu_rotations[:,:,i]
        # To rotate the IMU, we could use the following 2 steps make a 0 offset joint, and rotate the IMU around the joint by only the angle
        imu_ = get_joint_position(reconstructed_data[imu], imu_angles, imu_offset, device)
        if imu in ankle_imus:
            ankle_imus_[imu] = imu_ # to return the global state of the ankle IMUs
        imu_ = to_local_coordinates_imu(imu_, body_constants[:,:,body_constant_columns.index("g")], device)
        imu_data[:,:,3*i:3*i+3] = imu_


    # Get the global coordinates from the gc model
    gc_positions = {}
    for gc in gc_model.keys():
        parent = gc_model[gc]["parent"]
        key = gc_model[gc]["key"]
        joint_offset = ground_contact_model[:,:, gc_model_columns.index(f"{key}_x"):gc_model_columns.index(f"{key}_y")+1]
        gc_positions[gc] = get_joint_position(reconstructed_data[parent], torch.zeros_like(imu_angles), joint_offset, device)

    return reconstructed_data, imu_data, gc_positions, ankle_imus_