import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def get_data_IK(predict_dl, idx, filter = -1, fps_ = 102.4):
    subs = predict_dl.dataset.imu_data.subject + 1
    trials = predict_dl.dataset.imu_data.trial
    if trials.iloc[idx] == 'walking':
        pathto = f"data/weiss2022/Subject_{str(subs.iloc[idx]).zfill(2)}_oN/Results_OMC/walking_withoutWS/Walk_oN_Result_IK_Trial.mot"
    if trials.iloc[idx] == 'running':
        pathto = f"data/weiss2022/Subject_{str(subs.iloc[idx]).zfill(2)}_oN/Results_OMC/running_withoutWS/Run_oN_Result_IK_Trial.mot"
    IK_data = pd.read_csv(pathto, delimiter = "\t", skiprows=10)
    case = IK_data
    fps = len(case)/case.time.iloc[-1]
    IK_data = torch.zeros(1,len(case),27).to('cpu')



    IK_data[:,:,0] = torch.tensor(case.pelvis_tx)
    IK_data[:,:,3] = torch.tensor(case.pelvis_ty)

    IK_data[:,:,6] = np.pi/180*torch.tensor(case.pelvis_tilt)
    IK_data[:,:,9] = np.pi/180*torch.tensor(case.hip_flexion_r)
    IK_data[:,:,12] = np.pi/180*torch.tensor(case.knee_angle_r)
    IK_data[:,:,15] = np.pi/180*torch.tensor(case.ankle_angle_r)
    IK_data[:,:,18] = np.pi/180*torch.tensor(case.hip_flexion_l)
    IK_data[:,:,21] = np.pi/180*torch.tensor(case.knee_angle_l)
    IK_data[:,:,24] = np.pi/180*torch.tensor(case.ankle_angle_l)

    if filter > 0:
        from scipy.signal import butter, filtfilt
        b, a = butter(4, filter/(fps/2), 'low')
        IK_data = filtfilt(b, a, IK_data, axis=1)
        IK_data = torch.from_numpy(IK_data.copy()).to('cpu').to(torch.float32)

    IK_data[:,:,1] = torch.diff(IK_data[:,:,0], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,4] = torch.diff(IK_data[:,:,3], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,2] = torch.diff(IK_data[:,:,1], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,5] = torch.diff(IK_data[:,:,4], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,7] = torch.diff(IK_data[:,:,6], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,8] = torch.diff(IK_data[:,:,7], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,10] = torch.diff(IK_data[:,:,9], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,11] = torch.diff(IK_data[:,:,10], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,13] = torch.diff(IK_data[:,:,12], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,14] = torch.diff(IK_data[:,:,13], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,16] = torch.diff(IK_data[:,:,15], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,17] = torch.diff(IK_data[:,:,16], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,19] = torch.diff(IK_data[:,:,18], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,20] = torch.diff(IK_data[:,:,19], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,22] = torch.diff(IK_data[:,:,21], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,23] = torch.diff(IK_data[:,:,22], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,25] = torch.diff(IK_data[:,:,24], dim=-1, append=torch.zeros((1,1))) * fps
    IK_data[:,:,26] = torch.diff(IK_data[:,:,25], dim=-1, append=torch.zeros((1,1))) * fps

    from scipy.signal import resample
    num_samples = int(len(case) * fps_/fps)
    IK_data = torch.nn.functional.interpolate(IK_data.transpose(2,1).clone(), size=num_samples, mode='linear').transpose(1,2)

    return IK_data


def plot_stick_2D(frame, ax = None):
    if ax is None:
        ax = plt.gca()
    import src.kinematics.kinematics_2d as k2d
    kin_model = k2d.kinematics_model
    # Floor
    plt.plot([0,1],[0,0],'k',linewidth=0.0)
    for key in frame.keys():
        # Set color
        if key[-1] == 's':
            c = 'k'
            ax.plot([frame[key][0],frame[key][0]-torch.sin(frame[key][6])*0.5],
                      [frame[key][3], frame[key][3]+torch.cos(frame[key][6])*0.5],c=c)
            continue
        elif key[-1] == 'r' or key[:2] == 'r_':
            c = 'b'
        elif key[-1] == 'l' or key[:2] == 'l_':
            c = 'r'
        # Plot lines from parent to segment:
        if key in kin_model:
            parent_key = kin_model[key]['parent']
        else:
            parent_key = k2d.gc_model[key]['parent']
        ax.plot([frame[parent_key][0], frame[key][0]], [frame[parent_key][3], frame[key][3]], c=c)

        if key in ['r_toe','l_toe']:
            # Find the parent by replacing 'toe' with 'heel'
            parent_key = key.replace('toe','heel')
            ax.plot([frame[parent_key][0], frame[key][0]], [frame[parent_key][3], frame[key][3]], c=c)

    return ax

