import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def get_data_IK(predict_dl, idx, filter = -1, fps_ = 102.4):
    subs = predict_dl.dataset.imu_data.subject + 1
    trials = predict_dl.dataset.imu_data.trial
    if trials.iloc[idx] == 'walking':
        pathto = f"data/weiss2022/Subject_{str(subs.iloc[idx]).zfill(2)}_oN/Results_OMC/walking_withoutWS/Walk_oN_Result_IK_Trial.mot"
    if trials.iloc[idx] == 'running':
        pathto = f"data/weiss2022/Subject_{str(subs.iloc[idx]).zfill(2)}_oN/Results_OMC/running_withoutWS/Run_oN_Result_IK_Trial.mot"
    IK_data = pd.read_csv(pathto, delimiter = "\t", skiprows=10)
    case = IK_data
    fps = len(case)/case.time.iloc[-1]
    IK_data = torch.zeros(1,len(case),27).to('cpu')



    IK_data[:,:,0] = torch.tensor(case.pelvis_tx)
    IK_data[:,:,3] = torch.tensor(case.pelvis_ty)

    IK_data[:,:,6] = np.pi/180*torch.tensor(case.pelvis_tilt)
    IK_data[:,:,9] = np.pi/180*torch.tensor(case.hip_flexion_r)
    IK_data[:,:,12] = np.pi/180*torch.tensor(case.knee_angle_r)
    IK_data[:,:,15] = np.pi/180*torch.tensor(case.ankle_angle_r)
    IK_data[:,:,18] = np.pi/180*torch.tensor(case.hip_flexion_l)
    IK_data[:,:,21] = np.pi/180*torch.tensor(case.knee_angle_l)
    IK_data[:,:,24] = np.pi/180*torch.tensor(case.ankle_angle_l)

    if filter > 0:
        from scipy.signal import butter, filtfilt
        b, a = butter(4, filter/(fps/2), 'low')
        IK_data = filtfilt(b, a, IK_data, axis=1)
        IK_data = torch.from_numpy(IK_data.copy()).to('cpu').to(torch.float32)

    IK_data[:,:,1] =  torch.diff(IK_data[:,:,0], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,4] =  torch.diff(IK_data[:,:,3], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,2] =  torch.diff(IK_data[:,:,1], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,5] =  torch.diff(IK_data[:,:,4], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,7] =  torch.diff(IK_data[:,:,6], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,8] =  torch.diff(IK_data[:,:,7], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,10] = torch.diff(IK_data[:,:,9], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,11] = torch.diff(IK_data[:,:,10], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,13] = torch.diff(IK_data[:,:,12], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,14] = torch.diff(IK_data[:,:,13], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,16] = torch.diff(IK_data[:,:,15], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,17] = torch.diff(IK_data[:,:,16], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,19] = torch.diff(IK_data[:,:,18], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,20] = torch.diff(IK_data[:,:,19], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,22] = torch.diff(IK_data[:,:,21], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,23] = torch.diff(IK_data[:,:,22], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,25] = torch.diff(IK_data[:,:,24], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,26] = torch.diff(IK_data[:,:,25], dim=-1, append=torch.zeros((1,1))) * fps

    from scipy.signal import resample
    num_samples = int(len(case) * fps_/fps)
    IK_data = torch.nn.functional.interpolate(IK_data.transpose(2,1).clone(), size=num_samples, mode='linear').transpose(1,2)

    return IK_data


def plot_stick_2D(frame, ax = None):
    if ax is None:
        ax = plt.gca()
    import src.kinematics.kinematics_2d as k2d
    kin_model = k2d.kinematics_model
    # Floor
    plt.plot([0,2],[0,0],'k')
    for key in frame.keys():
        # Set color
        if key[-1] == 's':
            c = 'k'
            ax.plot([frame[key][0], frame[key][0]+torch.sin(frame[key][6])*0.5],
                      [frame[key][3], frame[key][3]+torch.cos(frame[key][6])*0.5],c=c)
            continue
        elif key[-1] == 'r' or key[:2] == 'r_':
            c = 'b'
        elif key[-1] == 'l' or key[:2] == 'l_':
            c = 'r'
        # Plot lines from parent to segment:
        if key in kin_model:
            parent_key = kin_model[key]['parent']
        else:
            parent_key = k2d.gc_model[key]['parent']
        ax.plot([frame[parent_key][0], frame[key][0]], [frame[parent_key][3], frame[key][3]], c=c)

        if key in ['r_toe','l_toe']:
            # Find the parent by replacing 'toe' with 'heel'
            parent_key = key.replace('toe','heel')
            ax.plot([frame[parent_key][0], frame[key][0]], [frame[parent_key][3], frame[key][3]], c=c)

    return ax

def plot_stick_2d_data(reconstructed_data, gc_model, start_idx=0, end_idx=100, spacing = 1):
    for i in range(start_idx, end_idx, spacing):
        frame = {}
        for key in reconstructed_data.keys():
            frame[key] = reconstructed_data[key][0, i].detach().cpu()
        for key in gc_model.keys():
            frame[key] = gc_model[key][0,i].detach().cpu()
        plot_stick_2D(frame)
    plt.gca().set_aspect('equal')
    plt.show()

