import argparse
import os
import pickle
import numpy as np
import matplotlib; matplotlib.use('Agg')
from matplotlib import pyplot as plt
import seaborn
matplotlib.rcParams['ps.useafm'] = True
matplotlib.rcParams['pdf.use14corefonts'] = True
matplotlib.rcParams['text.usetex'] = True

parser = argparse.ArgumentParser()
parser.add_argument('-p', type=str, default='./fpovi/toy_results_linear')
parser.add_argument('-figs', type=str, default='22')
args = parser.parse_args([])
path = args.p


def plot(preds, xs, ys, alpha=0.2, s=None, mode='ci', base=1, plotMean=True, extra_sd=0):
    if not s:
        s = 0.2/preds.shape[0]
    if mode == 'ci':
        mean_ = np.mean(preds, axis=0)
        sd = np.sqrt(np.mean((preds-mean_.reshape((1,-1)))**2, axis=0))
        plt.plot(xs.reshape((-1, )), mean_)
        wsd = 1.96 * (sd**2 + extra_sd**2)**0.5
        plt.fill_between(xs.reshape((-1,)), mean_-wsd, mean_+wsd, facecolor='lightblue',
                         alpha=0.15, interpolate=True)
        plt.fill_between(xs.reshape((-1,)), mean_-1.96*sd, mean_+1.96*sd, facecolor='blue',
                         alpha=0.2, interpolate=True)

    elif mode == 'scatter':
        xs = np.tile(xs.reshape((1, -1)), [preds.shape[0], 1])
        plt.scatter(xs, preds, alpha=alpha, s=s)
    elif mode == 'groundtruth':
        mean_ = np.mean(preds, axis=0)
        sd = np.sqrt(np.mean((preds-mean_.reshape((1,-1)))**2, axis=0))
        plt.plot(xs.reshape((-1, )), mean_-3*sd, linestyle='dashed', c='red', alpha=0.4)
        plt.plot(xs.reshape((-1, )), mean_+3*sd, linestyle='dashed', c='red', alpha=0.4)
    else:
        raise NotImplementedError()

        
def sine_fn(x):
    return x+1#np.sin(4*x) + np.sin(13*x)


def plot_sine(fil, xlimmax=15.5):
    fil = os.path.join(path, fil)
    plt.xlim(-15.5, xlimmax)
    plt.ylim(-20, 20)
    pred, ylogstd, xs, ys, xtrs, ytrs = pickle.load(open(fil, 'rb'))
    plot(pred, xs, ys, alpha=0.5, mode='ci', extra_sd=1.**0.5)
    plt.scatter(xtrs, ytrs, s=5, c='red', marker='+')


if args.figs == '14':

    plt.figure(figsize=(10, 5), facecolor='w')
    plt.subplot(231)
    plt.title('w-SGLD')
    plot_sine('wsgld.bin')
    plt.subplot(234)
    plt.title('f-wSGLD')
    plot_sine('fwsgld.bin')
    plt.subplot(232)
    plt.title('pi-SGLD')
    plot_sine('pisgld.bin')
    plt.subplot(235)
    plt.title('f-piSGLD')
    plot_sine('fpisgld.bin')
    plt.subplot(233)
    plt.title('HMC')
    plot_sine('hmc.bin')
    plt.subplot(236)
    plt.title('Cov')
    plot_sine('fgfsf.bin')
    plt.subplots_adjust(hspace=0.3)
    plt.savefig(os.path.join(path, 'fig4.png'))
    
elif args.figs == '17':

    plt.figure(figsize=(10, 5), facecolor='w')
    plt.subplot(241)
    plt.title('MAP(Ensemble)')
    plot_sine('map.bin')
    plt.subplot(245)
    plt.title('HMC')
    plot_sine('hmc.bin')
    plt.subplot(242)
    plt.title('SVGD')
    plot_sine('svgd.bin')
    plt.subplot(246)
    plt.title('$\mathrm{PAC}^2_E$')
    plot_sine('pred3.bin')
    plt.subplot(247)
    plt.title('Cov(Proposed)')
    #plot_sine('cov_not_S_log3.bin')
    plot_sine('cov_prop.bin')
    plt.subplot(243)
    plt.title('GFSF')
    plot_sine('gfsf.bin')
    plt.subplot(244)
    plt.title('f-svgd')
    plot_sine('fsvgd.bin')
    plt.subplot(248)
    plt.title('Cov-svgd(Proposed)')
    #plot_sine('cov_not_S_log3_svgd.bin')
    plot_sine('cov_prop_svgd.bin')
    plt.subplots_adjust(hspace=0.3)
    #plt.show()
    plt.tight_layout()
    plt.savefig(os.path.join(path, 'fig7.eps'))
        

