#! python

import csv
import math
import os
import matplotlib.pyplot as plt
from pathlib import Path

evs = 4000

def script_dir():
    spath =  Path(os.path.abspath(__file__))
    sdir = spath.parent
    return str(sdir)

def root_dir():
    spath =  Path(os.path.abspath(__file__))
    #sdir = spath.parent
    #rdir = sdir.parent
    return str(spath) #(rdir)


def get_progress_csv_files(directory, funcname, algname):
    rs = []
    token = algname
    ftoken = funcname
    for fn in os.listdir(directory):
        if token in fn and ftoken in fn and fn.endswith('.csv'):
            fpath = os.path.join(directory, fn)
            rs.append(fpath)
            #print('\t\t->%s '%fn)
    #print('\t..%d csv files found for [%s] in directory: %s '%(len(rs), algname, directory))
    return rs



class AlgoData:
    def __init__(self, algname, funcname, csvlist, perc = '0.95', reverse_sign=False):
        self.alg = algname
        self.func = funcname
        self.prog_data = []
        
        self.avg=None
        self.std=None
        self.conf_itv = 0.0
        #print (':%d files for %s with %s...'%(len(csvlist),funcname, algname))
        for csvfn in csvlist:
            apath = []
            #print('opening: %s ...'%csvfn)
            with open(csvfn) as csvdata:
                csvrdr = csv.reader(csvdata, delimiter=',')
                for row in csvrdr:
                    v = float(row[0])
                    if reverse_sign:
                        v = -v
                    apath.append(v)
                    if len(apath) >= evs:
                        break

            #print('\t csv file [%s] : %d records '%(csvfn, len(apath)))
            if len(apath)>0 and len(apath)<evs:
                lv = apath[-1]
                for i in range(evs-len(apath)):
                    apath.append(lv)
                #print('\t\tappending to %d records'%evs)
            self.prog_data.append(apath)

        self._calc(perc = perc)

    def _calc_ci(self, perc, mean, stdev, n):
    	#ref https://www.mathsisfun.com/data/confidence-interval.html
        Z = 0.0
        if perc == '0.95':
            Z = 1.960
        elif perc == '0.99':
            Z = 2.576
        elif perc == '0.995':
            Z = 2.807
        elif perc == '0.999':
            Z = 3.291
        else:
            raise ValueError('Undefined CI level: %s'%(perc))

        return Z*stdev/math.sqrt(n)

    def _calc(self, perc = '0.95'):
        self.avg = []
        self.std = []
        self.conf_itv = []
        L=0
        for apth in self.prog_data:
            if len(apth) > L:
                L = len(apth)
        #print('\t> calc %s-CI with %d experiments for %s with %s ...'%(perc,len(self.prog_data), self.func, self.alg))
        for i in range(L):
            xs =[]
            x_sum = 0.0
            x_n =0
            for apth in self.prog_data:
                if len(apth) > i:
                    x_n += 1
                    x_sum += apth[i]
                    xs.append(apth[i])
            if x_n < 1:
                self.avg.append(0.0)
                self.std.append(0.0)
                self.conf_itv.append(0.0)
            else:
                x_avg = x_sum/x_n;
                self.avg.append(x_avg)
                vr = 0.0
                for x in xs:
                    vr += (x-x_avg)**2
                xstd = 0.0
                ci = 0.0
                if x_n > 1:
                    xstd = math.sqrt(vr/(x_n-1))

                
                ci = self._calc_ci(perc, x_avg, xstd, x_n)

                self.std.append(xstd)
                self.conf_itv.append(ci)

    def draw_curve(self, ax, linetype, linecolor):
        up = [self.avg[i]+self.conf_itv[i] for i in range(len(self.avg))]
        down = [self.avg[i]-self.conf_itv[i] for i in range(len(self.avg))]
        iters = [i for i in range(len(self.avg))]
        ax.plot(iters, self.avg, color=linecolor, label=self.alg)
        ax.fill_between(iters, down, up, color=linecolor, alpha=0.1)
        

class FuncData:
    def __init__(self, funcname):
        self.func = funcname
        self.algdata = {}
        self.colors = {}

    def append(self, algo_data):
        algo_data.func = self.func
        self.algdata[algo_data.alg] = algo_data

    def set_color(self, alg, color):
        self.colors[alg] = color

    def draw(self, figfn, alg_seq=None, ybottom=None, ytop=None, figsize=(8,2)):
        fig,ax = plt.subplots(1,1,figsize=figsize)
        if ybottom is not None:
            ax.set_ylim(bottom=ybottom)
        if ytop is not None:
            ax.set_ylim(top=ytop)

        if alg_seq is None or len(alg_seq) < 1 :
            alg_seq = self.algdata.keys()
        for alg in alg_seq:
            acolor = self.colors[alg]
            linetype = '--'
            ad = self.algdata[alg]
            ad.draw_curve(ax,linetype, acolor)
        #ax.legend(loc = 'lower center', \
        #         ncol=3, bbox_to_anchor =(0.02,0.15,1,1),\
        #         bbox_transform = plt.gcf().transFigure)
        ax.legend(loc = 'lower center', ncol=3)
        ax.set_xlabel('Number of evaluations')
        ax.set_ylabel('Value')
        ax.set_title(self.func)
        fig.savefig(figfn, dpi=600)

    def sub_draw(self, ax, alg_seq=None, ybottom=None, ytop=None, title=None, xlabel=True):
        if ybottom is not None:
            ax.set_ylim(bottom=ybottom)
        if ytop is not None:
            ax.set_ylim(top=ytop)
        if title is None:
            title = self.func
        if alg_seq is None or len(alg_seq) < 1 :
            alg_seq = self.algdata.keys()
        for alg in alg_seq:
            acolor = self.colors[alg]
            linetype = '--'
            ad = self.algdata[alg]
            ad.draw_curve(ax,linetype, acolor)
        ax.set_ylabel('Value')  
        if xlabel:
            ax.set_xlabel('Number of evaluations')
        ax.set_title(title, x=0.5, y=0.825)  



