from matplotlib.pylab import f
from omegaconf import DictConfig
from typing import Optional
import pyrootutils

root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)

import os
import xml.etree.ElementTree as ET

import hydra

import pandas as pd
import numpy as np
import torch
from scipy.signal import butter, filtfilt

from src.utils.utils import (
    COP_columns,
    GRF_COP_column_mapping,
    GRF_columns,
    IK_column_mapping,
)
from src.utils.reconstruction import (
    reconstruct_imus,
    compose_segment_positions,
)

subjects = [1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13]
trials = [1, 2]


# Load Segment Lengths
def _preprocess_segment_lengths(full_data_dir_path: str) -> pd.DataFrame:
    preprocessed_segment_lengths_filepath = os.path.join(
        full_data_dir_path, "preprocessed", "segment_lengths.parquet"
    )

    if os.path.exists(preprocessed_segment_lengths_filepath):
        df = pd.read_parquet(preprocessed_segment_lengths_filepath)
        if len(df) == len(subjects) * 12:
            return df

    segment_lengths = pd.DataFrame(
        columns=["subject", "segment", "length"],
    )

    for subject_id in subjects:
        origianl_segment_lengths_filepath = os.path.join(
            full_data_dir_path, f"subject_{subject_id:02d}", "Models", "final.osim"
        )
        # Parse the XML file
        tree = ET.parse(origianl_segment_lengths_filepath)
        root = tree.getroot()

        jointset = root.find(".//JointSet")
        objects = jointset.find("objects") if jointset is not None else None

        joint_translations_dict = {}
        if objects is not None:
            # Iterate over all joint elements and add them to the dictionary
            for joint in objects:
                joint_name = joint.get("name")  # Extract the 'name' attribute
                if joint_name:
                    # Find the PhysicalOffsetFrame element
                    frame = joint.find(".//PhysicalOffsetFrame")
                    if frame is not None:
                        frame_name = frame.get("name")  # Optionally extract frame name
                        # Find the translation element
                        translation = frame.find("translation")
                        if translation is not None and translation.text:
                            # Extract translation values
                            translation_values = translation.text.split()[:-1]
                            joint_translations_dict[joint_name] = translation_values
        segment_lengths = pd.concat(
            [
                segment_lengths,
                pd.DataFrame(
                    {
                        "subject": subject_id,
                        "segment": joint_translations_dict.keys(),
                        "length": joint_translations_dict.values(),
                    },
                ),
            ]
        )
    assert len(segment_lengths) == len(subjects) * 12

    segment_lengths.to_parquet(preprocessed_segment_lengths_filepath, index=True)

    return segment_lengths


def _joints(subject_id: int, data_dir: str) -> pd.DataFrame:
    joints_file = os.path.join(data_dir, "consts", "joint_list.parquet")
    joints = pd.read_parquet(joints_file)
    return joints[joints["subject"] == subject_id]


def _contact_points(subject_id: int, data_dir: str) -> pd.DataFrame:
    contact_points_file = os.path.join(data_dir, "consts", "cp_list.parquet")
    contact_points = pd.read_parquet(contact_points_file)
    return contact_points[contact_points["subject"] == subject_id]


def _butterworth_filter_data(
    data: pd.DataFrame, axis=0, cutoff_freq=6, sampling_rate=100, order=2
):
    nyquist = 0.5 * sampling_rate
    normal_cutoff = cutoff_freq / nyquist
    b, a = butter(order, normal_cutoff, btype="low", analog=False)
    return pd.DataFrame(filtfilt(b, a, data, axis=0), columns=data.columns)


def _filtered_IK_koelewijn(
    data_dir: str, subject_id: int, trial_id: int
) -> pd.DataFrame:
    prep_dir = os.path.join(data_dir, "IK_prep")
    os.makedirs(prep_dir, exist_ok=True)
    prep_IK_path = os.path.join(
        prep_dir,
        f"IK_subject{subject_id}_trial{trial_id}.parquet",
    )
    if os.path.exists(prep_IK_path):
        return pd.read_parquet(prep_IK_path)

    orig_IK_path = os.path.join(
        data_dir,
        "IK_orig",
        f"IK_subject{subject_id}_trial{trial_id}.csv",
    )
    if not os.path.exists(orig_IK_path):
        raise FileNotFoundError(f"File not found: {orig_IK_path}")

    IK_data = pd.read_csv(orig_IK_path, sep=",", skiprows=0)
    IK_data = IK_data[[val for val in IK_column_mapping.values() if val != ""]].copy()
    IK_data = IK_data.rename(
        columns={val: key for key, val in IK_column_mapping.items()}
    )

    filtered_IK_data = _butterworth_filter_data(IK_data, axis=0)
    filtered_IK_data.to_parquet(prep_IK_path, index=False)

    return filtered_IK_data.copy()


