#!/usr/bin/env python

import numpy as np
import torch
from torch.utils.data import DataLoader


idSubs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Fs = 1000  # sampling freq in Hz

g_grav = 9.81


def calculate_body_constants(subject_data):
    bodyheight = subject_data[0][2]
    bodymass = subject_data[0][3]

    lengths = np.array(
        [
            (0.53 - 0.285) * bodyheight,  # thigh_length [m]
            (0.285 - 0.039) * bodyheight,  # shank_length [m]
            0.152 * bodyheight,  # foot_length [m]
            (0.81 - 0.53) * bodyheight,  # torso_length [m]
        ]
    )

    masses = np.array(
        [
            0.1 * bodymass,  # thigh_mass [kg]
            0.0465 * bodymass,  # shank_mass [kg]
            0.0145 * bodymass,  # foot_mass [kg]
            0.678 * bodymass,  # torso_mass [kg]
        ]
    )

    com_lengths = np.array(
        [
            0.433 * lengths[0],  # thigh_com_length [m]
            0.433 * lengths[1],  # shank_com_length [m]
            0.5 * lengths[2],  # foot_com_length [m]
            0.626 * lengths[3],  # torso_com_length [m]
        ]
    )

    inertias = np.array(
        [
            masses[0] * (0.323 * lengths[0]) ** 2,  # thigh_inertia [kg*m^2]
            masses[1] * (0.302 * lengths[1]) ** 2,  # shank_inertia [kg*m^2]
            masses[2] * (0.475 * lengths[2]) ** 2,  # foot_inertia [kg*m^2]
            masses[3] * (0.496 * lengths[3]) ** 2,  # torso_inertia [kg*m^2]
        ]
    )

    body_constants = np.array(
        [
            lengths[0],  # thigh_length [m]
            com_lengths[0],  # thigh_com_length [m]
            masses[0],  # thigh_mass [kg]
            inertias[0],  # thigh_inertia [kg*m^2]
            lengths[1],  # shank_length [m]
            com_lengths[1],  # shank_com_length
            masses[1],  # shank_mass [kg]
            inertias[1],  # shank_inertia [kg*m^2]
            lengths[2],  # foot_length [m]
            com_lengths[2],  # foot_com_length
            masses[2],  # foot_mass [kg]
            inertias[2],  # foot_inertia [kg*m^2]
            # lengths[3],  # torso_length [m]   # this is not needed for Kanes Method
            com_lengths[3],  # torso_com_length
            masses[3],  # torso_mass [kg]
            inertias[3],  # torso_inertia [kg*m^2]
            g_grav,  # acceleration due to gravity [m/s^2]
        ]
    )
    if type(subject_data[0][0]) == torch.Tensor:
        return np.tile(body_constants, (subject_data[0][0].shape[1], 1))
    return np.tile(body_constants, (subject_data.shape[0], 1))


# this parts needs reattention probably !
def calculate_hip_data(
    hip_info, add_data, center_of_press
):  # ground force is in shape (60, 100, 2)
    hip_x = hip_info[:, :, 0]
    hip_y = hip_info[:, :, 1]
    pelvis_tilt = hip_info[:, :, 2]
    hip_vel_x = hip_info[:, :, 3]
    hip_vel_y = hip_info[:, :, 4]
    der_x, der_y = hip_vel_x, hip_vel_y

    durations = add_data[:, 0]
    nSample = hip_x.shape[1]
    t = np.linspace(0.0, durations, nSample, axis=1)

    # asuming const speeds
    der_hip_vel_x, der_hip_vel_y = 0, 0

    der_hip_vel_x = np.zeros(hip_vel_x.shape)
    der_hip_vel_y = np.zeros(hip_vel_y.shape)

    fs = (nSample) / durations  # (nSample-1) time intervals

    # for left data it is the shifted version of right
    shift_samples = int(nSample / 2)

    for i in range(hip_vel_x.shape[0]):
        der_hip_vel_x[i, 1:] = np.diff(hip_vel_x[i]) * fs[i]
        der_hip_vel_x[i, 0] = der_hip_vel_x[i, 1]

        der_hip_vel_y[i, 1:] = np.diff(hip_vel_y[i]) * fs[i]
        der_hip_vel_y[i, 0] = der_hip_vel_y[i, 1]

    # this is related to the contact point data
    right_center_of_press_x, right_center_of_press_y = (
        center_of_press[:, :, 0],
        center_of_press[:, :, 1],
    )

    left_center_of_press_x, left_center_of_press_y = (
        center_of_press[:, :, 2],
        center_of_press[:, :, 3],
    )
    # left_center_of_press_x = np.roll(right_center_of_press_x,shift_samples)
    # left_center_of_press_y = np.roll(right_center_of_press_y,shift_samples)

    var_list = [
        pelvis_tilt,
        hip_x,
        hip_y,
        der_x,
        der_y,
        hip_vel_x,
        hip_vel_y,
        der_hip_vel_x,
        der_hip_vel_y,
        right_center_of_press_x,
        right_center_of_press_y,
        left_center_of_press_x,
        left_center_of_press_y,
    ]

    hip_data = np.zeros((hip_info.shape[0], hip_info.shape[1], len(var_list)))
    for i, var in enumerate(var_list):
        hip_data[:, :, i] = var

    return hip_data


