import os, argparse, json, math, numpy as np, matplotlib.pyplot as plt


def init_bar_fig(ylim):
    fig = plt.figure(figsize=[10, 5])
    ax = fig.gca()
    ax.grid()
    ax.set_axisbelow(True)
    if ylim != None:
        ax.set_ylim(ylim)
    return ax


def init_test_fig(label):
    fig = plt.figure(figsize=[24, 6])
    fig.suptitle(label, fontsize=16)
    grid_fig = fig.add_gridspec(1, 3)
    ax1 = fig.add_subplot(grid_fig[0, 0], title='Learning')
    ax1.grid()
    ax2 = fig.add_subplot(grid_fig[0, 1], title='Test of u-agent')
    ax2.grid()
    ax3 = fig.add_subplot(grid_fig[0, 2], title='Test of v-agent')
    ax3.grid()
    return fig


EXCLUDED_NAMES = ['.npy', '.pt', '.json', '.DS_Store', '.ipynb_checkpoints', 'model', 'log']


def get_data(path):
    data = {}
    for file_name in os.listdir(path):
        if file_name == 'trs.npy':
            data['trs'] = np.load(os.path.join(path, 'trs.npy'))
        elif file_name == 'tts.npy':
            data['tts'] = np.load(os.path.join(path, 'tts.npy'))
        elif math.prod([substr not in file_name for substr in EXCLUDED_NAMES]):
            data[file_name] = get_data(os.path.join(path, file_name))
    return data


def get_exp_values(exp_data):
    min_value, max_value = float('inf'), -float('inf')
    for test_name, test_data in exp_data.items():
        if 'trs' != test_name and 'tts' != test_name:
            if 'u' == test_name[0]:
                if 'trs' in test_data:
                    min_value = min(min_value, test_data['trs'].min())
                else:
                    min_value = None
                    break
            elif 'v' == test_name[0]:
                if 'trs' in test_data:
                    max_value = max(max_value, test_data['trs'].max())
                else:
                    max_value = None
                    break
                    
    return min_value, max_value
    
    
def get_bar_keys_names_labels(data, bar_info):
    bar_keys, bar_names, bar_labels = [], [], []
    
    if bar_info == None:
        for bar_label in range(len(data)):
            bar_key = bar_name = sorted(data.keys())[bar_label]
            bar_keys.append(bar_key)
            bar_names.append(bar_name)
            bar_labels.append(bar_label)
            
    else:
        for bar_name, bar_label in bar_info.items():
            for bar_key in sorted(data.keys()):
                if bar_name in bar_key:
                    bar_keys.append(bar_key)
                    bar_names.append(bar_name)
                    bar_labels.append(bar_label)
    
    return bar_keys, bar_names, bar_labels
        

def show_bar_fig(data, sub_name, ylim, bar_info):
    fig = init_bar_fig(ylim)
    bar_keys, bar_names, bar_labels = get_bar_keys_names_labels(data, bar_info)
    
    for bar_key, bar_name, bar_label in zip(bar_keys, bar_names, bar_labels):
        bar_data = data[bar_key]
        if sub_name in bar_key:
            mins, maxs = [], []
            skips = 0
            for exp_data in bar_data.values():
                min_value, max_value = get_exp_values(exp_data)
                if min_value == float('inf') or max_value == -float('inf'):
                    skips += 1
                else:
                    mins.append(min_value)
                    maxs.append(max_value)
            if skips > 0:
                print(f'Skips {skips} in {bar_label}: {bar_name}')
            
            if None not in mins and None not in maxs and skips < 5:
                if bar_info == None:
                    print(f'{bar_label}: {bar_name}')
                min_min, mean_min, max_min = np.min(mins), np.mean(mins), np.max(mins)
                min_max, mean_max, max_max = np.min(maxs), np.mean(maxs), np.max(maxs)
                
                min_max = max(min_max, max_min + 0.1)
                mean_max = max(mean_max, min_max)
                max_max = max(max_max, mean_max)

                bar = fig.bar(bar_label, max_max - min_min, 0.5, min_min, alpha=0.2)
                color = bar.patches[0].get_facecolor()
                fig.bar(bar_label, mean_max - mean_min, 0.5, bottom=mean_min, color=color, alpha=0.4)   
                fig.bar(bar_label, max(min_max - max_min, 0.01), 0.5, bottom=max_min, color=color, alpha=1)
            else:
                print(f'None in {bar_label}: {bar_name}')
                
    return None