def animate_stick_2d(reconstructed_data, gc_model, start_idx=0, end_idx=100, fps = 1, subtrans = False, save = False):
    from matplotlib.animation import FuncAnimation
    from IPython.display import HTML
    # rewirte the visualize_recostruction function to work with the new data format from plot_stick data 2d
    nlines = 13
    fig, ax = plt.subplots(figsize=(10, 5))
    colors = ['k',  'b', 'b','b','r','r','r', 'b', 'b', 'b', 'r','r','r']
    lines = [ax.plot([], [], 'o-', lw=2, color=colors[i])[0] for i in
             range(nlines)]  # Line objects for the segments
    imu_points = [ax.plot([], [], 'o', color=colors[i])[0] for i in range(nlines)]
    nlines += 25
    background = [ax.plot([], [], '-', lw=2, color='gray')[0] for i in range(25)]
    if subtrans:
        ax.set_xlim(-1,1)
        ax.set_ylim(-0.5, 2)
    else:
        ax.set_xlim(-0.5, 7)
        ax.set_ylim(-0.5, 2)
    ax.set_aspect('equal')
    import src.kinematics.kinematics_2d as k2d
    kin_model = k2d.kinematics_model
    # Floor
    plt.plot([-2,7],[0,0],'k')
    def init_anim():
        for line in lines:
            line.set_data([], [])
        for imu_point in imu_points:
            #imu_point.set_data([], [])
            pass
        for line in background:
            line.set_data([], [])
        return lines + background

    def update(i):
        frame = {}
        for key in reconstructed_data.keys():
            frame[key] = reconstructed_data[key][0, i+start_idx].detach().cpu()
        for key in gc_model.keys():
            frame[key] = gc_model[key][0, i+start_idx].detach().cpu()
        count = 0
        posx_0 = 0
        for key in frame.keys():
            # Set color
            if key == 'pelvis':
                if subtrans:
                    posx_0 = frame['pelvis'][0]
                for j in range(25):
                    background[j].set_data([j-posx_0, j-posx_0], [0, 1])

                lines[count].set_data([frame[key][0]-posx_0, frame[key][0]-posx_0 + torch.sin(frame[key][6]) * 0.5],
                         [frame[key][3], frame[key][3] + torch.cos(frame[key][6]) * 0.5])

                count += 1
                continue
            # Plot lines from parent to segment:
            if key in kin_model:
                parent_key = kin_model[key]['parent']
            else:
                parent_key = k2d.gc_model[key]['parent']
            lines[count].set_data([frame[parent_key][0]-posx_0, frame[key][0]-posx_0], [frame[parent_key][3], frame[key][3]])
            count += 1
            if key in ['r_toe', 'l_toe']:
                # Find the parent by replacing 'toe' with 'heel'
                parent_key = key.replace('toe', 'heel')
                lines[count].set_data([frame[parent_key][0]-posx_0, frame[key][0]-posx_0],
                         [frame[parent_key][3], frame[key][3]])
                count += 1
        return lines + imu_points

    ani = FuncAnimation(fig, update, frames=(end_idx-start_idx), init_func=init_anim,
                        blit=True, repeat=True, interval=1000/fps)

    if save:
        # Save the animation as an mp4 file
        ani.save('stick_2d.gif', writer='imagemagick')

    return HTML(ani.to_jshtml())


def plot_imu_pre_after(imu_reconstructed, imu_raw,cfg, debug = False):
    # Define time vector
    imu_reconstructed = imu_reconstructed.cpu().clone().detach()
    if debug:
        fps = 100
        x_ = np.linspace(0,imu_reconstructed.shape[1]/fps,imu_reconstructed.shape[1])
    else:
        x_ = np.linspace(0,imu_reconstructed.shape[1]/cfg.fps,imu_reconstructed.shape[1])
    plt.figure(figsize=(5/3*imu_reconstructed.shape[-1], 15))
    count = 1
    c = 0
    if debug == True:
        ikey = cfg.dataset_variables.IMU_data
    else:
        ikey = cfg.datamodule.dataset_variables.IMU_data

    for idx, imu in enumerate(ikey):
        plt.subplot(4, 2, count)
        colors = ['r','g','b']
        plt.plot(x_,imu_reconstructed[0, :, idx],colors[c%3])
        plt.plot(x_,imu_raw[0, :, idx],f'{colors[c%3]}--')
        plt.xlabel('time in [s]')
        plt.ylabel('signal in [m/s^2 | 1/s]')
        plt.grid()
        c+=1
        if c%3 == 0:
            plt.legend([f'{ikey[idx-2][4:]} sim',f'{ikey[idx-2][4:]} raw',f'{ikey[idx-1][4:]} sim',f'{ikey[idx-1][4:]} raw',f'{imu[4:]} sim',f'{imu[4:]} raw'])
            count += 1


