"""Logs language as table"""

import pandas as pd
import matplotlib.pyplot as plt
from ncc import compositionality_metrics


def cross_table(protocol):
    """Logs language as nice table"""
    df = pd.DataFrame(columns=['l1', 'l2', 'msg'])
    for labels, msg in protocol:
        assert len(labels) == 2
        labels = [l.split('=')[1] for l in labels]
        df = df.append({'l1':labels[0], 'l2':labels[1], 'msg':msg},
                       ignore_index=True)
    df2 = df.groupby(['l1', 'l2']).agg(
     lambda x: x.value_counts(normalize=False).to_dict()).unstack()
    fig, ax = plt.subplots(figsize=(20, 5))
    ax.axis('off')
    df_color = df2.copy()
    df_color.loc[:, :] = 'xkcd:silver'
    #for c,s in test_labels:
    #    df_color.at[c,s]='xkcd:azure'

    t = pd.plotting.table(ax, df2, bbox=[0, 0, 1, 1],
                      rowColours=['xkcd:silver']*df2.shape[0],
                        colColours=['xkcd:silver']*df2.shape[1],
                        cellColours=df_color.values
    )
    t.auto_set_font_size(False)
    t.set_fontsize(12)
    #fig.suptitle(f'{name} epoch {epoch}')
    return fig


def main():
    proto = compositionality_metrics.get_protocol(
                        [(0, 0), (0, 0), (0, 1), (1, 0), (1, 1),
                          (1, 2), (2, 0), (2, 1), (2, 2)],
                         [(0, 0), (0, 0), (0, 0), (1, 0), (1, 1),
                          (1, 2), (2, 0), (2, 1), (2, 2)],
                         {'color':'ups',
                          'shape':'pups'})
    cross_table(proto)


if __name__ == '__main__':
    main()
