import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

query_models = ['LR', 'RBFSVM', 'RandomForest']
task_models = ['LR', 'RBFSVM', 'RandomForest']
datasets = ["appendicitis", "sonar", "parkinsons", "ex8b", "heart", "haberman", "ionosphere", "clean1", "breast", "wdbc", "australian", "diabetes", "mammographic", "ex8a", "tic", "german", "splice", "gcloudb", "gcloudub", "checkerboard", "spambase", "banana", "phoneme", "ringnorm", "twonorm", "phishing"]
aubc_avg_3x3 = {}
aubc_std_3x3 = {}

# reported values of mean and std aubc for different models on different datasets
reported_aubc_avg = {
    ('parkinsons', 'LR', 'LR'): 0.8366,
    ('parkinsons', 'LR', 'RBFSVM'): 0.8467,
    ('parkinsons', 'LR', 'RandomForest'): 0.8446,
    ('parkinsons', 'RBFSVM', 'LR'): 0.8440,
    ('parkinsons', 'RBFSVM', 'RandomForest'): 0.8603,
    ('parkinsons', 'RandomForest', 'LR'): 0.8259,
    ('parkinsons', 'RandomForest', 'RBFSVM'): 0.8491,
    ('ex8b', 'LR', 'LR'): 0.9067,
    ('ex8b', 'LR', 'RBFSVM'): 0.8920,
    ('ex8b', 'LR', 'RandomForest'): 0.8675,
    ('ex8b', 'RBFSVM', 'LR'): 0.9056,
    ('ex8b', 'RBFSVM', 'RandomForest'): 0.8692,
    ('ex8b', 'RandomForest', 'LR'): 0.9032,
    ('ex8b', 'RandomForest', 'RBFSVM'): 0.8965,
    ('heart', 'LR', 'LR'): 0.8165,
    ('heart', 'LR', 'RBFSVM'): 0.8098,
    ('heart', 'LR', 'RandomForest'): 0.8030,
    ('heart', 'RBFSVM', 'LR'): 0.8174,
    ('heart', 'RBFSVM', 'RandomForest'): 0.8067,
    ('heart', 'RandomForest', 'LR'): 0.8149,
    ('heart', 'RandomForest', 'RBFSVM'): 0.8121,
    ('clean1', 'LR', 'LR'): 0.7949,
    ('clean1', 'LR', 'RBFSVM'): 0.8287,
    ('clean1', 'LR', 'RandomForest'): 0.8072,
    ('clean1', 'RBFSVM', 'LR'): 0.7857,
    ('clean1', 'RBFSVM', 'RandomForest'): 0.8165,
    ('clean1', 'RandomForest', 'LR'): 0.7875,
    ('clean1', 'RandomForest', 'RBFSVM'): 0.8387,
    ('wdbc', 'LR', 'LR'): 0.9703,
    ('wdbc', 'LR', 'RBFSVM'): 0.9581,
    ('wdbc', 'LR', 'RandomForest'): 0.9524,
    ('wdbc', 'RBFSVM', 'LR'): 0.9698,
    ('wdbc', 'RBFSVM', 'RandomForest'): 0.9518,
    ('wdbc', 'RandomForest', 'LR'): 0.9692,
    ('wdbc', 'RandomForest', 'RBFSVM'): 0.9622,
    ('australian', 'LR', 'LR'): 0.8542,
    ('australian', 'LR', 'RBFSVM'): 0.8450,
    ('australian', 'LR', 'RandomForest'): 0.8581,
    ('australian', 'RBFSVM', 'LR'): 0.8542,
    ('australian', 'RBFSVM', 'RandomForest'): 0.8587,
    ('australian', 'RandomForest', 'LR'): 0.8541,
    ('australian', 'RandomForest', 'RBFSVM'): 0.8472,
    ('diabetes', 'LR', 'LR'): 0.7606,
    ('diabetes', 'LR', 'RBFSVM'): 0.7418,
    ('diabetes', 'LR', 'RandomForest'): 0.7423,
    ('diabetes', 'RBFSVM', 'LR'): 0.7592,
    ('diabetes', 'RBFSVM', 'RandomForest'): 0.7442,
    ('diabetes', 'RandomForest', 'LR'): 0.7598,
    ('diabetes', 'RandomForest', 'RBFSVM'): 0.7450,
    ('mammographic', 'LR', 'LR'): 0.8287,
    ('mammographic', 'LR', 'RBFSVM'): 0.8117,
    ('mammographic', 'LR', 'RandomForest'): 0.7984,
    ('mammographic', 'RBFSVM', 'LR'): 0.8223,
    ('mammographic', 'RBFSVM', 'RandomForest'): 0.8018,
    ('mammographic', 'RandomForest', 'LR'): 0.8222,
    ('mammographic', 'RandomForest', 'RBFSVM'): 0.8145,
    ('ex8a', 'LR', 'LR'): 0.6889,
    ('ex8a', 'LR', 'RBFSVM'): 0.8251,
    ('ex8a', 'LR', 'RandomForest'): 0.8906,
    ('ex8a', 'RBFSVM', 'LR'): 0.6789,
    ('ex8a', 'RBFSVM', 'RandomForest'): 0.9280,
    ('ex8a', 'RandomForest', 'LR'): 0.6733,
    ('ex8a', 'RandomForest', 'RBFSVM'): 0.8623,
    ('german', 'LR', 'LR'): 0.7449,
    ('german', 'LR', 'RBFSVM'): 0.7394,
    ('german', 'LR', 'RandomForest'): 0.7440,
    ('german', 'RBFSVM', 'LR'): 0.7462,
    ('german', 'RBFSVM', 'RandomForest'): 0.7459,
    ('german', 'RandomForest', 'LR'): 0.7449,
    ('german', 'RandomForest', 'RBFSVM'): 0.7409,
    ('splice', 'LR', 'LR'): 0.7694,
    ('splice', 'LR', 'RBFSVM'): 0.8125,
    ('splice', 'LR', 'RandomForest'): 0.9136,
    ('splice', 'RBFSVM', 'LR'): 0.7676,
    ('splice', 'RBFSVM', 'RandomForest'): 0.9181,
    ('splice', 'RandomForest', 'LR'): 0.7621,
    ('splice', 'RandomForest', 'RBFSVM'): 0.8159,
    ('gcloudub', 'LR', 'LR'): 0.9569,
    ('gcloudub', 'LR', 'RBFSVM'): 0.9413,
    ('gcloudub', 'LR', 'RandomForest'): 0.9355,
    ('gcloudub', 'RBFSVM', 'LR'): 0.9538,
    ('gcloudub', 'RBFSVM', 'RandomForest'): 0.9468,
    ('gcloudub', 'RandomForest', 'LR'): 0.9542,
    ('gcloudub', 'RandomForest', 'RBFSVM'): 0.9508,
    ('spambase', 'LR', 'LR'): 0.9224,
    ('spambase', 'LR', 'RBFSVM'): 0.9037,
    ('spambase', 'LR', 'RandomForest'): 0.9358,
    ('spambase', 'RBFSVM', 'LR'): 0.9170,
    ('spambase', 'RBFSVM', 'RandomForest'): 0.9398,
    ('spambase', 'RandomForest', 'LR'): 0.9120,
    ('spambase', 'RandomForest', 'RBFSVM'): 0.9134,
    ('phoneme', 'LR', 'LR'): 0.7566,
    ('phoneme', 'LR', 'RBFSVM'): 0.8233,
    ('phoneme', 'LR', 'RandomForest'): 0.8699,
    ('phoneme', 'RBFSVM', 'LR'): 0.7161,
    ('phoneme', 'RBFSVM', 'RandomForest'): 0.8802,
    ('phoneme', 'RandomForest', 'LR'): 0.7325,
    ('phoneme', 'RandomForest', 'RBFSVM'): 0.8293,
    ('ringnorm', 'LR', 'LR'): 0.7580,
    ('ringnorm', 'LR', 'RBFSVM'): 0.9766,
    ('ringnorm', 'LR', 'RandomForest'): 0.9442,
    ('ringnorm', 'RBFSVM', 'LR'): 0.7499,
    ('ringnorm', 'RBFSVM', 'RandomForest'): 0.9454,
    ('ringnorm', 'RandomForest', 'LR'): 0.6833,
    ('ringnorm', 'RandomForest', 'RBFSVM'): 0.9780,
    ('phishing', 'LR', 'LR'): 0.9279,
    ('phishing', 'LR', 'RBFSVM'): 0.9380,
    ('phishing', 'LR', 'RandomForest'): 0.9587,
    ('phishing', 'RBFSVM', 'LR'): 0.9268,
    ('phishing', 'RBFSVM', 'RandomForest'): 0.9652,
    ('phishing', 'RandomForest', 'LR'): 0.9257,
    ('phishing', 'RandomForest', 'RBFSVM'): 0.9439,
}

