import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np

def show_stat(dict_stat, figname):

    plt.clf()
    plt.rcParams["figure.figsize"] = (8, 8)

    dict_count = []
    dict_prob = []
    total = sum(dict_stat.values())
    for key, cnt in dict_stat.items():
        dict_count.append([key, cnt])
        dict_prob.append([key, cnt/total])
    df_cnt = pd.DataFrame(dict_count, columns=['class', 'count'])
    df_prob = pd.DataFrame(dict_prob, columns=['class', 'prob'])

    # sns_bar = sns.barplot(data=dict_prob, x="class", y='prob')
    # sns_bar.set(ylim=(0, 0.5))
    # sns_bar.set_xticklabels(sns_bar.get_xticklabels(), rotation=30, horizontalalignment='right')

    # fig_cnt = sns_bar.get_figure()
    # fig_cnt.savefig(figname+'_cnt.png')

    plt.clf()

    sns_bar = sns.barplot(data=df_cnt, x="class", y='count')
    sns_bar.set_xticklabels(sns_bar.get_xticklabels(), rotation=30, horizontalalignment='right')

    fig_cnt = sns_bar.get_figure()
    fig_cnt.savefig(figname)


def show_split(verb_noun_iid, verb_noun_ood, figname='fig/split.png', save=True):

    verb_noun_all = {**verb_noun_iid, **verb_noun_ood}
    verb_noun_all = {k: v for k, v in sorted(verb_noun_all.items(), key=lambda item: item[0][0]+'_'+item[0][1])}

    # dict to matrix
    cnt_verb, cnt_noun = 0, 0
    dict_verb, dict_noun = dict(), dict()

    for key, value in verb_noun_all.items():
        verb, noun = key[0], key[1]
        if verb not in dict_verb:
            dict_verb[verb] = cnt_verb
            cnt_verb += 1
        if noun not in dict_noun:
            dict_noun[noun] = cnt_noun
            cnt_noun += 1

    mat_iid = np.zeros([cnt_verb, cnt_noun])
    mat_ood = np.zeros([cnt_verb, cnt_noun])

    for key, value in verb_noun_iid.items():
        idx_verb, idx_noun = dict_verb[key[0]], dict_noun[key[1]]
        mat_iid[idx_verb][idx_noun] = len(value)

    for key, value in verb_noun_ood.items():
        idx_verb, idx_noun = dict_verb[key[0]], dict_noun[key[1]]
        mat_ood[idx_verb][idx_noun] = len(value)

    verblist = dict_verb.keys()
    nounlist = dict_noun.keys()

    mat_split = mat_iid - mat_ood

    plt.rcParams["figure.figsize"] = (cnt_noun, cnt_verb)

    rdgn = sns.diverging_palette(h_neg=10, h_pos=240, sep=1, as_cmap=True)
    ax = sns.heatmap(mat_split, cmap=rdgn, vmin=-3, vmax=3, square=True, cbar=False, xticklabels=nounlist, yticklabels=verblist)

    ntick = max(len(verblist), len(nounlist))
    fs = 18 - 0.125 * ntick

    ax.set_xticklabels(ax.get_xticklabels(), rotation=90, horizontalalignment='center', fontsize=fs)
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=fs)
    plt.xlabel('object class', fontsize=fs)  # x-axis label with fontsize 15
    plt.ylabel('action class', fontsize=fs)  # y-axis label with fontsize 15

    if save:
        plt.savefig(figname, bbox_inches='tight', dpi=300)
    else:
        return ax