import json
import numpy as np
import seaborn as sns
sns.set()
import matplotlib.pyplot as plt
import pickle
from copy import deepcopy
from common import args


def load_fold(fold_f, idx, type):
    gt_info = json.load(open(args.dataset_dir + '/cond2monomer_gt_info.json', 'r'))
    gt = []
    for repeat_unit in gt_info:
        if gt_info[repeat_unit]['group'] == idx:
            gt.append(gt_info[repeat_unit][type])

    k_list = [1,2,5,10,15,20,25,30,35,40,45,50]

    predictions = []
    pred = []
    with open(fold_f, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip('\n')
            pred.append(line)
            if len(pred) == 50:
                predictions.append(pred)
                pred = []
    assert pred == []
    assert len(predictions) == len(gt)

    result = {}
    for k in k_list:
        recovered = [gt[i] in predictions[i][:k] for i in range(len(predictions))]
        result[k] = np.array(recovered).mean()

    return result


def plot_line_charts():
    results = {}
    for alg in ['polyretro', 'uspto', 'random']:
        f = args.final_result_dir + '/%s_real.pkl' % alg
        print('loading %s results from %s' % (alg, f))
        results[alg] = pickle.load(open(f, 'rb'))

    results['Transformer'] = deepcopy(results['random'])
    for k in results['Transformer']['summary'].keys():
        for i in results['Transformer']['summary'][k]['recovered'].keys():
            results['Transformer']['summary'][k]['recovered'][i] = 0.
        for i in results['Transformer']['summary'][k]['monomer_recovered'].keys():
            results['Transformer']['summary'][k]['monomer_recovered'][i] = 0.

    f = args.final_result_dir + '/upperbound_real.pkl'
    print('loading upperbound results from %s' % f)
    upperbounds = pickle.load(open(f, 'rb'))

    names = {
        'polyretro': 'PolyRetro',
        'uspto': 'PolyRetro-USPTO',
        'random': 'Random Proposal',
        'Transformer': 'Transformer'
    }

    for alg in ['polyretro', 'uspto', 'random', 'Transformer']:
        alg_result = results[alg]['summary']
        x, y = [0]*5, [0]*5
        for k in alg_result:
            x.extend([k] * len(alg_result[k]['recovered']))
            y.extend([alg_result[k]['recovered'][g]*100 for g in alg_result[k]['recovered']])

        ax = sns.lineplot(x=x, y=y, label=names[alg])

    plt.xlabel('Top-k')
    plt.ylabel('% recovered')
    plt.title('% of unit polymers recovered in top-k prediction')
    # ax.set_aspect(aspect=40)
    # plt.show()
    fig = args.final_result_dir + '/unit-polymer-recovery.pdf'
    print('saving line chart to %s' % fig)
    plt.savefig(fig)
    plt.clf()


    for alg in ['polyretro', 'uspto', 'random', 'Transformer']:
        alg_result = results[alg]['summary']
        x, y = [0]*5, [0]*5
        for k in alg_result:
            x.extend([k] * len(alg_result[k]['monomer_recovered']))
            y.extend([alg_result[k]['monomer_recovered'][g]*100 for g in alg_result[k]['monomer_recovered']])

        ax = sns.lineplot(x=x, y=y, label=names[alg])

    ax.axhline(upperbounds['full_units']*100, ls='--')
    ax.text(0, 72, "Recursive")
    ax.axhline(upperbounds['monomers']*100, ls='--')
    ax.text(0, 66, "Stability")
    ax.axhline(upperbounds['synthesizability']*100, ls='--')
    ax.text(0, 58, "Synthesizability")
    plt.xlabel('Top-k')
    plt.ylabel('% recovered')
    plt.title('% of monomers recovered in top-k prediction')
    # ax.set_aspect(aspect=0.8)
    fig = args.final_result_dir + '/monomer-recovery.pdf'
    print('saving line chart to %s' % fig)
    plt.savefig(fig)
    # plt.show()


if __name__ == '__main__':
    plot_line_charts()
    # process_baseline_result()