for data in datasets:
    aubc_avg_3x3[data] = pd.DataFrame(index=query_models, columns=task_models)
    aubc_std_3x3[data] = pd.DataFrame(index=query_models, columns=task_models)
    for q_model in query_models:
        for t_model in task_models:
            if q_model == 'RBFSVM' and t_model == 'RBFSVM':
                aubc = pd.read_csv(f'aubc_detail-model_compatibility/{data}-margin-zhan-google-zhan-zhan-RS_noFix_scale-aubc.csv')
                aubc_avg = aubc['res_tst_score'].mean()
                aubc_std = aubc['res_tst_score'].std()
                aubc_avg_3x3[data].loc[q_model, t_model] = aubc_avg
                aubc_std_3x3[data].loc[q_model, t_model] = aubc_std
            else:
                try:
                    aubc = pd.read_csv(f'aubc_detail-model_compatibility/{data}-margin-zhan-{q_model}-{t_model}-RS_noFix_scale-aubc.csv')
                    aubc_avg = aubc['res_tst_score'].mean()
                    aubc_std = aubc['res_tst_score'].std()
                    aubc_avg_3x3[data].loc[q_model, t_model] = aubc_avg
                    aubc_std_3x3[data].loc[q_model, t_model] = aubc_std
                except:
                    # print(f'Error reading {data}-margin-zhan-{q_model}-{t_model}-RS_noFix_scale-aubc.csv')
                    # use reported values
                    aubc_avg = reported_aubc_avg.get((data, q_model, t_model))
                    aubc_avg_3x3[data].loc[q_model, t_model] = aubc_avg
                    aubc_std_3x3[data].loc[q_model, t_model] = 0