def _filtered_grf_cop_koelewijn(
    data_dir: str, subject_id: int, trial_id: int
) -> tuple[pd.DataFrame, pd.DataFrame]:
    prep_GRF_COP_dir = os.path.join(data_dir, "GRF_COP_prep")
    os.makedirs(prep_GRF_COP_dir, exist_ok=True)
    prep_GRF_path = os.path.join(
        prep_GRF_COP_dir,
        f"GRF_subject{subject_id}_trial{trial_id}.parquet",
    )
    prep_COP_path = os.path.join(
        prep_GRF_COP_dir,
        f"COP_subject{subject_id}_trial{trial_id}.parquet",
    )
    if os.path.exists(prep_GRF_path) and os.path.exists(prep_COP_path):
        return pd.read_parquet(prep_GRF_path), pd.read_parquet(prep_COP_path)

    orig_GRF_COP_path = os.path.join(
        data_dir,
        "GRF_COP_orig",
        f"GRF_COP_subject{subject_id}_trial{trial_id}.mot",
    )
    if not os.path.exists(orig_GRF_COP_path):
        raise FileNotFoundError(f"File not found: {orig_GRF_COP_path}")

    grf_cop_data = pd.read_csv(orig_GRF_COP_path, sep="\t", skiprows=6)
    filtered_grf_cop_data = _butterworth_filter_data(grf_cop_data, axis=0)
    # calc cop
    filtered_grf_cop_data["ground_force_px"] = (
        filtered_grf_cop_data["ground_torque_z"]
        / filtered_grf_cop_data["ground_force_vy"]
    )
    filtered_grf_cop_data["1_ground_force_px"] = (
        filtered_grf_cop_data["1_ground_torque_z"]
        / filtered_grf_cop_data["1_ground_force_vy"]
    )
    filtered_grf_cop_data["ground_force_py"] = 0
    filtered_grf_cop_data["1_ground_force_py"] = 0

    filtered_grf_cop_data = filtered_grf_cop_data[
        GRF_COP_column_mapping.values()
    ].copy()
    filtered_grf_cop_data = filtered_grf_cop_data.rename(
        columns={val: key for key, val in GRF_COP_column_mapping.items()}
    )

    filtered_grf_data = filtered_grf_cop_data[GRF_columns].copy()
    filtered_cop_data = filtered_grf_cop_data[COP_columns].copy()

    filtered_grf_data.to_parquet(prep_GRF_path, index=False)
    filtered_cop_data.to_parquet(prep_COP_path, index=False)

    return filtered_grf_data, filtered_cop_data


def _reconstructed_IMU_koelewijn(
    data_dir: str,
    subject_id: int,
    trial_id: int,
    imu_offsets: np.ndarray,
    segment_positions: np.ndarray,
) -> pd.DataFrame:
    prep_dir = os.path.join(data_dir, "IMU_prep")
    os.makedirs(prep_dir, exist_ok=True)
    IMU_data_file = os.path.join(
        prep_dir,
        f"IMU_subject{subject_id}_trial{trial_id}.parquet",
    )
    global_imu_translations_file = os.path.join(
        prep_dir,
        f"IMU_translations_subject{subject_id}_trial{trial_id}.parquet",
    )
    imu_offsets_file = os.path.join(
        prep_dir,
        f"imu_offsets_subject{subject_id}_trial{trial_id}.txt",
    )

    if os.path.exists(imu_offsets_file) and os.path.exists(IMU_data_file):
        cached_imu_offsets = np.loadtxt(imu_offsets_file)
        if np.all(np.equal(cached_imu_offsets, imu_offsets)):
            return pd.read_parquet(IMU_data_file)
    np.savetxt(imu_offsets_file, imu_offsets)

    (IMU_data, global_imu_translations) = reconstruct_imus(
        _filtered_IK_koelewijn(
            data_dir=data_dir, subject_id=subject_id, trial_id=trial_id
        ),
        segment_positions,
        imu_offsets,
    )
    assert isinstance(IMU_data, pd.DataFrame)
    assert isinstance(global_imu_translations, pd.DataFrame)

    IMU_data.to_parquet(IMU_data_file, index=False)
    global_imu_translations.to_parquet(global_imu_translations_file, index=False)

    return IMU_data


