import numpy as np
from matplotlib import pyplot as plt
from pudb import set_trace

from tueplots import bundles, figsizes, cycler
from tueplots.constants.color import palettes
plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_half())
plt.rcParams.update(cycler.cycler(color=palettes.tue_plot))
#plt.rcParams.update({"figure.dpi": 150})


def print_data(entropy, state_buff):
    max_state = np.ones(state_buff[0].shape[0]) * 100000
    min_state = np.zeros(state_buff[0].shape[0])
    for idx in range(state_buff[0].shape[0]):
        max_state[idx] = np.max([x[idx] for x in state_buff])
        min_state[idx] = np.min([x[idx] for x in state_buff])
    print(f'Entropy: {entropy}')
    print('(Minimum, Maximum):')
    for idx in range(max_state.shape[0]):
        print(f'j_{idx}: ({min_state[idx]}, {max_state[idx]})')


def compute_entropy_shoulder(states):
    total_entr = 0
    for idx in range(states[0].shape[0]):
        Nbins = 20
        shoulder = [x[idx] for x in states]
        x = np.linspace(-3, 3, Nbins)
        counts = np.zeros_like(x)
        for idx, s in enumerate(shoulder):
            for bin_idx in range(x.shape[0]-1):
                if (s > x[bin_idx] and s < x[bin_idx+1]):
                    counts[bin_idx] += 1

        total_count = np.sum(counts)
        entr = 0
        for c in counts:
            p = c / total_count
            entr -= p * np.log(p + 1e-5)
        total_entr += entr
    return total_entr

def compute_mc(states, action, next_states):
    pass


def plot(axs, axs2, axs4, kappa, states, actions, activations):
    for idx, ax in enumerate(axs[:,kidx].flatten()):
        ax.plot([x[idx] for x in states], label=f'{kappa=}')
        #ax.set_ylim([-1.2,1.2])
        #ax.set_ylim([-3,3.5])
        ax.set_ylabel(f'j_{idx}')
        ax.set_xlabel('time')
        if idx == 0:
            ax.set_title(f'{kappa=}')
        #ax.legend()
    for idx, ax in enumerate(axs2[:,kidx].flatten()):
        ax.plot([x[idx] for x in actions], label=f'{kappa=}')
        ax.set_ylim([-0.1,1.1])
        ax.set_ylabel(f'u_{idx}')
        ax.set_xlabel('time')
        if idx == 0:
            ax.set_title(f'{kappa=}')
    for idx, ax in enumerate(axs4[:,kidx].flatten()):
        ax.plot([x[idx] for x in activations], label=f'{kappa=}')
        ax.set_ylim([-0.1,1.1])
        ax.set_ylabel(f'act_{idx}')
        ax.set_xlabel('time')
        if idx == 0:
            ax.set_title(f'{kappa=}')


def plot_trajs(action_buff, state_buff):
    fig, axs = plt.subplots(len(action_buff[0]),1)
    for idx in range(len(action_buff[0])):
        axs[idx].plot([x[idx] for x in action_buff])
        axs[idx].set_ylabel(f'a_{idx}')
        axs[idx].set_ylim([-1.1, 1.1])

    fig2, axs2 = plt.subplots(len(state_buff[0]),1)
    for idx in range(len(state_buff[0])):
        axs2[idx].plot([x[idx] for x in state_buff])
        axs2[idx].set_ylabel(f'j_{idx}')
        axs2[idx].set_ylim([-3.6, 3.6])
    plt.show()
            
            
def plot_ee_arm26(state_buff, env, episodes, title='default', name='default'):
    fig3, axs3 = plt.subplots(1,1)
    color_ee = 'tab:blue'
    x_data = []
    y_data = []
    length = env.max_episode_steps+1
    trackings = [x[-3:] for x in state_buff]
    print(len(trackings))
    for idx in range(len(trackings)//length):
        x_data.append(trackings[idx*length:idx*length+length])
        y_data.append(trackings[idx*length:idx*length+length])
    for x,y, idx in zip(x_data, y_data, range(len(x_data))):
        axs3.plot([a[0] for a in x], [a[1] for a in y], alpha=0.3)
        #axs3.plot([a[0] for a in x], [a[1] for a in y], alpha=0.4, color=plt.cm.inferno(idx/episodes))
        #axs3[0].plot([a[0] for a in x], [a[1] for a in y], color=plt.cm.RdYlBu(idx/episodes), alpha=0.3)
        #axs3[0].plot([a[0] for a in x], [a[1] for a in y], color=plt.cm.Paired(idx/episodes), alpha=0.3)
    #axs3.plot([x[0] for x in trackings],[x[1] for x in trackings], color='tab:grey', alpha=0.4)
    #axs3[1].hist([x[0] for x in trackings], alpha=1.0)
    #axs3[2].hist([x[1] for x in trackings], alpha=1.0)
    axs3.set_title(f'{title} - arm2dof6m')
    axs3.set_xlim([-0.8,1.1])
    axs3.set_ylim([-0.25,1.1])
    axs3.set_xlabel('endeffector - x')
    axs3.set_ylabel('endeffector - y')
    #axs3[1].set_xlabel('endeffector - x')
    #axs3[2].set_xlabel('endeffector - y')
    #axs3[1].set_ylabel('count - x')
    #axs3[2].set_ylabel('count - y')
    #plt.colorbar()a
    plt.savefig(f'ee_{name}_arm26.pdf')

    plt.show()