def plot_stick_2d_data(reconstructed_data, gc_model, start_idx=0, end_idx=100, spacing = 1, noshow = False):
    for i in range(start_idx, end_idx, spacing):
        frame = {}
        for key in reconstructed_data.keys():
            frame[key] = reconstructed_data[key][0, i].detach().cpu()
        for key in gc_model.keys():
            frame[key] = gc_model[key][0,i].detach().cpu()
        plot_stick_2D(frame)
    if not noshow:
        plt.gca().set_aspect('equal')
        plt.show()

def animate_stick_2d(reconstructed_data, gc_model, start_idx=0, end_idx=100, fps = 1, subtrans = False, save = False, figsize = (10,5)):
    from matplotlib.animation import FuncAnimation
    import matplotlib
    from IPython.display import HTML
    # rewirte the visualize_recostruction function to work with the new data format from plot_stick data 2d
    nlines = 13
    fig, ax = plt.subplots(figsize=figsize)
    colors = ['k',  'b', 'b','b','r','r','r', 'b', 'b', 'b', 'r','r','r']
    lines = [ax.plot([], [], 'o-', lw=2, color=colors[i])[0] for i in
             range(nlines)]  # Line objects for the segments
    imu_points = [ax.plot([], [], 'o', color=colors[i])[0] for i in range(nlines)]
    nlines += 25
    background = [ax.plot([], [], '-', lw=2, color='gray')[0] for i in range(25)]
    if subtrans:
        ax.set_xlim(-1,1)
        ax.set_ylim(-0.5, 2)
    else:
        ax.set_xlim(-0.5, reconstructed_data['pelvis'][0,-1,0].detach().cpu().numpy()+1)
        ax.set_ylim(-0.5, 2)
    matplotlib.rcParams['animation.embed_limit'] = 200e6 # 200 MB

    ax.set_aspect('equal')
    import src.kinematics.kinematics_2d as k2d
    kin_model = k2d.kinematics_model
    # Floor
    plt.plot([-2,25],[0,0],'k')
    def init_anim():
        for line in lines:
            line.set_data([], [])
        for imu_point in imu_points:
            #imu_point.set_data([], [])
            pass
        for line in background:
            line.set_data([], [])
        return lines + background

    def update(i):
        frame = {}
        for key in reconstructed_data.keys():
            frame[key] = reconstructed_data[key][0, i+start_idx].detach().cpu()
        for key in gc_model.keys():
            frame[key] = gc_model[key][0, i+start_idx].detach().cpu()
        count = 0
        posx_0 = 0
        for key in frame.keys():
            # Set color
            if key == 'pelvis':
                if subtrans:
                    posx_0 = frame['pelvis'][0]
                for j in range(25):
                    background[j].set_data([j-posx_0, j-posx_0], [0, 1])

                lines[count].set_data([frame[key][0]-posx_0, frame[key][0]-posx_0 - torch.sin(frame[key][6]) * 0.5],
                         [frame[key][3], frame[key][3] + torch.cos(frame[key][6]) * 0.5])

                count += 1
                continue
            # Plot lines from parent to segment:
            if key in kin_model:
                parent_key = kin_model[key]['parent']
            else:
                parent_key = k2d.gc_model[key]['parent']
            lines[count].set_data([frame[parent_key][0]-posx_0, frame[key][0]-posx_0], [frame[parent_key][3], frame[key][3]])
            count += 1
            if key in ['r_toe', 'l_toe']:
                # Find the parent by replacing 'toe' with 'heel'
                parent_key = key.replace('toe', 'heel')
                lines[count].set_data([frame[parent_key][0]-posx_0, frame[key][0]-posx_0],
                         [frame[parent_key][3], frame[key][3]])
                count += 1
        return lines + imu_points

    ani = FuncAnimation(fig, update, frames=(end_idx-start_idx), init_func=init_anim,
                        blit=True, repeat=True, interval=1000/fps)

    if save:
        # Save the animation as an mp4 file
        ani.save('stick_2d.gif', writer='imagemagick')

    return HTML(ani.to_jshtml())



