import os
import argparse
import fnmatch
import re

import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
plt.rcParams.update({'font.size': 13})
plt.rcParams.update({'font.family':'Arial'})
from scipy.signal import savgol_filter
import seaborn as sns

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--base_path', type=str, default='~/logs/repaired', help='Path to results directory.')
    parser.add_argument('-p', '--xpid_prefixes', type=str, nargs='+', default=['latest'])
    parser.add_argument('-l', '--labels', type=str, nargs='+', default=[None])
    parser.add_argument('--max_step', type=float, default=5e8)
    parser.add_argument('-e', '--env_type', type=str, choices=['minigrid', 'car_racing'], default='minigrid')
    parser.add_argument('-b', '--base', type=float, default=1e7)
    parser.add_argument('--plot_type', type=str, choices=['mean_sem', 'mean_std', 'median', 'mean_sem_median'], default='mean')
    parser.add_argument('--show_paired_3b', action='store_true')
    parser.add_argument('--show_minimax_3b', action='store_true')
    parser.add_argument('--show_dr_3b', action='store_true')
    parser.add_argument('--num_dec', type=int, default=0)
    parser.add_argument('--xtick', type=int, default=50e6)

    parser.add_argument('--y_norm_max', type=float, default=None)

    parser.add_argument('--cols', type=str, nargs='+', default=None)
    parser.add_argument('--col_names', type=str, nargs='+', default=None)
    parser.add_argument('--col_ylabels', type=str, nargs='+', default=None)
    parser.add_argument('--linestyle', type=str, default='solid')
    parser.add_argument('--no_legends', type=bool, nargs='+', default=None)
    parser.add_argument('--cycle_colors', type=bool, default=False)
    parser.add_argument('--figsize', type=str, default='3,5')
    parser.add_argument('--lgd_pos', type=str, choices=['top', 'right'], default='top')
    parser.add_argument('--lgd_fontsize', type=float, default=10)
    parser.add_argument('--fontsize', type=float, default=12)

    parser.add_argument('--y_bounds', type=str, nargs='+', default=None)

    parser.add_argument('-s', '--savename', type=str, default='latest.pdf')

    return parser.parse_args()  

def custom_round(x, base=5):
    return int(base * round(float(x)/base))

def get_paths_for_prefix(base_path, prefix):
    # Get all files matching xpid_prefix
    all_xpids = fnmatch.filter(os.listdir(base_path), f"{prefix}*")
    filter_re = re.compile('.*_[0-9]*$')
    paths = [os.path.join(base_path, x) for x in all_xpids if filter_re.match(x)]

    return paths

def makedf(base_path='~/logs/repaired',
           xpid_prefix='latest', 
           env_type='minigrid', 
           name='Labyrinth', 
           columns=['test_returns:MultiGrid-Labyrinth-v0'], 
           plot_type='mean_sem', 
           base=1e7, 
           max_x=500000):
    df = pd.DataFrame()

    xpid_paths = get_paths_for_prefix(base_path, xpid_prefix)

    print(xpid_prefix)

    for xpid_path in xpid_paths:
        try:
            d = pd.read_csv(f'{xpid_path}/logs.csv') 
            print('opened:', xpid_path)
        except:
            print("Can't open file")
            print(xpid_path, flush=True)
            continue

        d['steps'] = d.steps.apply(lambda x: custom_round(x, base=base))
        print(f"steps: {d['steps'].iloc[-1]}")

        df = pd.concat([df, d]).reset_index(drop=True)

    smooth_params = (15, 1)

    if plot_type.startswith('mean'):
        out = df.groupby('steps').mean().reset_index()[['steps'] + columns]
    elif plot_type=='median':
        out = df.groupby('steps').median().reset_index()[['steps'] + columns]

    for col in columns:
        if args.y_norm_max:
            out[col] = out[col]/args.y_norm_max
            df[col] = df[col]/args.y_norm_max 

        if plot_type.startswith('mean'):
            out[col] = out[col].ewm(alpha=0.1, ignore_na=True).mean()

            if plot_type == 'mean_std':
                std = df.groupby('steps').std().reset_index().reset_index()[[col]].ewm(alpha=0.9, ignore_na=True).mean().values.flatten()
                out[f'{col}_std'] = std
            else:
                sem = df.groupby('steps').sem().reset_index().reset_index()[[col]].ewm(alpha=0.1, ignore_na=True).mean().values.flatten()
                out[f'{col}_sem'] = sem

        elif plot_type == 'median':
            out[col] = savgol_filter(out[col], *smooth_params)
            out[f'{col}_q1'] = savgol_filter(df.groupby('steps').quantile(.25).reset_index()[[col]].values.flatten(), *smooth_params)
            out[f'{col}_q3'] = savgol_filter(df.groupby('steps').quantile(.75).reset_index()[[col]].values.flatten(), *smooth_params)

    out = out[out['steps'] <= max_x]

    return out

