import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.lines import Line2D
import scipy.signal

from tueplots import bundles, figsizes, cycler
from tueplots.constants.color import palettes

plt.rcParams.update(bundles.neurips2022()) 
#bundle = bundles.neurips2022()
#bundle['figure.figsize'] = (4.0, 3.3)
## bundle['font.size'] = 14
#bundle['axes.labelsize'] = 12
## bundle['legend.fontsize'] = 8
#bundle['xtick.labelsize'] = 12
#bundle['ytick.labelsize'] = 12
#bundle['axes.titlesize'] = 18
#plt.rcParams.update(bundle)
DEVIATION = 1

def change_figsize(fig_dict, fraction_width, fraction_height=1):
    fig_dict['figure.figsize'] = fig_dict['figure.figsize'][0] * fraction_width, fig_dict['figure.figsize'][1] * fraction_height
    return fig_dict

# where do we start to load
a = 800
# where do we end to load
b = 1000
width = 1

titles = ['DEP-MPO', 'MPO', 'TD4']
colors = ['tab:orange', 'tab:blue', 'tab:green']
colors_twin = [[204/255, 102/255, 1/255], [27/255, 0, 162/255], [0, 102/255, 0]]

if DEVIATION:
    fig_dict = figsizes.neurips2022(ncols=3)
    plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.7, fraction_height=1.0) )
    print(fig_dict)
    fig, axs = plt.subplots(1, 3, sharey=True)
    for i, descriptor in enumerate(['depmpo_fast', 'mpo_fast', 'td4_fast']):
        foot_left = np.load(f'foot_deviation_left_{descriptor}_ever.npy')[a:b]
        foot_right = np.load(f'foot_deviation_right_{descriptor}_ever.npy')[a:b]
        x = np.arange(foot_left.shape[0]) * 0.025
        axs[i].plot(x, foot_left, label='left foot', color=colors[i], linewidth=width)
        axs[i].plot(x, foot_right, '--', label='right foot', color=colors_twin[i], linewidth=width)
        axs[i].set_ylim([-0.76, 0.65])
        axs[i].set_xticks([0, 1, 2, 3, 4, 5])
        axs[i].set_xticklabels([20, 21, 22, 23, 24, 25])
        if i == 0:
            axs[i].set_ylabel('rel. foot-x (m)')
        axs[i].set_xlabel('time (s)')
        #axs[i].set_title(titles[i])
        #if descriptor == 'mpo':
        #    # ax.legend(bbox_to_anchor=(-0.2, 0.8), ncol=2, loc='lower right', frameon=False)
        #    ax.legend(loc='upper left', frameon=False)
        #var = np.abs(foot_left - foot_right)
        #var = np.mean(np.abs(foot_left - np.mean(foot_left)))
        #var = foot_left - np.mean(foot_left)
        #var = foot_right - np.mean(foot_right)
        #conv = []
        #for delay in range(foot_left.shape[0]):
        #    conv.append(np.abs(var[0] - var[delay]))
        #plt.figure()
        #plt.plot(conv)
        #plt.show()
        #print(f'{descriptor} {np.max(conv)}')
        # var = foot_right - np.mean(foot_right)
    plt.savefig(f'foot_deviations.pdf')

fig_dict = figsizes.neurips2022(nrows=3)
plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.3, fraction_height=0.25))
plt.rcParams.update(fig_dict)
print(fig_dict)
fig, axs = plt.subplots(3, 1, sharex=True)
for ax_idx, descriptor in enumerate(['depmpo_fast', 'mpo_fast', 'td4_fast']):
    ground_left = np.load(f'ground_left_{descriptor}_ever.npy')[a:b]
    ground_right = np.load(f'ground_right_{descriptor}_ever.npy')[a:b]
    for ground, start, color in [(ground_left, 0, colors[ax_idx]), (ground_right, 1.1, colors_twin[ax_idx])]:
        box_started = 0
        rectangles = []
        for i, data in enumerate(ground):
            if data == 1 and not box_started:
                box_started = 1
                start_index = i
            if data == 0 and box_started:
                box_started = 0
                finish_index = i
                rectangles.append((start_index, finish_index))

        boxes = [Rectangle((x[0], start), x[1] - x[0], 1.0) for x in rectangles]

        # Create patch collection with specified colour/alpha
        facecolor = 'r'
        edgecolor = 'none'
        alpha = 1.
        # boxes = [Rectangle((0,0), 5, 5)]
        pc = PatchCollection(boxes, facecolor=color, alpha=alpha,
                             edgecolor=edgecolor)

        # Add collection to axes
        axs[ax_idx].add_collection(pc)
        #axs[ax_idx].set_xlim([a, b])
        axs[ax_idx].set_xlim([0, b - a])
        axs[ax_idx].set_ylim([0, 2.0])
        axs[ax_idx].set_yticks([])
        if ax_idx != 2:
            axs[ax_idx].set_xticks([])
            axs[ax_idx].text(-22, 1.2, 'LF', fontsize=8)
            axs[ax_idx].text(-22, 0.1, 'RF', fontsize=8)
        else:
            axs[ax_idx].set_xticks(np.arange(0, b-a + 40, 40))
            axs[ax_idx].set_xticklabels([x  for x in np.arange(20, 26, 1)])
            axs[ax_idx].set_xlabel('time (s)')
            axs[ax_idx].text(-22.2, 1.2, 'LF', fontsize=8)
            axs[ax_idx].text(-22.2, 0.1, 'RF', fontsize=8)
        #if ax_idx != 0:
        #    axs[ax_idx].set_title(titles[ax_idx])