## this method calculats the first and second derivatives of the joint angles
## add_data has 'duration in s' and 'speed in m/s'
def calculate_derivatives(joint_angles, joint_angles_left, aux_data, add_data):
    theta_pelvis = aux_data[:, :, 2]
    r_theta_hips = joint_angles[:, :, 0]
    r_theta_knees = joint_angles[:, :, 1]
    r_theta_ankles = joint_angles[:, :, 2]

    omega_torso = aux_data[:, :, 5]
    r_omega_hip = aux_data[:, :, 6]
    r_omega_knee = aux_data[:, :, 7]
    r_omega_ankle = aux_data[:, :, 8]

    der_omega_torso = np.zeros(omega_torso.shape)
    der_r_omega_hip = np.zeros(r_omega_hip.shape)
    der_r_omega_knee = np.zeros(r_omega_knee.shape)
    der_r_omega_ankle = np.zeros(r_omega_ankle.shape)

    der_theta_pelvis = np.zeros(omega_torso.shape)
    der_r_theta_hip = np.zeros(r_omega_hip.shape)
    der_r_theta_knee = np.zeros(r_omega_knee.shape)
    der_r_theta_ankle = np.zeros(r_omega_ankle.shape)

    durations = add_data[:, 0]
    nSample = theta_pelvis.shape[1]
    t = np.linspace(0.0, durations, nSample, axis=1)

    fs = (nSample) / durations  # (nSample-1) time intervals

    # for left data it is the shifted version of right
    shift_samples = int(nSample / 2)

    for i in range(der_r_theta_hip.shape[0]):
        der_theta_pelvis[i, 1:] = np.diff(theta_pelvis[i]) * fs[i]
        der_r_theta_hip[i, 1:] = np.diff(r_theta_hips[i]) * fs[i]
        der_r_theta_knee[i, 1:] = np.diff(r_theta_knees[i]) * fs[i]
        der_r_theta_ankle[i, 1:] = np.diff(r_theta_ankles[i]) * fs[i]

        der_theta_pelvis[i, 0] = der_theta_pelvis[i, 1]
        der_r_theta_hip[i, 0] = der_r_theta_hip[i, 1]
        der_r_theta_knee[i, 0] = der_r_theta_knee[i, 1]
        der_r_theta_ankle[i, 0] = der_r_theta_ankle[i, 1]

        der_omega_torso[i, 1:] = np.diff(omega_torso[i]) * fs[i]
        der_r_omega_hip[i, 1:] = np.diff(r_omega_hip[i]) * fs[i]
        der_r_omega_knee[i, 1:] = np.diff(r_omega_knee[i]) * fs[i]
        der_r_omega_ankle[i, 1:] = np.diff(r_omega_ankle[i]) * fs[i]

        der_omega_torso[i, 0] = der_omega_torso[i, 1]
        der_r_omega_hip[i, 0] = der_r_omega_hip[i, 1]
        der_r_omega_knee[i, 0] = der_r_omega_knee[i, 1]
        der_r_omega_ankle[i, 0] = der_r_omega_ankle[i, 1]

    # left part data
    l_theta_hips = joint_angles_left[:, :, 0]
    l_theta_knees = joint_angles_left[:, :, 1]
    l_theta_ankles = joint_angles_left[:, :, 2]

    l_omega_hip = aux_data[:, :, 9]
    l_omega_knee = aux_data[:, :, 10]
    l_omega_ankle = aux_data[:, :, 11]

    der_l_omega_hip = np.zeros(l_omega_hip.shape)
    der_l_omega_knee = np.zeros(l_omega_knee.shape)
    der_l_omega_ankle = np.zeros(l_omega_ankle.shape)

    der_l_theta_hip = np.zeros(l_omega_hip.shape)
    der_l_theta_knee = np.zeros(l_omega_knee.shape)
    der_l_theta_ankle = np.zeros(l_omega_ankle.shape)

    for i in range(der_r_theta_hip.shape[0]):
        der_l_theta_hip[i, 1:] = np.diff(l_theta_hips[i]) * fs[i]
        der_l_theta_knee[i, 1:] = np.diff(l_theta_knees[i]) * fs[i]
        der_l_theta_ankle[i, 1:] = np.diff(l_theta_ankles[i]) * fs[i]

        der_l_theta_hip[i, 0] = der_l_theta_hip[i, 1]
        der_l_theta_knee[i, 0] = der_l_theta_knee[i, 1]
        der_l_theta_ankle[i, 0] = der_l_theta_ankle[i, 1]

        der_l_omega_hip[i, 1:] = np.diff(l_omega_hip[i]) * fs[i]
        der_l_omega_knee[i, 1:] = np.diff(l_omega_knee[i]) * fs[i]
        der_l_omega_ankle[i, 1:] = np.diff(l_omega_ankle[i]) * fs[i]

        der_l_omega_hip[i, 0] = der_l_omega_hip[i, 1]
        der_l_omega_knee[i, 0] = der_l_omega_knee[i, 1]
        der_l_omega_ankle[i, 0] = der_l_omega_ankle[i, 1]

    var_list = [
        der_r_theta_hip,
        der_r_theta_knee,
        der_r_theta_ankle,  # this &
        der_l_theta_hip,
        der_l_theta_knee,
        der_l_theta_ankle,
        der_theta_pelvis,
        r_omega_hip,
        r_omega_knee,
        r_omega_ankle,  # this are same actually
        l_omega_hip,
        l_omega_knee,
        l_omega_ankle,
        omega_torso,
        der_r_omega_hip,
        der_r_omega_knee,
        der_r_omega_ankle,
        der_l_omega_hip,
        der_l_omega_knee,
        der_l_omega_ankle,
        der_omega_torso,
    ]

    deriv_data = np.zeros((aux_data.shape[0], aux_data.shape[1], len(var_list)))
    for i, var in enumerate(var_list):
        deriv_data[:, :, i] = var

    return deriv_data