def reformat_large_tick_values(tick_val, pos=None):
    """
    Turns large tick values (in the billions, millions and thousands) such as 4500 into 4.5K and also appropriately turns 4000 into 4K (no zero after the decimal).
    
    From: https://dfrieds.com/data-visualizations/how-format-large-tick-values.html
    """
    if tick_val >= 1000000000:
        val = round(tick_val/1000000000, 1)
        new_tick_format = '{:}B'.format(val)
    elif tick_val >= 1000000:
        val = round(tick_val/1000000, 1)
        new_tick_format = '{:}M'.format(val)
    elif tick_val >= 1000:
        val = round(tick_val/1000, 1)
        new_tick_format = '{:}K'.format(val)
    # elif tick_val < 1000 and tick_val >= 0.1:
    #    new_tick_format = round(tick_val, 1)        
    elif tick_val >= 10:
        new_tick_format = round(tick_val, 1)
    elif tick_val >= 1:
        new_tick_format = round(tick_val, 2)        
    elif tick_val >= 1e-4:
        # new_tick_format = '{:}m'.format(val)  
        new_tick_format = round(tick_val, 3)
    elif tick_val >= 1e-8:
        # val = round(tick_val*10000000, 1)
        # new_tick_format = '{:}μ'.format(val)          
        new_tick_format = round(tick_val, 8)   
    else:
        new_tick_format = tick_val

    new_tick_format = str(new_tick_format)
    new_tick_format = new_tick_format if "e" in new_tick_format else new_tick_format[:6]
    index_of_decimal = new_tick_format.find(".")
    
    if index_of_decimal != -1:
        value_after_decimal = new_tick_format[index_of_decimal+1]
        if value_after_decimal == "0" and (tick_val >= 10 or tick_val <= -10 or tick_val == 0.0):
            new_tick_format = new_tick_format[0:index_of_decimal] + new_tick_format[index_of_decimal+2:]
            
    # FIXME: manual hack
    if new_tick_format == "-0.019":
        new_tick_format = "-0.02"
    elif new_tick_format == "-0.039":
        new_tick_format = "-0.04"
            
    return new_tick_format

LABEL_COLORS = {
    'DR': 'black',
    
    'SAMPLR': (0.3711152842731098, 0.6174124752499043, 0.9586047646790773),

    'LP-SAMPLR': (0.2745098 , 0.76862745, 0.30196078),

    'PLR': (0.9637256003082545, 0.40964669235271706, 0.7430230442501574),
    'Robust PLR': (0.9637256003082545, 0.40964669235271706, 0.7430230442501574),
    'PLR$^{⟂}$': (0.9637256003082545, 0.40964669235271706, 0.7430230442501574),
}

