#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps 
import matplotlib
matplotlib.use('agg')
from python.active_bakeoff_setting import setting, fname

resdf = pd.read_csv(fname)
#print("Jettisoned big N.")
#resdf = resdf[resdf['N']<1e6]

Ps = sorted(list(set(resdf['P'])))
ests = sorted(list(set(resdf['est'])))
#Ds = sorted(list(set(resdf['D'])))

#cm = colormaps['autumn']
cm = colormaps['spring']

ncol = 3
nrow = int(np.ceil(len(Ps)/3))
ncol *= 2

cols = {'gp':'red','pra':'blue','tree':'orange','dasm':'cyan'}
label_pretty = {'tree': 'TBAS', 'gp' : 'GP', 'pra': 'PRA', 'dasm' : 'DASM'}

## Plot the two individually.
pn = {'IP':'Error','Time':'Time'}
#fig = plt.figure(figsize=[12,2])
fig,axes = plt.subplots(ncols=ncol, figsize=[12,2])
ai = -1
for target in ['IP','Time']:
    for pi, P in enumerate(Ps):
        ai += 1
        cdf = resdf[resdf['P']==P]
        #plt.subplot(nrow,ncol,pi+1+3*(target=='Time'))

        # Group by estimator.
        cdfi = cdf.groupby(['N','P','est'], as_index = False).mean()
        cdfi_lb = cdf.groupby(['N','P','est'], as_index = False).quantile(0.25)
        cdfi_ub = cdf.groupby(['N','P','est'], as_index = False).quantile(0.75)

        sign = 1 if target=='Time' else -np.pi/2
        if target=='Time':
            #trans = lambda x: x  
            trans = lambda x: np.log10(x)  
        else:
            trans = lambda x: np.pi/2*(1-x)

        #cols = cdf['D'] / max(cdf['D'])
        for est in ests:
            dcdf = cdfi[cdfi['est']==est]
            dcdf_lb = cdfi_lb[cdfi_lb['est']==est]
            dcdf_ub = cdfi_ub[cdfi_ub['est']==est]

            label = label_pretty[est] if ai==0 else None
            axes[ai].plot(dcdf['N'], trans(dcdf[target]), c = cols[est], label = label)
            axes[ai].plot(dcdf_lb['N'], trans(dcdf_lb[target]), c = cols[est], linestyle='--', alpha = 0.5)
            axes[ai].plot(dcdf_ub['N'], trans(dcdf_ub[target]), c = cols[est], linestyle='--', alpha = 0.5)

        axes[ai].set_title(pn[target]+' P='+str(int(P)))
        axes[ai].set_xlabel("N")
        if target=='IP':
            #axes[ai].set_ylim(-1,0)
            #if pi==0:
            axes[ai].set_ylabel("Radians")
        #plt.xscale('log')
        if target=='Time':
            #axes[ai].set_yscale('log')
            #if pi==0:
            axes[ai].set_ylabel("log Seconds", labelpad = -1)
        if setting=='big':
            axes[ai].set_xscale('log')
plt.tight_layout()
fig.legend(loc=7, prop = {'size':12})
fig.subplots_adjust(right=0.90)
#axes[ai].legend(prop={'size':5})
plt.subplots_adjust(wspace=0.4, hspace=0)
add = '_big' if setting=='big' else ''
plt.savefig('active_bakeoff_'+add+'.pdf')
plt.close()


## Compare the two.
#fig = plt.figure(figsize=[3*ncol,3*nrow])
ncol = 3
fig,axes = plt.subplots(ncols=ncol, figsize=[3*ncol,3*nrow])
for pi, P in enumerate(Ps):
    cdf = resdf[resdf['P']==P]
    #plt.subplot(nrow,ncol,pi+1)

    # Group by estimator.
    cdfi = cdf.groupby(['N','P','est'], as_index = False).mean()

    #cols = cdf['D'] / max(cdf['D'])
    for est in ests:
        dcdf = cdfi[cdfi['est']==est]
        #plt.scatter(dcdf['Time'],-dcdf['IP'], c = cols[est], label = est)
        lab = label_pretty[est] if pi==0 else None
        axes[pi].scatter(dcdf['Time'],np.pi/2*(1-dcdf['IP']), c = cols[est], label = lab)

    axes[pi].set_xlabel("Execution Time (s)")
    if pi==0:
        axes[pi].set_ylabel("Error (Radians)")
    axes[pi].set_title(' P='+str(int(P)))
    #axes[pi].set_legend(prop={'size':5})
    axes[pi].set_xscale('log')
fig.legend(loc=7, prop = {'size':12})
plt.tight_layout()
fig.subplots_adjust(right=0.85)
add = '_big' if setting=='big' else ''
plt.savefig('active_bakeoff_pareto'+add+'.pdf')
plt.close()