def plot_imu_pre_after(imu_reconstructed, imu_raw,cfg):
    # Define time vector
    imu_reconstructed = imu_reconstructed.cpu().clone().detach()
    x_ = np.linspace(0,imu_raw.shape[1]/cfg.fps,imu_raw.shape[1])
    plt.figure(figsize=(5/3*imu_raw.shape[-1], 15))
    count = 1
    c = 0
    ikey = cfg.datamodule.dataset_variables.IMU_data
    for idx, imu in enumerate(ikey):
        plt.subplot(4, 2, count)
        colors = ['r','g','b']
        plt.plot(x_,imu_reconstructed[0, :, idx],colors[c%3])
        plt.plot(x_,imu_raw[0, :, idx],f'{colors[c%3]}--')
        plt.xlabel('time in [s]')
        plt.ylabel('signal in [m/s^2 | 1/s]')
        plt.grid()
        c+=1
        if c%3 == 0:
            plt.legend([f'{ikey[idx-2][4:]} sim',f'{ikey[idx-2][4:]} raw',f'{ikey[idx-1][4:]} sim',f'{ikey[idx-1][4:]} raw',f'{imu[4:]} sim',f'{imu[4:]} raw'])
            count += 1

def plot_kinematics(IK_data_sim, IK_data_gt, start_idx, cfg, offset):
    IK_data_sim = IK_data_sim.cpu().detach().numpy()
    plt.figure(figsize=(15, 5))
    x_ = np.linspace(0,IK_data_sim.shape[1]/cfg.fps,IK_data_sim.shape[1])
    siglen = IK_data_sim.shape[1]
    count = 1
    for idx, ik in enumerate(cfg.datamodule.dataset_variables.IK_data):
        if idx % 3 != 0:
            continue
        plt.subplot(3,3, count)
        if idx > 0:
            fac = 1
            if idx > 3:
                fac = 1/np.pi*180
            plt.plot(x_,IK_data_sim[0, :, idx]*fac)
        if IK_data_gt is not None:
            plt.plot(x_,IK_data_gt[0, start_idx+offset:start_idx+siglen+offset, idx]*fac)
        plt.xlabel('time in [s]')
        plt.ylabel('signal in [m | °]')
        if idx == 0:
            plt.plot(x_,IK_data_sim[0, :, 1])
            plt.legend(['v_sim (m/s)'])
        else:
            plt.legend([f'{ik} sim',f'{ik} raw'])
        plt.gca().yaxis.set_minor_locator(plt.MultipleLocator(5))  # Set minor ticks interval
        plt.grid(True, which='major', axis='both', linestyle='-', linewidth=0.5)  # Major grid lines
        plt.grid(True, which='minor', axis='y', linestyle='--', linewidth=0.3)
        count += 1

def plot_kinetics(torques, cfg, id_df = None):
    if id_df is not None:
        id_df, mass, start_idx = id_df
    torques = torques.cpu().detach().numpy()
    x_ = np.linspace(0,torques.shape[1]/cfg.fps,torques.shape[1])
    plt.figure(figsize=(15, 5))
    for idx, ik in enumerate(cfg.model.estimated_variables.torques):
        plt.subplot(2,3, idx+1)
        if id_df is not None:
            osim_torque_names = ['hip_flexion_r_moment', 'knee_angle_r_moment', 'ankle_angle_r_moment', 'hip_flexion_l_moment', 'knee_angle_l_moment', 'ankle_angle_l_moment']
            plt.plot(x_, torques[0, :, idx])
            if idx in [1,4]:
                plt.plot(id_df['time']+start_idx/100 + 0.03,id_df[osim_torque_names[idx]]/mass)
            else:
                plt.plot(id_df['time']+start_idx/100 + 0.03,id_df[osim_torque_names[idx]]/mass)
            plt.legend([f'{ik} IMU',f'{ik} ID'])
        else:
            plt.plot(x_, torques[0, :, idx])
            plt.legend([f'{ik}'])
        plt.grid()
        plt.xlabel('time in [s]')
        plt.ylabel('torque in Nm/kg(BW)')