if __name__ == '__main__':
    """
    Usage:
    python plot_flex.py \
    -r ~/logs/samplr \
    -p \
    <xpid_1> \
    <xpid_2> \
    <xpid_3> \
    -l <label_1> <label_2> <label_3>
    """
    args = parse_args()

    base_path = os.path.expandvars(os.path.expanduser(args.base_path))
    env_type = args.env_type

    # Set up color scale
    num_colors = len(args.xpid_prefixes)
    colors = sns.husl_palette(num_colors, h=.1)
    colors = sns.husl_palette(6, h=.1)

    base = args.base
    plot_type = args.plot_type

    if args.cols is not None:
        assert(len(args.cols) == len(args.col_names))
        names = args.col_names
        cols = args.cols

    fig, axs = plt.subplots(1,len(names), figsize=eval(args.figsize), sharex=True)
    fig.patch.set_facecolor('white')
    # plt.style.use('seaborn-muted')
    # plt.rcParams["font.family"] = "Arial"

    labels = args.labels

    if len(labels) < len(args.xpid_prefixes):
        labels = labels + [labels[-1],]*(len(args.xpid_prefixes) - len(labels))

    for i, xpid_prefix in enumerate(args.xpid_prefixes): 
        # Create a dataframe with mean/var statistics per xpid_prefix
        label = labels[i]

        if label == 'Robust PLR':
            label = 'PLR$^{⟂}$'

        if label in LABEL_COLORS:
            color = LABEL_COLORS[label]
        else:
            color = colors[i]

        df = makedf(
            base_path=base_path,
            xpid_prefix=xpid_prefix,
            env_type=env_type, 
            name=label, 
            columns=cols, 
            plot_type=plot_type, 
            base=base, 
            max_x=args.max_step)

        if plot_type == 'mean_sem_median':
            median_df = makedf(
                base_path=base_path,
                xpid_prefix=xpid_prefix,
                env_type=env_type, 
                name=label, 
                columns=cols, 
                plot_type='median', 
                base=base, 
                max_x=args.max_step)

        # Plot each col
        for col_idx, col in enumerate(cols):
            name = names[col_idx]

            col_q1 = f'{col}_q1'
            col_q3 = f'{col}_q3'

            if not hasattr(axs, '__iter__'):
                axs = [axs]

            # if col_idx in [1,2] and i in [0,3,4]:
            # if col_idx > 0 and i in [2]:
                # continue

            # axs[col_idx].grid(color=(0.95,0.95,0.95), linestyle='-', linewidth=1)

            if label == 'PLR Robust':
                label = r'PLR$^{\bot}$'
            axs[col_idx].plot(df['steps'], df[col], label=label, color=color, linewidth=2, linestyle=args.linestyle)

            if plot_type.startswith('mean'):
                if plot_type == 'mean_std':
                    col_err = f'{col}_std'
                else:
                    col_err = f'{col}_sem'
                axs[col_idx].fill_between(df['steps'], df[col] + df[col_err], df[col]-df[col_err], alpha=0.2, color=color, lw=0)
            else:
                axs[col_idx].fill_between(df['steps'], df[col_q1], df[col_q3], alpha=0.2, color=color, lw=0)

            if plot_type == 'mean_sem_median':
                axs[col_idx].plot(median_df['steps'], median_df[col], label='_nolegend_', color=color, linewidth=2, linestyle='dashed', alpha=0.5)

            if args.show_paired_3b:
                baseline_x = [0, df['steps'].iloc[-1]]
                baseline_y = [PAIRED_3B_SOLVE_RATE[name],]*2
                axs[col_idx].plot(
                    baseline_x, baseline_y, label='_nolegend_', 
                        alpha=0.35, **PAIRED_3B_LINESTYLE)

            if args.show_minimax_3b:
                baseline_x = [0, df['steps'].iloc[-1]]
                baseline_y = [MINIMAX_3B_SOLVE_RATE[name],]*2
                axs[col_idx].plot(
                    baseline_x, baseline_y, label='_nolegend_', 
                        alpha=0.35, **MINIMAX_3B_LINESTYLE)

            if args.show_dr_3b:
                baseline_x = [0, df['steps'].iloc[-1]]
                baseline_y = [DR_3B_SOLVE_RATE[name],]*2
                axs[col_idx].plot(
                    baseline_x, baseline_y, label='_nolegend_', 
                        alpha=0.35, **DR_3B_LINESTYLE)

            axs[col_idx].set_title(name, fontsize=args.fontsize, pad=args.fontsize)
            axs[col_idx].set_xticks(list([x*args.xtick for x in range(0,round(args.max_step/args.xtick))]) + [args.max_step,])
            axs[col_idx].xaxis.set_major_formatter(mpl.ticker.FuncFormatter(reformat_large_tick_values));
            # axs[col_idx].yaxis.set_major_formatter(mpl.ticker.FuncFormatter(mpl.ticker.FormatStrFormatter('%d')));

            axs[col_idx].spines['top'].set_visible(False)
            axs[col_idx].spines['right'].set_visible(False)

            if args.y_bounds is not None and col_idx < len(args.y_bounds):
                y_bounds = eval(args.y_bounds[col_idx])
                axs[col_idx].set_ylim(*y_bounds)
                # print(y_bounds)
                if 0 <= min(y_bounds) <= 1 and 0 <= max(y_bounds) <= 1:
                    # axs[col_idx].set_yticks([min(y_bounds), max(y_bounds)])
                    if min(y_bounds) > 0 or max(y_bounds) < 1.0 or args.num_dec > 0:
                        num_dec = args.num_dec
                        axs[col_idx].yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter(f'%.{num_dec}f'))

            if args.col_ylabels is not None and col_idx < len(args.col_ylabels):
                axs[col_idx].set_ylabel(args.col_ylabels[col_idx], fontsize=args.fontsize)

    #         step = 250000000
            # row = df.iloc[-1]
            # print(row)
            # print(names[col_idx], row[col], row[col_err])

    # import pdb; pdb.set_trace()


    fig.supxlabel('Steps', fontsize=args.fontsize, va='bottom', y=0.07)

    handles, labels = axs[0].get_legend_handles_labels()

    if plot_type == 'mean_sem_median':
        line = mlines.Line2D([0], [0], color='black', alpha=0.5, linewidth=2, linestyle='dashed') 
        handles.append(line)
        labels.append('Median')

    if args.show_dr_3b:
        line = mlines.Line2D([0], [0], alpha=1, **DR_3B_LINESTYLE)
        handles.insert(0,line)
        labels.insert(0,'DR, 3B steps')

    if args.show_minimax_3b:
        line = mlines.Line2D([0], [0], alpha=1, **MINIMAX_3B_LINESTYLE)
        handles.insert(0,line)
        labels.insert(0,'Minimax, 3B steps')

    if args.show_paired_3b:
        line = mlines.Line2D([0], [0], alpha=1, **PAIRED_3B_LINESTYLE)
        handles.insert(0,line)
        labels.insert(0,'PAIRED, 3B steps')

    plt.rcParams.update({'font.size': 10})

    if args.lgd_pos == 'top':
        lgd = fig.legend(handles=handles, labels=labels, 
                frameon=False, ncol=len(args.labels), 
                loc='upper center', bbox_to_anchor=[0.5, 1.14], framealpha = 1, fancybox = False,
                fontsize=args.lgd_fontsize
                )
    else:
        lgd = plt.legend(handles=handles, labels=labels, 
                frameon=False, ncol=1, loc='center left', bbox_to_anchor=[1,0.6], framealpha = 1, fancybox = False,
                fontsize=args.lgd_fontsize
                )

    # Uncomment for contamination plot
    # lgd = axs[0].legend(handles=handles, labels=labels, 
    #         frameon=False, ncol=1, loc='center left', bbox_to_anchor=[0,0.25], framealpha = 1, fancybox = False,
    #         fontsize=10)

    plt.tight_layout()

    plt.savefig(f'figures/plots/{args.savename}.pdf', bbox_inches='tight')
    plt.show()