def heatmap(data, row_labels, col_labels, ax=None,
            cbar_kw=None, cbarlabel="", **kwargs):
    if ax is None:
        ax = plt.gca()

    if cbar_kw is None:
        cbar_kw = {}

    # Plot the heatmap
    im = ax.imshow(data, **kwargs)

    # Create colorbar
    #cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    #cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
    cbar = None
    cbar = None

    # Show all ticks and label them with the respective list entries.
    ax.set_xticks(np.arange(data.shape[1]), labels=col_labels, fontsize=18)
    ax.set_yticks(np.arange(data.shape[0]), labels=row_labels, fontsize=18)

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
    #plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
    #         rotation_mode="anchor")

    # Turn spines off and create white grid.
    ax.spines[:].set_visible(False)

    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im, cbar

def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
                     textcolors=("black", "white"),
                     threshold=None, **textkw):
    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # Normalize the threshold to the images color range.
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max())/2.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center",
              verticalalignment="center")
    kw.update(textkw)

    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts

for data in aubc_avg_3x3:
    print(data)
    print(aubc_avg_3x3[data])
    bst_q_model = aubc_avg_3x3[data].astype(float).max(axis=1).idxmax()
    bst_t_model = aubc_avg_3x3[data].astype(float).max().idxmax()
    assert aubc_avg_3x3[data].loc[bst_q_model, bst_t_model] == aubc_avg_3x3[data].astype(float).max().max()
    print(f'Query model: {bst_q_model: <12} X Task model: {bst_t_model}')
    fig, ax = plt.subplots(figsize=(12,8))
    plt.rcParams.update({'font.size': 24})
    values = aubc_avg_3x3[data].values.astype(float)
    idx_name = ['LR(C=1)', 'SVM(RBF)', 'RF'] # aubc_avg_3x3[data].index
    col_name = ['LR(C=1)', 'SVM(RBF)', 'RF'] # aubc_avg_3x3[data].columns
    im, _ = heatmap(values, idx_name, col_name, ax=ax, cmap="YlGn")
    texts = annotate_heatmap(im, valfmt="{x:.2%}")

    #ax.set_xticks(np.arange(len(idx_name)), labels=idx_name, fontsize=18)
    #ax.xaxis.tick_top()
    #ax.set_yticks(np.arange(len(col_name)), labels=col_name, fontsize=18)
    #for i in range(len(idx_name)):
    #    for j in range(len(col_name)):
    #        text = ax.text(j, i, f'{values[i, j]: .2%}', ha="center", va="center")

    export_name = f'diffmodels-{data}'
    fig.tight_layout()
    plt.savefig(f'images/{export_name}.png', bbox_inches='tight')
    plt.clf()
