import argparse
import os
import re
import subprocess
import sys
import time

import hashlib

import pickle

import numpy as np

import gzip

from rank_algos import inverse_significance

####################################################################################
# Plotting #########################################################################
####################################################################################

SEPARATE_LEGEND=True

import matplotlib.pyplot as plt
import matplotlib.ticker

# code for legend.pdf fails with MacOS backend currently, so I'm using agg.
matplotlib.use('Agg')
print("Backend: ", plt.get_backend())
import matplotlib.patches

import seaborn as sb

params = {
    'axes.labelsize': 32,
    'axes.titlesize': 32,
    'font.size': 32,
    'legend.fontsize': 10,
    'figure.subplot.wspace': 0.02,
    'xtick.labelsize': 24,
    'ytick.labelsize': 24,
    'text.usetex': False,
    'figure.figsize': [8,6]
}

plt.rcParams.update(params)
# plt.rc('font', size=16)
plt.rcParams['text.usetex'] = False
plt.rc('font', family='sans-serif')

# black borders
plt.rcParams['axes.edgecolor'] = "0.15"
plt.rcParams['axes.linewidth'] = 1.25

LINEWIDTH = 4
DPI = 200

palette = sb.color_palette("Set1", 10)

Styles = {
    'FastCB.L': (palette[0], 2),
    'SquareCB.L': (palette[1], 1),
    'Supervised.L': (sb.xkcd_rgb["brown"], 0)
}

# plt.style.use('ggplot')

####################################################################################
# Config ###########################################################################
####################################################################################

# EDIT THIS TO POINT TO YOUR VW BINARY
VW_BINARY = '/MY BINARY PATH'

GENERIC_FLAGS = ['-b', '24',  '--progress', '1', '-c']

CACHE_DIR = 'plotcache/'
FIG_DIR = 'plots/'

N_SHUFFLES = 10

####################################################################################
# Datasets #########################################################################
####################################################################################

# If true, use fixed config for each dataset. Otherwise, use the individual configs in DATASETS below
USE_GRANULAR = False

GRANULAR_ALGOS = [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '10',  '--gamma_exponent', '0.25',  '--fast',  '--learning_rate',  '10.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid', '--tune_gamma']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '10',  '--gamma_exponent', '0.25', '--learning_rate',  '10.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid', '--tune_gamma']),
                 ('Supervised.L', ['--learning_rate',  '1.0', '--loss_function', 'logistic',  '--sigmoid'])
]

GRANULAR_DATASETS = [

]

