import numpy as np
from matplotlib import pyplot as plt

from tueplots import bundles, figsizes, cycler
from tueplots.constants.color import palettes
from matplotlib import pyplot as plt
from pudb import set_trace
# plt.rcParams.update(bundles.icml2022())
# plt.rcParams.update(figsizes.icml2022_half())
# plt.rcParams.update(cycler.cycler(color=palettes.tue_plot))
# #plt.rcParams.update({"figure.dpi": 150})
plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_half())
bundle = bundles.icml2022()
# bundle['figure.figsize'] = (7, 4)
# bundle['figure.figsize'] = (12,8)
bundle['font.size'] = 14
bundle['axes.labelsize'] = 10
bundle['legend.fontsize'] = 8
bundle['xtick.labelsize'] = 8
bundle['ytick.labelsize'] = 8
bundle['axes.titlesize'] = 10
plt.rcParams.update(bundle)
plt.rcParams.update(figsizes.icml2022_half())


def print_data(entropy, state_buff):
    max_state = np.ones(state_buff[0].shape[0]) * 100000
    min_state = np.zeros(state_buff[0].shape[0])
    for idx in range(state_buff[0].shape[0]):
        max_state[idx] = np.max([x[idx] for x in state_buff])
        min_state[idx] = np.min([x[idx] for x in state_buff])
    print(f'Entropy: {entropy}')
    print('(Minimum, Maximum):')
    for idx in range(max_state.shape[0]):
        print(f'j_{idx}: ({min_state[idx]}, {max_state[idx]})')

def new_entorp(state_packet, typed, muscle=False):
    '''
    It's really more of a state coverage measure that looks at (q0,q1) tuples and measures the percentage of covered
    tuples on a NxN grid. It doesn't care about how many times a particular state was visited.
    '''
    N = 20
    state = np.array([[x[0], x[1]] for x in state_packet])
    #if muscle:
    #    a = 0.
    #    b = 2.09
    #else:
    #    a = - 2.09
    #    b= 2.09
    #plt.figure()
    #plt.plot(state[:,0], state[:,1], 'x')
    #plt.savefig('state.pdf')
    #plt.close()
    x_start = -0.8
    x_end = 1.1
    x_start = -0.9
    x_end = 1.1
    y_start = -0.25
    y_end = 1.1
    y_start = -1.1
    y_end = 1.1
    x = np.linspace(x_start, x_end, N)
    y = np.linspace(y_start, y_end, N)
    grid = np.zeros([N, N])
    for i in range(1, N):
        for j in range(1, N):
            for s in state:
                cd1 = s[0] > x[i-1] and s[0] < x[i]
                if cd1:
                    cd2 = s[1] > y[j-1] and s[1] < y[j]
                if cd1 and cd2:
                    grid[i, j] = 1
                    break
        #if not i % 20:
        #    plt.figure()
        #    plt.imshow(grid)
        #    plt.savefig(f'entropy_grid_{i}.pdf')
        #    plt.close()
    entropy = np.count_nonzero(grid) / (N * N)
    #plt.figure()
    #plt.imshow(grid)
    #print(grid.shape)
    #plt.savefig(f'entropy_grid_{typed}.pdf')
    #plt.close()
    print(entropy)
    return entropy