def plot_grf(grf, _ ,cfg, grf_gt = None, bodymass = None):
    grf = grf.cpu().detach().numpy()
    grf = grf / 9.81 # in BW
    if grf.shape[-1] == 8:
        grf_ = grf[:,:,[0,1,4,5]] + grf[:,:,[2,3,6,7]]
    else:
        grf_ = grf
    plt.figure(figsize=(10, 5))
    x_ = np.linspace(0,grf.shape[1]/cfg.fps,grf.shape[1])
    grf_names = ['grf_x_r','grf_y_r','grf_x_l','grf_y_l']
    grf_gt_names = ['ground_force_1_vz','ground_force_1_vy', None, None]
    idx_heel = [0,1,4,5]
    idx_toe = [2,3,6,7]
    for i in range(4):
        plt.subplot(2,2, i+1)
        plt.plot(x_,grf_[0,:,i])
        if grf.shape[-1] == 8:
            plt.plot(x_,grf[0,:,idx_heel[i]],'k--')
            plt.plot(x_,grf[0,:,idx_toe[i]],'k-.')
            plt.legend([grf_names[i],'heel','toe'])
        else:
            plt.legend([grf_names[i]])
        if grf_gt is not None and grf_gt_names[i] is not None:
            plt.plot(grf_gt.time+1.03, grf_gt[grf_gt_names[i]]/bodymass/9.81)
            plt.legend(grf_gt_names[i])
        plt.xlabel('time in [s]')
        plt.ylabel('signal in [BW]')


def plot_kanesloss(loss_k, pow=1):
    loss_k = loss_k.detach().numpy()
    fig = plt.figure(figsize=(20, 5))
    count = 1
    fig.add_subplot(1, 3, count)
    for idx, col in enumerate(['pelvis_x', 'pelvis_y', 'pelvis_a']):
        line, = plt.plot(loss_k[0, :, idx] ** pow)
        line.set_label(col)
    plt.legend()
    plt.grid()
    plt.xlabel('timestep')
    plt.ylabel('loss in [N/kg(BW)]/[Nm/kg(BW)]')

    fig.add_subplot(1, 3, count + 1)
    for idx, col in enumerate(['hip_r', 'knee_r', 'ankle_r']):
        line, = plt.plot(loss_k[0, :, idx + 3] ** pow)
        line.set_label(col)
    plt.legend()
    plt.grid()
    plt.xlabel('timestep')
    plt.ylabel('loss in [Nm/kg(BW)]')
    fig.add_subplot(1, 3, count + 2)
    for idx, col in enumerate(['hip_l', 'knee_l', 'ankle_l']):
        line, = plt.plot(loss_k[0, :, idx + 6] ** pow)
        line.set_label(col)
    plt.legend()
    plt.grid()
    plt.xlabel('timestep')
    plt.ylabel('loss in [Nm/kg(BW)]')


def plot_euler(IK_data):
    a = IK_data.detach().cpu()
    fig = plt.figure(figsize=(20, 10))
    count = 1

    for i in range(9):
        fig.add_subplot(3, 3, count)

        if i >= 0:
            var = torch.std(a[0, :, 3 * i + 1])
            plt.plot(a[0, :, 3 * i].diff(1) * 100 / var, 'k-')
            plt.plot(a[0, :, 3 * i + 1] / var, 'k--')

        var = torch.std(a[0, :, 3 * i + 2])

        plt.plot(a[0, :, 3 * i + 1].diff(1) * 100 / var, 'r-')
        plt.plot(a[0, :, 3 * i + 2] / var, 'r--')
        count += 1
    plt.tight_layout()

def target_zones(trigger_value):
    cases =  {7:"0.9-1.0",
        8:"1.2-1.4",
        9:"1.8-2.0",
        10:"3.0-3.3",
        11:"3.9-4.1",
        12:"4.7-4.9"}
    return cases[trigger_value]