import json
import numpy as np
import matplotlib.pyplot as plt

names = ['i','lastl_norm','jac_repr','jac_all','robust_acc','transfer_acc']

def plot_fig(results, x_id, y_id, fig_name):
    xs = [r[x_id] for r in results]
    ys = [r[y_id] for r in results]
    fig = plt.figure()
    plt.plot(xs,ys,'o')
    plt.xlabel(names[x_id])
    plt.ylabel(names[y_id])
    fig.savefig(fig_name)
    plt.close(fig)

def main():
    #MODEL_NAME = 'jacLAST_0.003'
    ##MODEL_NAME = 'jacREPR_0.001'
    ##MODEL_NAME = 'jacALL_0.0001'
    #with open('figures/tune_%s.json'%MODEL_NAME) as inf:
    #    results = json.load(inf)
    #results = results[10:]
    #print (results)
    #plot_fig(results, 4,5, 'figures/tune-%s-rob-transf.pdf'%MODEL_NAME)
    #plot_fig(results, 3,4, 'figures/tune-%s-jacall-rob.pdf'%MODEL_NAME)
    #plot_fig(results, 3,5, 'figures/tune-%s-jacall-transf.pdf'%MODEL_NAME)
    #plot_fig(results, 2,4, 'figures/tune-%s-jacrepr-rob.pdf'%MODEL_NAME)
    #plot_fig(results, 2,5, 'figures/tune-%s-jacrepr-transf.pdf'%MODEL_NAME)
    #assert 0

    #with open('figures/tune_jacREPR_0.001.json') as inf:
    with open('figures/tune_alpha-approx-tune-3.0.json') as inf:
        repr_results = json.load(inf)
        repr_results = repr_results[20:]
        repr_results = [(x[0],x[2],x[3],x[4],x[5]*100,x[6]) for x in repr_results]
    #with open('figures/tune_jacALL_0.0001.json') as inf:
    with open('figures/tune_alpha-approxrepr-tune-10.0.json') as inf:
        all_results = json.load(inf)
        all_results = all_results[20:]
        all_results = [(x[0],x[2],x[3],x[4],x[5]*100,x[6]) for x in all_results]
    #xid,yid,fig_name = 4,5,'figures/tune-cmp-rob-transf.pdf'
    #xid,yid,fig_name = 3,4,'figures/tune-cmp-jacall-rob.pdf'
    #xid,yid,fig_name = 3,5,'figures/tune-cmp-jacall-transf.pdf'
    #xid,yid,fig_name = 2,4,'figures/tune-cmp-jacrepr-rob.pdf'
    xid,yid,fig_name = 2,5,'figures/tune-cmp-jacrepr-transf.pdf'
    fig = plt.figure()
    plt.plot([r[xid] for r in repr_results],[r[yid] for r in repr_results],'ro',label='JacREPR models')
    plt.plot([r[xid] for r in all_results],[r[yid] for r in all_results],'bo',label='JacALL models')
    plt.xlabel(names[xid])
    plt.ylabel(names[yid])
    plt.legend()
    fig.savefig(fig_name)
    plt.close(fig)

if __name__ == '__main__':
    main()