def get_dense_array(data):
    get_dense_array = []
    prev_ts = 0
    for ts, tr in zip(data['tts'], data['trs']):
        for _ in range(prev_ts, ts):
            get_dense_array.append(tr)
        prev_ts = ts
    return np.array(get_dense_array)
    

def get_fig_values(test_data, point_n=50):
    trs_matrix = np.array([get_dense_array(test_subdata) for test_subdata in test_data.values()])
    block_size = max(1, int(trs_matrix.shape[1] / point_n))
    block_splits = np.arange(block_size, trs_matrix.shape[1], block_size)
    blocks = [block for block in np.split(trs_matrix, block_splits, axis=1)]
    block_splits = np.insert(block_splits, 0, 0)
    return block_splits, list(map(np.min, blocks)), list(map(np.mean, blocks)), list(map(np.max, blocks))


def show_test_axes(axes, data, label):
    xs, min_ys, mean_ys, max_ys = get_fig_values({'_': data})
    axes.fill_between(xs, min_ys, max_ys, alpha=0.2)
    axes.plot(xs, mean_ys, label=label)
    axes.legend()


def show_test_fig(exp_data, title, test_subname):
    fig = init_test_fig(title)
    show_test_axes(fig.axes[0], exp_data, label='learning')
    
    for test_name in sorted(exp_data.keys()):
        test_data = exp_data[test_name]
        if 'u' in test_name[0] and test_subname in test_name:
            show_test_axes(fig.axes[2], test_data, label=test_name)
        elif 'v' in test_name[0] and test_subname in test_name:
            show_test_axes(fig.axes[1], test_data, label=test_name)
    
    mins, maxs = get_exp_values(exp_data)
    fig.axes[2].axhline(np.min(mins), color='black', label='min')
    fig.axes[2].legend()
    fig.axes[1].axhline(np.max(maxs), color='black', label='max')
    fig.axes[1].legend()

    
def show_test_figs(data, exp_subname, test_subname):
    for bar_name in sorted(data.keys()):
        bar_data = data[bar_name]
        if exp_subname in bar_name:
            for exp_i in sorted(bar_data.keys()):
                exp_data = bar_data[exp_i]
                if exp_data != {}:
                    show_test_fig(exp_data, f'{bar_name}: {exp_i}', test_subname)
                else:
                    print(f'{bar_name}:{exp_i} is empty!')

default_bar_info = {'_2xDDQN': '2xDDQN', '_RARL': 'RARL', '_NashDQN': 'NashDQN', 
                    '_MADDPG': 'MADDPG', '_MADQN': 'MADQN', '_CounterDQN': 'CounterDQN', 
                    '_IDQN': 'IDQN', '_DIDQN': 'DIDQN'}
    
def show(data_path, exp_subname='', test_subname='', show_tests=0, 
         ylim=None, bar_info=default_bar_info, true_value=None):
    data = get_data(data_path)
    show_bar_fig(data, exp_subname, ylim, bar_info)
    
    if true_value!= None:
        plt.axhline(y=true_value, color="black", linestyle="--", label='True Value')
        plt.legend()
    
    if show_tests == 1:
        show_test_figs(data, exp_subname, test_subname)
        
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--exp_subname', type=str, default='')
    parser.add_argument('--test_subname', type=str, default='')
    parser.add_argument('--show_tests', type=int, default=0)
    args = parser.parse_args()
    show(args.data_path, args.exp_subname, args.test_subname, args.show_tests)