class DataFiles:
    def __init__(self, directory, func_token, alg_token, alg_rename=None):
    	self.directory = directory
    	self.func_tk = func_token
    	self.alg_tk = alg_token
    	if alg_rename is None:
    		self.alg_show_name = alg_token
    	else:
    		self.alg_show_name = alg_rename
    
    def get_data(self):
    	fns = get_progress_csv_files(self.directory, self.func_tk, self.alg_tk)
    	d = AlgoData(self.alg_show_name, self.func_tk, fns)
    	return d   	


def draw(fname, ax=None, figsize=(8,2)):
    global evs
    funcs = ['push','rover','lunar']
    if fname not in funcs:
        raise ValueError('Unknown function [%s]'%fname)
    data_base =os.path.join(script_dir(), 'PRL')
    functoken = ''
    dim = 10
    rdir = root_dir()
    ybottom = -10.0
    ytop = 10.0
    ci_perc = '0.95'

    flabel = 'NONE'

    if fname == 'push':
        functoken = 'Robot-Push'
        ybottom = 0.0
        ytop = 13.0
        evs = 10000
        flabel = '14D Robot pushing'
        dim = 14
    elif fname == 'lunar':
        functoken = 'Lunar'
        ybottom = 0.0
        ytop = 350.0
        evs = 1500
        flabel = '12D Lunar landing'
        dim = 12
    elif fname == 'rover':
        functoken = 'Rover'
        ybottom = -6.0
        ytop = 6.5
        evs = 20000
        flabel = '60D Rover trajectory planning'
        dim = 60
    else:
        raise ValueError('Should not have arrived here!!!')

    fdata = FuncData(functoken) #('%dD %s function'%(dim,functoken))

    #BOHB, CobBO, TuRBO, Shiwa, TPE, ATPE, CMAES
    data_base  = os.path.join(script_dir(), 'PRL')
    shw_ftoken = functoken.replace('D','')
    
    bohb_csvs  = get_progress_csv_files(data_base+'/bohb',  functoken, 'BOHB')
    cobbo_csvs = get_progress_csv_files(data_base+'/cobbo', functoken, 'CobBO')
    turbo_csvs = get_progress_csv_files(data_base+'/turbo', functoken, 'TuRBO')
    tpe_csvs   = get_progress_csv_files(data_base+'/tpe',   functoken, '_TPE_')
    atpe_csvs  = get_progress_csv_files(data_base+'/atpe',  functoken, 'ATPE')
    cmaes_csvs = get_progress_csv_files(data_base+'/cmaes', functoken, 'CMAES')
    diffev_csvs= get_progress_csv_files(data_base+'/diffevo',functoken, 'DiffEvo')

    bohb_ad  = AlgoData('BOHB',  flabel, bohb_csvs,  perc=ci_perc)
    cobbo_ad = AlgoData('CobBO', flabel, cobbo_csvs, perc=ci_perc)
    turbo_ad = AlgoData('TuRBO', flabel, turbo_csvs, perc=ci_perc)
    tpe_ad   = AlgoData('TPE',   flabel, tpe_csvs,   perc=ci_perc)
    atpe_ad  = AlgoData('ATPE',  flabel, atpe_csvs,  perc=ci_perc)
    cmaes_ad = AlgoData('CMAES', flabel, cmaes_csvs, perc=ci_perc)
    diffev_ad= AlgoData('Diff-Evo', flabel, diffev_csvs, perc=ci_perc)

    fdata.append(bohb_ad)
    fdata.append(cobbo_ad)
    fdata.append(turbo_ad)
    fdata.append(tpe_ad)

    if dim < 20:
        fdata.append(atpe_ad)
    
    fdata.append(cmaes_ad)
    #fdata.append(diffev_ad)
    
    fdata.set_color(bohb_ad.alg, 'pink')
    fdata.set_color(cobbo_ad.alg, 'black')
    fdata.set_color(turbo_ad.alg, 'blue')
    fdata.set_color(tpe_ad.alg, 'orange')
    
    if dim < 20:
        fdata.set_color(atpe_ad.alg, 'red')
    
    fdata.set_color(cmaes_ad.alg, 'fuchsia')
    #fdata.set_color(diffev_ad.alg, 'purple')
    

    if ax is None:
        figname = '%s-%dd.png'%(fname,dim)
        fdata.draw(figname,ybottom=ybottom, ytop=ytop,figsize=figsize)
    else:
        fdata.sub_draw(ax,ybottom=ybottom, ytop=ytop, xlabel=True)


if __name__ == '__main__':
    funcs = ['push','rover','lunar']

    for fn in funcs:
        print("Drawing ",fn," ...")
        draw(fn,figsize=(6.0,3.0))