def scale_bodyweights(ground_force, subject_data):
    bodymass = subject_data[0, 3]
    bodyweight = bodymass * g_grav
    return ground_force * bodyweight


## this method gets a batch of data from dataset
def load_data(batch_size=32, iBatch=1, useReal=False):
    from src.datamodules.components.biomechEst_dataset import BiomechEstDataset
    from src.utils.data_processing import denormalize_variable

    # load data using torch dataset module
    mode = "test"
    data_dir = 'data/seifer_dataset'
    #data_dir = "data/fukuchi_dataset_randPos"
    # data_dir = 'data/fukuchi_dataset_randAlign'

    # can edit here subs and speeds
    Subs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    speeds = [
        "slowwalking",
        "normwalking",
        "fastwalking",
        "slowrunning",
        "normrunning",
        "fastrunning",
    ]

    #Subs = [1, 2]  # ,3,4,5,6,7,8,9,10]
    # speeds = ['normrunning']
    #speeds = ["runT45"]

    num_workers = 0
    pin_memory = False

    train_dataset = BiomechEstDataset(
        mode=mode, Subs=Subs, data_dir=data_dir, speeds=speeds
    )

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    # create an iterator object
    dataloader_iter = iter(train_loader)

    # get the i'th batch data
    for i in range(iBatch):
        batch_data = next(dataloader_iter)

    # setting the nSample from data
    nSample = int(batch_data["subject_data"][0, -1])

    # # this part is for the left y data which is the shifted version of the right
    # shift_samples = int(nSample/2)
    # y_dl= torch.roll(y_d, shifts=shift_samples, dims=1)

    # creatig time samples
    duration = batch_data["subject_data"][0, 0].cpu().numpy()
    t = np.linspace(0, duration, nSample)

    return batch_data, t