DATASETS = [('1041', 'multiclass_shuffled/ds_1041_10{}.vw.gz', 10,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '10', '--fast',  '--learning_rate',  '3.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '100', '--learning_rate',  '1.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '0.3', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('1110', 'multiclass_shuffled/ds_1110_23{}.vw.gz', 23,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '700',  '--fast',  '--learning_rate',  '3.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '1000',  '--learning_rate',  '3.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid' ]),
                 ('Supervised.L', ['--learning_rate',  '10.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('1113', 'multiclass_shuffled/ds_1113_23{}.vw.gz', 23, 
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '100',  '--fast',  '--learning_rate',  '3.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '1000',  '--learning_rate',  '10.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '10.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('150', 'multiclass_shuffled/ds_150_7{}.vw.gz', 7,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '10',  '--fast',  '--learning_rate',  '3.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '10', '--learning_rate',  '3.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '1.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('153', 'multiclass_shuffled/ds_153_2{}.vw.gz', 2,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '400', '--gamma_exponent', '0.25', '--fast',  '--learning_rate',  '10.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '10',  '--gamma_exponent', '0.25', '--learning_rate',  '1.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '1.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('157', 'multiclass_shuffled/ds_157_5{}.vw.gz', 5,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '50',  '--fast',  '--learning_rate',  '0.1', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '50',  '--gamma_exponent', '0.25', '--learning_rate',  '1.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '0.3', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('158', 'multiclass_shuffled/ds_158_5{}.vw.gz', 5,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '400',  '--gamma_exponent', '0.25',  '--fast',  '--learning_rate',  '0.3', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '400',  '--gamma_exponent', '0.25', '--learning_rate',  '0.3', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '0.3', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('161', 'multiclass_shuffled/ds_161_2{}.vw.gz', 2,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '1000',  '--fast',  '--learning_rate',  '0.3', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '50',  '--gamma_exponent', '0.25', '--learning_rate',  '0.3', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '0.3', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('162', 'multiclass_shuffled/ds_162_2{}.vw.gz', 2,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '700',  '--fast',  '--learning_rate',  '3.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '700', '--learning_rate',  '1.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '1.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('293', 'multiclass_shuffled/ds_293_2{}.vw.gz', 2,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '1000',  '--fast',  '--learning_rate',  '1.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '1000', '--learning_rate',  '3.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '3.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('389', 'multiclass_shuffled/ds_389_17{}.vw.gz', 17,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '1000',  '--fast',  '--learning_rate',  '1.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '700', '--learning_rate',  '3.0', '--gamma_exponent', '0.25',  '--cb_type', 'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '1.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('396', 'multiclass_shuffled/ds_396_6{}.vw.gz', 6,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '100',  '--fast',  '--learning_rate',  '10.0', '--gamma_exponent', '0.25', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '400', '--learning_rate',  '10.0',  '--cb_type', 'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '1.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('399', 'multiclass_shuffled/ds_399_10{}.vw.gz', 10,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '1000',  '--fast',  '--learning_rate',  '10.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '50', '--learning_rate',  '3.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '0.3', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('458', 'multiclass_shuffled/ds_458_4{}.vw.gz', 4,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '10',  '--fast',  '--learning_rate',  '10.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '10', '--learning_rate',  '10.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '1.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('554', 'multiclass_shuffled/ds_554_10{}.vw.gz', 10,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '10.0',  '--fast',  '--learning_rate',  '1.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '50', '--learning_rate',  '3.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '0.3', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('822', 'multiclass_shuffled/ds_822_2{}.vw.gz', 2,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '10',  '--fast',  '--learning_rate',  '10.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '10',  '--learning_rate',  '10.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '1.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
            ('971', 'multiclass_shuffled/ds_971_2{}.vw.gz', 2,
             [
                 ('FastCB.L', ['--squarecb', '--gamma_scale', '10', '--gamma_exponent', '0.25',  '--fast',  '--learning_rate',  '10.0', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('SquareCB.L', ['--squarecb', '--gamma_scale', '1000',  '--learning_rate',  '0.003', '--cb_type',  'mtr',  '--loss_function', 'logistic',  '--sigmoid']),
                 ('Supervised.L', ['--learning_rate',  '1.0', '--loss_function', 'logistic',  '--sigmoid'])
             ]),
]

####################################################################################
# Paper setup ######################################################################
####################################################################################

# IGNORE_Y = ['1041', '1113', '293']
IGNORE_Y = ['293', '1113']
# IGNORE_X = ['971', '1113']
IGNORE_X = ['293', '1041']
DATASET_NAMES = {'1113': 'kddcup99', '971': 'mfeat-fourier', '1041': 'gina-prior2', '389': 'fbis-wc', '293': 'covertype'}

# Either "bernoulli" or "gaussian"
CONFIDENCE_TYPE = "bernoulli"

####################################################################################
# Main #############################################################################
####################################################################################

if __name__=='__main__':

    parser = argparse.ArgumentParser(description='Generate performance plots for specific datasets')
    parser.add_argument('--paper', action='store_true', default=False)
    args = parser.parse_args()
    
    start = time.time()

    if not os.path.exists(CACHE_DIR):
        os.mkdir(CACHE_DIR)

    if SEPARATE_LEGEND is True:
        # Generate and save legend pdf.
        figlegend = plt.figure()
        fig = plt.figure(figsize=(8,6))
        ax = fig.add_subplot(111)
        lines = []
        names = Styles.keys()

        for (name, style) in Styles.items():
            # ll, = ax.plot([0], [0], color=style[0], linestyle=style[1])
            color=style[0]
            order = style[1]
            
            ll, = ax.plot([0], [0], color=color)
            lines.append(matplotlib.patches.Patch(color=ll.get_color(), label=name, linestyle='solid'))

        legend = figlegend.legend(lines, names, loc='center', ncol=6, frameon=True)
        figlegend.canvas.draw()

        figlegend.savefig(FIG_DIR + "legend.pdf", format='pdf', dpi=DPI, bbox_inches=legend.get_window_extent().transformed(figlegend.dpi_scale_trans.inverted()))


    # for (ds_name, ds_command, algos) in DATASETS[0:2]:
    for (ds_name, ds_path_base, ds_na, algos) in DATASETS:

        if USE_GRANULAR:
            algos = GRANULAR_ALGOS
        
        results = {}

        for (algo, algo_flags) in algos:
            
            results_algo = []

            for idx in range(N_SHUFFLES):

                ds_path = ds_path_base.format('_shuf{}'.format(idx))

                if algo == 'Supervised.L':
                    reduction_flags = ['--oaa', str(ds_na)]
                else:
                    reduction_flags = ['--cb_explore_adf', '--cbify', str(ds_na)]

                command_str = [VW_BINARY] + [ds_path] + reduction_flags + GENERIC_FLAGS + algo_flags

                config_hash = hashlib.md5(str((ds_path, tuple(reduction_flags), tuple(algo_flags))).encode('utf-8')).hexdigest()
                cache_path = os.path.join(CACHE_DIR, str(config_hash))

                if os.path.exists(cache_path):
                    rs_algo = pickle.load(gzip.open(cache_path, 'rb'))
                    results_algo.append(rs_algo)
                    ds_sz = len(rs_algo['pv_loss'])
                else:
                    try:
                        print("Running command ", command_str)
                        output = subprocess.check_output(command_str, stderr=subprocess.STDOUT).decode('ascii')
                    except subprocess.CalledProcessError as e:
                        print(e.returncode)
                        print(e.output)

                    lines = output.split('\n')
                    # The lines of the output where (PV) loss is plotted. May change with different VW versions
                    loss_lines = lines[10:-9]

                    # Number of examples
                    ds_sz = len(loss_lines)

                    if ds_sz == 0:
                        print("Error parsing output for command ", command_str)
                        print("VW output:")
                        for line in lines:
                            print(line)

                    # Progressive validation loss (ie, 1/t \sum_{s=1}^{t}l_s(a_s))
                    pv_loss = []
                    # Instantaneous loss
                    inst_loss = []

                    for line in loss_lines:
                        vals = line.split()
                        # print('line: ', line)
                        # print('vals: ', vals)
                        pv_loss.append(float(vals[0]))
                        inst_loss.append(float(vals[1]))

                    result = {'pv_loss': np.array(pv_loss), 'inst_loss': np.array(inst_loss)}
                    pickle.dump(result, gzip.open(cache_path, 'wb'))
                    results_algo.append(result)
            results[algo] = results_algo


        n_points = 1000

        if ds_sz > n_points:
            res = int(ds_sz/n_points)
        else:
            res = 1

           
        # res = 1
        if ds_sz > 200:
            filter_first = 100
        else:
            filter_first = 1

        print("Plotting dataset ", ds_name)

        # plt.figure()
        fig, ax = plt.subplots()
        x = range(ds_sz)
        # print(ds_sz)
        for (algo, algo_results) in results.items():

            pv = np.zeros((N_SHUFFLES, ds_sz))
            for idx in range(N_SHUFFLES):
                pv[idx,:] = algo_results[idx]['pv_loss']

            pv_mean = np.mean(pv, axis=0)
            if CONFIDENCE_TYPE == "gaussian":
                pv_std = np.std(pv, axis=0)*2/np.sqrt(N_SHUFFLES)
                pv_upper = pv_mean + std
                pv_lower = pv_mean - std
            elif CONFIDENCE_TYPE == "bernoulli":
                ns = np.arange(1, ds_sz+1)*N_SHUFFLES
                # ns = np.arange(1, ds_sz+1)
                (pv_upper, pv_lower) = inverse_significance(pv_mean, ns)
                
            style = Styles[algo]
            color = style[0]
            order = style[1]

            # +2 ensures that this plots over grid lines
            li, = ax.semilogx(x[filter_first::res], pv_mean[filter_first::res], label=algo, color=color, linewidth=LINEWIDTH, zorder=order+2)

            ax.fill_between(x[filter_first::res], pv_lower[filter_first::res], pv_upper[filter_first::res], color=li.get_color(), alpha = 0.2)
            
        # X Axis
        if not args.paper or ds_name not in IGNORE_X:
            ax.set_xlabel('Number of examples')
        else:
            ax.set_xlabel('Number of examples', alpha=0)

        ax.grid(True, axis='x', which='both')
        
        locmaj = matplotlib.ticker.LogLocator(base=10,numticks=12) 
        ax.xaxis.set_major_locator(locmaj)
        locmin = matplotlib.ticker.LogLocator(base=10.0,subs=(0.2,0.4,0.6,0.8),numticks=12)
        ax.xaxis.set_minor_locator(locmin)
        ax.xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())

        ax.set_xlim(left=x[filter_first], right=x[-1])

        # Y Axis
        ax.set_ylim(bottom=0.0, top=1.0)
        ax.grid(True, axis='y', which='both')
        if not args.paper or ds_name not in IGNORE_Y:
            # ax.set_ylabel('Prog. val. loss')
            ax.set_ylabel('Progressive val. loss')
        else:
            ax.set_yticklabels([])
            plt.tick_params(
                axis='y',          # changes apply to the x-axis
                which='both',      # both major and minor ticks are affected
                left=False,      # ticks along the bottom edge are off
                right=False,         # ticks along the top edge are off
            )

        # Title
        if not args.paper or ds_name not in DATASET_NAMES.keys():
            ax.set_title('PV Loss for Dataset {}'.format(ds_name))
        else:
            ax.set_title('{} (#{})'.format(DATASET_NAMES[ds_name],ds_name))

        
        if not SEPARATE_LEGEND:
            ax.legend(loc='upper right')

        # plt.figure()
        # x = range(ds_sz)
        # for (algo, algo_results) in results.items():
        #     total_loss = np.cumsum(algo_results['inst_loss'])
        #     # import ipdb; ipdb.set_trace()
        #     plt.semilogx(x[::res], total_loss[::res], label=algo)
        # # print('hello')

        # plt.xlabel('Number of examples')
        # plt.ylabel('Total Loss')
        # plt.title('Total Loss for Dataset {}'.format(ds_name))
        # plt.legend(loc='upper left')

        fig_path = os.path.join(FIG_DIR, "{}.pdf".format(ds_name))
        plt.savefig(fig_path, bbox_inches='tight', dpi=DPI, pad_inches=0)

    stop = time.time()
    print("Elapsed time: ", stop - start)

    # plt.show()