#plt.tight_layout()
#fig.set_constrained_layout_pads(w_pad=10 / 72, h_pad=10 / 72, hspace=0, wspace=0)
#fig.subplots_adjust(left=0.25)
plt.savefig('foot_patterns.pdf')

fig_dict = figsizes.neurips2022(ncols=3)
plt.rcParams.update(change_figsize(fig_dict, fraction_width=0.7, fraction_height=1.105))
plt.rcParams.update(fig_dict) 
fig, axs = plt.subplots(1, 3, sharey=True)
for ax_idx, descriptor in enumerate(['depmpo_fast', 'mpo_fast', 'td4_fast']):
    color = colors[ax_idx]
    twin_color = colors_twin[ax_idx]
    fl_lx = np.load(f'foot_locus_left_x_{descriptor}_ever.npy')[a:b]
    fl_lx = np.load(f'foot_deviation_left_{descriptor}_ever.npy')[a:b]
    fl_lz = np.load(f'foot_locus_left_z_{descriptor}_ever.npy')[a:b]
    fl_rx = np.load(f'foot_locus_right_x_{descriptor}_ever.npy')[a:b]
    fl_rx = np.load(f'foot_deviation_right_{descriptor}_ever.npy')[a:b]
    fl_rz = np.load(f'foot_locus_right_z_{descriptor}_ever.npy')[a:b]  
    axs[ax_idx].plot(fl_lx, fl_lz, color=color)
    axs[ax_idx].plot(fl_rx, fl_rz, '--', color=twin_color)
    axs[ax_idx].set_xlim([-1.0, 0.7])
    axs[ax_idx].set_ylim([-0.11, 1.0])
    if ax_idx == 0:
        axs[ax_idx].set_ylabel('foot  - z  (m)')
    axs[ax_idx].set_xlabel('foot  - x (m)')
plt.savefig(f'foot_locus.pdf')
# twin_colors = [[27/255, 0, 162/255],[138/255, 88/255, 1/255]]
# colors = ['tab:blue', 'tab:orange']
# speeds = []
# for descriptor, color, twin_color in zip(['best_depmpo_fast', 'best_dep_depmpo_fast'], colors, twin_colors):
#     fp_l = np.load(f'{descriptor}_foot_pend_left.npy')
#     fp_r = np.load(f'{descriptor}_foot_pend_right.npy')
#     fl_lx = np.load(f'{descriptor}_foot_locus_left_x.npy')
#     fl_lz = np.load(f'{descriptor}_foot_locus_left_z.npy')
#     fl_rx= np.load(f'{descriptor}_foot_locus_right_x.npy')
#     fl_rz = np.load(f'{descriptor}_foot_locus_right_z.npy')
#     speed = np.load(f'{descriptor}_speed.npy')
#     
#     fig, ax = plt.subplots(1,1)
#     ax.plot(fp_l, color=color)
#     ax.plot(fp_r, '--', color=twin_color)
#     plt.show()
#     fig, ax = plt.subplots(1,1)
#     ax.plot(fl_lx, fl_lz, color=color)
#     ax.plot(fl_rx, fl_rz, '--', color=twin_color)
#     plt.show()
#     speeds.append(np.mean(speed))
# fig, ax = plt.subplots(1, 1)
# x = np.arange(len(speeds))
# bars = ax.bar(x, speeds)
# for color, bar in zip(colors, bars):
#     bar.set_color(color)
# plt.show()
    

