import matplotlib.pyplot as plt
from matplotlib import rc
from copy import deepcopy
import numpy as np
import matplotlib
import pickle
import glob
import sys
import os

font = {'size'   : 12}
matplotlib.rc('font', **font)

def main():
    fig_size = 4
    window_size = 200
    interp_steps = 1000
    item_list = ['score', 'cost_0', 'cost_1', 'cost_2']
    color_list = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
    linestyle_list = ['-', '-', '-', '-', '-']

    env_name = "MITCheetah"
    algo_list = []
    algo_list.append({
        'name': 'TQC',
        'logs': [f'results/TQC-Cheetah_s{i}' for i in [1]]
    })
    draw(env_name, item_list, algo_list, fig_size, window_size, interp_steps, color_list, linestyle_list, is_horizon=True)


def draw(env_name, item_list, algo_list, fig_size, window_size, interp_steps, color_list, linestyle_list, is_horizon=False):
    if is_horizon:
        fig, ax_list = plt.subplots(nrows=1, ncols=len(item_list), figsize=(fig_size*len(item_list)*1.1, fig_size))
    else:
        fig, ax_list = plt.subplots(nrows=len(item_list), ncols=1, figsize=(fig_size*1.1, fig_size*len(item_list)))
    if len(item_list) == 1:
        ax_list = [ax_list]

    for item_idx in range(len(item_list)):
        ax = ax_list[item_idx]
        item_name = item_list[item_idx]
        min_value = np.inf
        max_value = -np.inf
        for algo_idx in range(len(algo_list)):
            algo_dict = algo_list[algo_idx]
            algo_name = algo_dict['name']
            algo_logs = algo_dict['logs']
            if item_name != 'score' or 'sdac' in algo_name.lower():
                algo_dirs = ['{}/{}_log'.format(dir_item, item_name) for dir_item in algo_logs]
            else:
                algo_dirs = ['{}/cost_3_log'.format(dir_item) for dir_item in algo_logs]
            linspace, means, stds = parse(algo_dirs, item_name, window_size, interp_steps)

            if (item_idx == 0 and is_horizon) or (not is_horizon):
                ax.plot(linspace, means, lw=2, color=color_list[algo_idx], linestyle=linestyle_list[algo_idx], label=algo_name)
            else:
                ax.plot(linspace, means, lw=2, color=color_list[algo_idx], linestyle=linestyle_list[algo_idx])
            ax.fill_between(linspace, means - stds, means + stds, alpha=0.15)
            max_value = max(max_value, np.max(means))
            min_value = min(min_value, np.min(means))

        prefix, postfix = "", ""
        fontsize = 15 #"x-large"
        ax.set_xlabel('Steps', fontsize=fontsize)

        if item_idx == 0 and is_horizon:
            ax.legend(bbox_to_anchor=(0.69, 0.01, 0.3, 0.5), loc='lower right', ncol=1, mode="expand", borderaxespad=0.01)
        if item_name == "score":
            ax.set_title(f'{prefix}True Reward Sum{postfix}', fontsize=fontsize)
            ax.set_ylim(min_value, 0.0)
        elif "cost" in item_name:
            cost_idx = item_name.split('_')[1]
            ax.set_title(f'{prefix}Cost {cost_idx}{postfix}', fontsize=fontsize)
            ax.set_ylim(0, max_value)
        else:
            ax.set_ylabel(item_name)

        if 'cheetah' in env_name.lower() or 'laikago' in env_name.lower():
            ax.set_xlim(0, 3e6)
        else:
            ax.set_xlim(0, 6e6)
        ax.grid()

    fig.tight_layout()
    save_dir = "./imgs"
    item_names = '&'.join(item_list)
    env_name = env_name.replace(' ', '')
    if not os.path.isdir(save_dir): os.makedirs(save_dir)
    plt.savefig(f'{save_dir}/{env_name}_{item_names}.png')
    plt.show()


def parse(algo_dirs, item_name, window_size, interp_steps):
    algo_datas = []
    min_linspace = None
    min_len = np.inf
    print(f'[parsing] {algo_dirs}')
    for algo_dir in algo_dirs:
        record_paths = glob.glob('./{}/*.pkl'.format(algo_dir))
        record_paths.sort()
        record = []
        for record_path in record_paths:
            with open(record_path, 'rb') as f:
                record += pickle.load(f)

        cnt = 0
        limit_idx = -1
        for record_idx, record_data in enumerate(record):
            pre_cnt = cnt
            cnt += record_data[0]
            if pre_cnt <= 6e6 and cnt > 6e6:
                limit_idx = record_idx
                break
        record = record[:limit_idx]

        if item_name == "metric":
            cv_record_paths = glob.glob('./{}/*.pkl'.format(algo_dir.replace('score', 'cv')))
            cv_record_paths.sort()
            cv_record = []
            for record_path in cv_record_paths:
                with open(record_path, 'rb') as f:
                    cv_record += pickle.load(f)

            cnt = 0
            limit_idx = -1
            for record_idx, record_data in enumerate(cv_record):
                pre_cnt = cnt
                cnt += record_data[0]
                if pre_cnt <= 5e6 and cnt > 5e6:
                    limit_idx = record_idx
                    break
            cv_record = cv_record[:limit_idx]

        steps = [0]
        data = [0.0]
        for step_idx in range(len(record)):
            steps.append(steps[-1] + record[step_idx][0])
            if item_name == 'metric':
                data.append(record[step_idx][1]/(cv_record[step_idx][1] + 1))
            elif 'total' in item_name:
                data.append(data[-1] + record[step_idx][1])
            else:
                data.append(record[step_idx][1])

        linspace = np.linspace(steps[0], steps[-1], int((steps[-1]-steps[0])/interp_steps + 1))
        if min_len > len(linspace):
            min_linspace = linspace[:]
            min_len = len(linspace)
        interp_data = np.interp(linspace, steps, data)
        algo_datas.append(interp_data)

    algo_len = min([len(data) for data in algo_datas])
    algo_datas = [data[:algo_len] for data in algo_datas]

    smoothed_means, smoothed_stds = smoothing(algo_datas, window_size)
    return min_linspace, smoothed_means, smoothed_stds

def smoothing(data, window_size):
    means = []
    stds = []
    for i in range(1, len(data[0]) + 1):
        if i < window_size:
            start_idx = 0
        else:
            start_idx = i - window_size
        end_idx = i
        concat_data = np.concatenate([item[start_idx:end_idx] for item in data])
        a = np.mean(concat_data)
        b = np.std(concat_data)
        means.append(a)
        stds.append(b)
    return np.array(means), np.array(stds)

if __name__ == "__main__":
    main()
