import pyrootutils
from typing import List, Union

import numpy as np
from omegaconf import ListConfig, OmegaConf
import torch
import pandas as pd
from torch.utils.data import Dataset
from src.multibody_sim.utils import calculate_body_constants
from src.utils.data_processing import normalize_movementEst_data, normalize_sensor_data
import pickle
from scipy.signal import butter, filtfilt


class Dorschky2024Dataset(Dataset):
    """Dorschky2024_Dataset
    Dataset class for the Dorschky2024 dataset
    """

    def __init__(self,
                 mode: str,
                 data_dir: str,
                 noise: float = 0.0,
                 cutoff_frequency: float = -1,
                 seq_len: int = 100,
                 subjects = 'all'):
        """
            :param mode: str:
            :param data_dir:
            :param trial_ids:
            :param dataset_variables:
            :param seq_len:
            :param supersampling: int: # how often to sample from the dataset per epoch
        """
        super().__init__()

        assert mode in ["train", "val", "test"]
        with open(f"{data_dir}/dorschky_sequences.pkl", 'rb') as f:
            imu_data = pickle.load(f)

        participant_info = pd.read_csv(f"{data_dir}/ParticipantInfo.csv")
        self.imu_data = imu_data
        if subjects != 'all':
            self.imu_data = self.imu_data[self.imu_data.subject.isin(subjects)]

        self.noise = noise
        #self.imu_data.data = self.imu_data.data.apply(lambda x: x.to('cuda') if torch.cuda.is_available() else x)
        self.seq_len = seq_len
        self.framerate = 100
        self.all_data = pd.DataFrame(columns=['IMU_data', 'body_constants', 'imu_offsets','imu_rotations', 'ground_contact_model', 'speed_r', 'speed_l'])
        if seq_len > 0:
            for idx, row in self.imu_data.iterrows():
                part_info_row = participant_info.iloc[row.subject-1]

                body_constants = calculate_body_constants([[row.data[:, :seq_len], 0,
                                                            part_info_row.BODYHEIGHT, 1]])
                # Extract sensor placements for the current subject
                sensor_columns = ['IMU_PELVIS_PX', 'IMU_PELVIS_PY',
                                  'IMU_FEMUR_R_PX', 'IMU_FEMUR_R_PY',
                                  'IMU_TIBIA_R_PX', 'IMU_TIBIA_R_PY',
                                  'IMU_FOOT_R_PX', 'IMU_FOOT_R_PY',
                                  'IMU_FEMUR_L_PX', 'IMU_FEMUR_L_PY',
                                  'IMU_TIBIA_L_PX', 'IMU_TIBIA_L_PY',
                                  'IMU_FOOT_L_PX', 'IMU_FOOT_L_PY']

                # Calculate new values for imu_offsets
                imu_offsets = np.tile(part_info_row[sensor_columns].to_list(), (seq_len, 1)).T
                imu_rotations = imu_offsets.copy()[:7]*0
                gc = get_gc_contact_points_generic(part_info_row.BODYHEIGHT, seq_len)
                speed_r = row.speed_r
                speed_l = row.speed_l
                if cutoff_frequency > 0:
                    imu = torch.from_numpy(butter_lowpass_filter(row.data.to(torch.float32), cutoff_frequency, self.framerate).copy())
                    self.all_data.loc[idx] = [imu.to(torch.float32), torch.from_numpy(body_constants).to(torch.float32),
                                            torch.from_numpy(imu_offsets).to(torch.float32), torch.from_numpy(imu_rotations).to(torch.float32), gc, speed_r.to(torch.float32), speed_l.to(torch.float32)]
                else:
                    self.all_data.loc[idx] = [row.data.to(torch.float32), torch.from_numpy(body_constants).to(torch.float32),
                                          torch.from_numpy(imu_offsets).to(torch.float32), torch.from_numpy(imu_rotations).to(torch.float32), gc, speed_r.to(torch.float32), speed_l.to(torch.float32)]

        else:
            for idx, row in self.imu_data.iterrows():
                part_info_row = participant_info.iloc[row.subject-1]
                body_constants = calculate_body_constants([[row.data, 0,
                                                            part_info_row.BODYHEIGHT, 1]])
                # Extract sensor placements for the current subject
                sensor_columns = ['IMU_PELVIS_PX', 'IMU_PELVIS_PY',
                                  'IMU_FEMUR_R_PX', 'IMU_FEMUR_R_PY',
                                  'IMU_TIBIA_R_PX', 'IMU_TIBIA_R_PY',
                                  'IMU_FOOT_R_PX', 'IMU_FOOT_R_PY',
                                  'IMU_FEMUR_L_PX', 'IMU_FEMUR_L_PY',
                                  'IMU_TIBIA_L_PX', 'IMU_TIBIA_L_PY',
                                  'IMU_FOOT_L_PX', 'IMU_FOOT_L_PY']

                # Calculate new values for imu_offsets
                imu_offsets = np.tile(part_info_row[sensor_columns].to_list(), (row.data.shape[1], 1)).T
                gc = get_gc_contact_points_generic(part_info_row.BODYHEIGHT, row.data.shape[1])
                speed_r = row.speed_r
                imu_rotations = imu_offsets.copy()[:7]*0

                speed_l = row.speed_l
                # Update imu_offsets in the DataFrame
                if cutoff_frequency > 0:
                    imu = torch.from_numpy(butter_lowpass_filter(row.data.to(torch.float32), cutoff_frequency, self.framerate).copy())
                    self.all_data.loc[idx] = [imu.to(torch.float32), torch.from_numpy(body_constants).to(torch.float32),
                                            torch.from_numpy(imu_offsets).to(torch.float32), torch.from_numpy(imu_rotations).to(torch.float32), gc, speed_r.to(torch.float32), speed_l.to(torch.float32)]
                else:
                    self.all_data.loc[idx] = [row.data.to(torch.float32), torch.from_numpy(body_constants).to(torch.float32),
                                              torch.from_numpy(imu_offsets).to(torch.float32),
                                              torch.from_numpy(imu_rotations).to(torch.float32), gc,
                                              speed_r.to(torch.float32), speed_l.to(torch.float32)]


    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, idx, get_start_idx = False, start_idx = None):
        # get the respective row from the imu_data:
        datarow = self.all_data.iloc[idx]
        # get a random sequence of length seq_len from datarow.data
        if start_idx is None:
            start_idx = torch.randint(0, datarow.IMU_data.shape[-1] - self.seq_len, (1,)).item()

        if self.seq_len > 0:
            imu_data = datarow.IMU_data[:, start_idx:start_idx + self.seq_len]
            speed = torch.cat([datarow.speed_r[start_idx:start_idx + self.seq_len].unsqueeze(0),
                            datarow.speed_l[start_idx:start_idx + self.seq_len].unsqueeze(0)], dim=0)
        else:
            imu_data = datarow.IMU_data
            speed = torch.cat([datarow.speed_r.unsqueeze(0),
                            datarow.speed_l.unsqueeze(0)], dim=0)

        batch_data_dict = {
            'IMU_data': imu_data.T,
            "body_constants": datarow.body_constants,
            "imu_offsets": datarow.imu_offsets.T,
            "imu_rotations": datarow.imu_rotations.T,
            "ground_contact_model": datarow.ground_contact_model,
            "speed": speed.T
        }


        return batch_data_dict if not get_start_idx else (batch_data_dict, start_idx)

def get_gc_contact_points_generic(subject_height, seq_len):
    """
    Get the (symmetric) ground contact points for a generic subject
    :param subject_height: The height of the subject
    :param seq_len: The length of the sequence
    :return: The ground contact points
    """
    # The ground contact points for a generic subject
    gc_points = []
    for gc, x in zip(["heel_r", "toe_r"], [-6/180, 16.36/180]):
        gc_contact_points = torch.zeros(seq_len, 5)
        gc_contact_points[:, 0] = x * subject_height
        gc_contact_points[:, 1] = -0.039*subject_height
        gc_contact_points[:, 2] = 100 # Stiffness in BW/m
        gc_contact_points[:, 3] = 0.75 # Damping coefficient
        gc_contact_points[:, 4] = 0.5 # Friction coefficient
        gc_points.append(gc_contact_points)
    return torch.cat(gc_points, dim=-1)

def butter_lowpass_filter(data, cutoff, fs, order=3):
    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data, axis=-1)
    return y