def _load_body_constants(
    data_dir: str, subject_id: int, segment_lengths: np.ndarray
) -> np.ndarray:
    """Loads body contants from consts folder

    Args:
        data_dir (str): _description_
        subject_id (int): _description_
        segment_lengths (np.ndarray): lengths of segments: thigh, shank, foot shape(3,)

    Returns:
        np.ndarray: shape (16, )

        [
        ['thigh_length', 'thigh_com_dist', 'thigh_mass','thigh_inertia'],
        ['shank_length', 'shank_com_dist', 'shank_mass', 'shank_inertia'],
        ['foot_length', 'foot_com_dist', 'foot_mass', 'foot_inertia'],
        ['torso_com_dist', 'torso_mass', 'torso_inertia', 'g']
        ]
    """
    assert segment_lengths.shape == (3,)
    consts_file = os.path.join(data_dir, "consts", "segment_list.parquet")
    df = pd.read_parquet(consts_file)
    df = df[df["subject"] == subject_id]
    df["mass_center"] = df["mass_center"].apply(lambda x: np.array(x, dtype=np.float32)[:2])  # type: ignore
    body_constants = np.zeros((4, 4), dtype=np.float32)
    body_constants[:3, 0] = segment_lengths
    for i, segment in enumerate(["femur_r", "tibia_r", "foot_r", "torso"]):
        # thigh: com_dist, mass, inertia
        body_constants[i, 1] = np.sqrt(
            np.sum(np.square(df[df["segment"] == segment]["mass_center"].values[0]))
        )
        body_constants[i, 2] = df[df["segment"] == segment]["mass"].values[0]
        body_constants[i, 3] = df[df["segment"] == segment]["inertia"].values[0]

    body_constants[3, 0:3] = body_constants[3, 1:4]
    body_constants[3, 3] = 9.81

    return body_constants


def load_data_koelewijn(
    data_dir: str,
    subject_id: int,
    trial_id: int,
    imu_offsets: np.ndarray,
) -> tuple[
    torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
    IK_data = _filtered_IK_koelewijn(data_dir, subject_id, trial_id)

    GRF_data, COP_data = _filtered_grf_cop_koelewijn(data_dir, subject_id, trial_id)

    joints = _joints(subject_id, data_dir)
    contact_points = _contact_points(subject_id, data_dir)

    segment_positions = compose_segment_positions(
        joints,
        contact_points,
    )
    # .T NEEDED FOR CORRECT DATA
    segment_positions_flat = segment_positions.T.flatten()

    IMU_data = _reconstructed_IMU_koelewijn(
        data_dir, subject_id, trial_id, imu_offsets, segment_positions
    )

    segment_lengths = np.zeros((3,))
    segment_lengths[0] = np.sqrt(np.sum(np.square(segment_positions[:, 3])))
    segment_lengths[1] = np.sqrt(np.sum(np.square(segment_positions[:, 4])))
    segment_lengths[2] = np.sqrt(np.sum(np.square(segment_positions[:, 6])))

    body_constants = _load_body_constants(data_dir, subject_id, segment_lengths)
    body_constants_flat = body_constants.flatten()

    # Save data
    return (
        torch.tensor(IMU_data.values.copy(), dtype=torch.float32),
        torch.tensor(IK_data.values.copy(), dtype=torch.float32),
        torch.tensor(segment_positions_flat, dtype=torch.float32),
        torch.tensor(body_constants_flat, dtype=torch.float32),
        torch.tensor(GRF_data.values.copy(), dtype=torch.float32),
        torch.tensor(COP_data.values.copy(), dtype=torch.float32),
    )


@hydra.main(version_base="1.3", config_path="../../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
    datamodule = hydra.utils.instantiate(cfg.datamodule)

    train_dl = datamodule.train_dataloader()
    test_dl = datamodule.test_dataloader()
    predict_dl = datamodule.predict_dataloader()

    for batch in train_dl:
        for key, value in batch.items():
            print(key, value.shape)
        break


if __name__ == "__main__":
    import hydra

    main()