def compute_entropy(states, typed, muscle=False):
    states = np.array(states)
    states = states.reshape((5, states.shape[0] // 5, states.shape[-1]))
    entropy_samples = []
    for state_packet in states:
        total_entr = 0
        entr = new_entorp(state_packet, typed, muscle)
        #for idx in range(state_packet[0].shape[0]):
        #    shoulder = [x[idx] for x in state_packet]
        #    Nbins = 10
        #    x = np.linspace(-np.pi*0.5, np.pi*0.5, Nbins)
        #    counts = np.zeros_like(x)
        #    for idx, s in enumerate(shoulder):
        #        for bin_idx in range(x.shape[0]-1):
        #            if (s > x[bin_idx] and s < x[bin_idx+1]):
        #                counts[bin_idx] += 1
        #    total_count = np.sum(counts)
        #    entr = 0
        #    for c in counts:
        #        p = c / total_count
        #        entr -= p * np.log(p + 1e-5)
        #    entr = np.count_nonzero(counts) / len(counts)
        #    print(entr)
        #    total_entr += entr
        total_entr += entr
        entropy_samples.append(total_entr)
    return entropy_samples


def plot(axs, axs2, axs4, kappa, states, actions, activations):
    for idx, ax in enumerate(axs[:, kidx].flatten()):
        ax.plot([x[idx] for x in states], label=f'{kappa=}')
        # ax.set_ylim([-1.2,1.2])
        # ax.set_ylim([-3,3.5])
        ax.set_ylabel(f'j_{idx}')
        ax.set_xlabel('time')
        if idx == 0:
            ax.set_title(f'{kappa=}')
        # ax.legend()
    for idx, ax in enumerate(axs2[:, kidx].flatten()):
        ax.plot([x[idx] for x in actions], label=f'{kappa=}')
        ax.set_ylim([-0.1,1.1])
        ax.set_ylabel(f'u_{idx}')
        ax.set_xlabel('time')
        if idx == 0:
            ax.set_title(f'{kappa=}')
    for idx, ax in enumerate(axs4[:, kidx].flatten()):
        ax.plot([x[idx] for x in activations], label=f'{kappa=}')
        ax.set_ylim([-0.1,1.1])
        ax.set_ylabel(f'act_{idx}')
        ax.set_xlabel('time')
        if idx == 0:
            ax.set_title(f'{kappa=}')


def plot_trajs(action_buff, state_buff):
    # fig, axs = plt.subplots(len(action_buff[0]), 1)
    fig, axs = plt.subplots(6, 1, figsize=(6, 8))
    #for idx in range(len(action_buff[0])):
    for idx, muscle in enumerate([5, 6, 7, 8, 9, 10]):
        axs[idx].plot(np.arange(0, len(action_buff)*0.025, 0.025)[:903], [(x[muscle] + 1)/ 2 for x in action_buff][:903], color='tab:blue')
        a = 'a'
        axs[idx].set_ylabel(fr'$\mathrm{{{a}}}_{{{idx+1}}}$', fontsize=16)
        axs[idx].set_ylim([-0.1, 1.1])
        axs[idx].set_yticks([0.0, 0.5, 1.0])
        if idx != 5:
            axs[idx].set_xticklabels([])
    axs[-1].set_xlabel(f'time (s)', fontsize=16)
    fig.tight_layout()
    fig.savefig('actions_arm750.pdf')

    fig2, axs2 = plt.subplots(len(state_buff[0]), 1)
    for idx in range(len(state_buff[0])):
        axs2[idx].plot([x[idx] for x in state_buff])
        axs2[idx].set_ylabel(f'j_{idx}')
        axs2[idx].set_ylim([-3.6, 3.6])
    plt.show()


def plot_ee_arm26(ax, state_buff, prefix, title='default', name='default', title_bool=False, left=True):
    color_ee = 'tab:blue'
    x_data = []
    y_data = []
    length = 1001
    trackings = [x[-3:] for x in state_buff]
    print(len(trackings))
    for idx in range(len(trackings)//length):
        x_data.append(trackings[idx*length:idx*length+length])
        y_data.append(trackings[idx*length:idx*length+length])
    for x, y, idx in zip(x_data, y_data, range(len(x_data))):
        ax.plot([a[0] for a in x], [a[1] for a in y], alpha=0.1, color='grey', linewidth=0.1)
    # graph_name = 'arm2dof6m'
    # axs.set_title(f'{title} - {graph_name}')
    # if title_bool:
    #     axs.set_title(f'{title}')
    ax.set_xlim([-0.8, 1.1])
    ax.set_ylim([-0.25, 1.1])
    if prefix == 'torque':
        ax.set_ylim([-1.1, 1.1])
    #if not title_bool:
    #    axs.set_xlabel('hand - x')
    #if title_bool:
    #    ax.set_title(title)
    ax.set_xticks([])
    #if left:
    #    axs.set_ylabel('hand - y')
    ax.set_yticks([])
    # axs[1].set_xlabel('endeffector - x')
    # axs[2].set_xlabel('endeffector - y')
    # axs[1].set_ylabel('count - x')
    # axs[2].set_ylabel('count - y')
    # plt.colorbar()a
    # fig.text(0.02, 0.9, 'A', fontsize=20)
    #plt.savefig(f'ee_{name}_arm26.png')
    #plt.close()


def compute_mc(states, action, next_states):
    pass


def discretize(array, w_bins):
    discrete_array = np.zeros_like(array)
    for idx in range(array.shape[1]):
        discrete_array[:, idx] = discretiseMatrix(array[:, idx], w_bins)


def discretiseMatrix(data, bins):
    return np.minimum(bins, int(np.floor((data - np.min(data)) / (np.max(data) - np.min(data))) * bins))


