import pyrootutils
from typing import List, Union

import numpy as np
from omegaconf import ListConfig, OmegaConf
import torch
import pandas as pd
from torch.utils.data import Dataset
from src.multibody_sim.utils import calculate_body_constants
from src.utils.data_processing import normalize_movementEst_data, normalize_sensor_data
import pickle
from scipy.signal import butter, filtfilt


class Dorschky2024EvalDataset(Dataset):
    """Dorschky2024_Dataset
    Dataset class for the Dorschky2024 dataset
    """

    def __init__(self,
                 data_dir: str,
                 subjects = 'all'):
        """
            :param data_dir:

        """
        super().__init__()
        with open(f"{data_dir}/dorschky_val_sequences.pkl", 'rb') as f:
            imu_data = pickle.load(f)

        participant_info = pd.read_csv(f"{data_dir}/ParticipantInfo.csv")
        self.imu_data = imu_data
        #self.imu_data.data = self.imu_data.data.apply(lambda x: x.to('cuda') if torch.cuda.is_available() else x)
        self.framerate = 100
        self.all_data = pd.DataFrame(columns=['IMU_data', 'body_constants', 'imu_offsets', 'ground_contact_model', 'triggertime', 'subject', 'trigger_no', 'trigger_idx','omc_start', 'omc_end'])

        for idx, row in imu_data.iterrows():
            part_info_row = participant_info.iloc[row.subject-1]
            body_constants = calculate_body_constants([[row.data, 0,
                                                        part_info_row.BODYHEIGHT, 1]])
            # Extract sensor placements for the current subject
            sensor_columns = ['IMU_PELVIS_PX', 'IMU_PELVIS_PY',
                              'IMU_FEMUR_R_PX', 'IMU_FEMUR_R_PY',
                              'IMU_TIBIA_R_PX', 'IMU_TIBIA_R_PY',
                              'IMU_FOOT_R_PX', 'IMU_FOOT_R_PY',
                              'IMU_FEMUR_L_PX', 'IMU_FEMUR_L_PY',
                              'IMU_TIBIA_L_PX', 'IMU_TIBIA_L_PY',
                              'IMU_FOOT_L_PX', 'IMU_FOOT_L_PY']

            # Calculate new values for imu_offsets
            imu_offsets = np.tile(part_info_row[sensor_columns].to_list(), (row.data.shape[1], 1)).T
            gc = get_gc_contact_points_generic(part_info_row.BODYHEIGHT, row.data.shape[1])

            # Update imu_offsets in the DataFrame

            self.all_data.loc[idx] = [row.data.to(torch.float32), torch.from_numpy(body_constants).to(torch.float32),
                                      torch.from_numpy(imu_offsets).to(torch.float32), gc, row.triggertime, row.subject, row.trigger_no, row.trigger_idx, row.omc_start, row.omc_end]

        if subjects != 'all':
            self.all_data = self.all_data[self.all_data.subject.isin(subjects)]

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, idx, get_metadata = False):
        # get the respective row from the imu_data:
        datarow = self.all_data.iloc[idx]
        # get a random sequence of length seq_len from datarow.data



        imu_data = datarow.IMU_data
        speed = None # not available at inference time

        batch_data_dict = {
            'IMU_data': imu_data.T,
            "body_constants": datarow.body_constants,
            "imu_offsets": datarow.imu_offsets.T,
            "imu_rotations": datarow.imu_offsets.T[:,:7]*0,
            "ground_contact_model": datarow.ground_contact_model,
        }

        metadata = {
            "triggertime": datarow.triggertime,
            "subject": datarow.subject,
            "trigger_no": datarow.trigger_no,
            "trigger_idx": datarow.trigger_idx,
            "omc_start": datarow.omc_start,
            "omc_end": datarow.omc_end
        }

        return batch_data_dict if not get_metadata else (batch_data_dict, metadata)

def get_gc_contact_points_generic(subject_height, seq_len):
    """
    Get the (symmetric) ground contact points for a generic subject
    :param subject_height: The height of the subject
    :param seq_len: The length of the sequence
    :return: The ground contact points
    """
    # The ground contact points for a generic subject
    gc_points = []
    for gc, x in zip(["heel_r", "toe_r"], [-6/180, 16.36/180]):
        gc_contact_points = torch.zeros(seq_len, 5)
        gc_contact_points[:, 0] = x * subject_height
        gc_contact_points[:, 1] = -0.039*subject_height
        gc_contact_points[:, 2] = 100 # Stiffness in BW/m
        gc_contact_points[:, 3] = 0.75 # Damping coefficient
        gc_contact_points[:, 4] = 0.5 # Friction coefficient
        gc_points.append(gc_contact_points)
    return torch.cat(gc_points, dim=-1)