elif args.figs == '18':

    plt.figure(figsize=(13, 5), facecolor='w')
    plt.subplot(251)
    plt.title('MAP(Ensemble)')
    plot_sine('map.bin')
    plt.subplot(256)
    plt.title('HMC')
    plot_sine('hmc.bin')
    plt.subplot(252)
    plt.title('SVGD')
    plot_sine('svgd.bin')
    plt.subplot(257)
    plt.title('GFSF')
    plot_sine('gfsf.bin')
    plt.subplot(253)
    plt.title('f-svgd')
    plot_sine('fsvgd.bin')    
    plt.subplot(258)
    plt.title('$\mathrm{PAC}^2_E$')
    plot_sine('pred3.bin')
    plt.subplot(254)
    plt.title('Cov($\mathrm{h}_m$)(Proposed)')
    #plot_sine('cov_not_S_log3.bin')
    plot_sine('cov_prop.bin')
    plt.subplot(259)
    plt.title('Cov-svgd($\mathrm{h}_m$)(Proposed)')
    #plot_sine('cov_not_S_log3_svgd.bin')
    plot_sine('cov_prop_svgd.bin')

    plt.subplot(255)
    plt.title('Cov(h)(Proposed)')
    plot_sine('cov_not_S_log3.bin')
    plt.subplot(2,5,10)
    plt.title('Cov-svgd(h)(Proposed)')
    plot_sine('cov_not_S_log3_svgd.bin')

    plt.subplots_adjust(hspace=0.3)
    #plt.show()
    plt.tight_layout()
    plt.savefig(os.path.join(path, 'fig7.eps'))

elif args.figs == '19':

    plt.figure(figsize=(13, 5), facecolor='w')
    plt.subplot(241)
    plt.title('MAP(Ensemble)')
    plot_sine('map.bin')
    plt.subplot(245)
    plt.title('HMC')
    plot_sine('hmc.bin')
    plt.subplot(242)
    plt.title('SVGD')
    plot_sine('svgd.bin')
    plt.subplot(246)
    plt.title('GFSF')
    plot_sine('gfsf.bin')
    plt.subplot(243)
    plt.title('f-SVGD')
    plot_sine('fsvgd.bin')    
    plt.subplot(247)
    plt.title('$\mathrm{PAC}^2_\mathrm{E}$')
    plot_sine('pred3.bin')
    plt.subplot(248)
    plt.title('VAR($\mathrm{h}_m$)(Proposed)')
    #plot_sine('cov_not_S_log3.bin')
    plot_sine('cov_prop.bin')
    plt.subplot(244)
    plt.title('VAR(h)(Proposed)')
    plot_sine('cov_not_S_log3.bin')


    plt.subplots_adjust(hspace=0.3)
    #plt.show()
    plt.tight_layout()
    plt.savefig(os.path.join(path, 'fig8.eps'))

elif args.figs == '22':

    plt.figure(figsize=(13, 5), facecolor='w')
    plt.subplot(162)
    plt.title('MAP(Ensemble)')
    plot_sine('map.bin')
    plt.subplot(161)
    plt.title('HMC')
    plot_sine('hmc.bin')
    plt.subplot(163)
    plt.title('SVGD')
    plot_sine('svgd.bin')
    plt.subplot(164)
    plt.title('$\mathrm{PAC}^2_\mathrm{E}$')
    plot_sine('cov_pred.bin')
    plt.subplot(165)
    plt.title('VAR($\mathrm{h}$)')
    #plot_sine('cov_not_S_log3.bin')
    plot_sine('cov_prop.bin')
    plt.subplot(166)
    plt.title('VAR-svgd($\mathrm{h}$)')
    #plot_sine('cov_not_S_log3.bin')
    plot_sine('cov_prop_sv.bin')



    plt.subplots_adjust(hspace=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(path, 'fig10.eps'))

elif args.figs == '5':
    plt.figure(figsize=(12, 6), facecolor='w')
    for i, v in enumerate([5, 10, 50, 100]):
        plt.subplot(2, 4, i+5)
        if i != 0:
            plt.yticks([])
        plt.title('function space, n='+str(v))
        plot_sine('fsvgd{}.bin'.format(v))
        plt.subplot(2, 4, i+1)
        plt.xticks([])
        if i != 0:
            plt.yticks([])
        plt.title('weight space, n='+str(v))
        plot_sine('svgd{}.bin'.format(v))
    plt.savefig(os.path.join(path, 'fig5.png'))

else:
    raise NotImplemented()
