import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import os
import pandas as pd

# Set matplotlib settings

import matplotlib
from matplotlib.backends.backend_pgf import FigureCanvasPgf
matplotlib.backend_bases.register_backend('pdf', FigureCanvasPgf)

import matplotlib.pyplot as plt
matplotlib.rcParams.update(matplotlib.rcParamsDefault)

latex_preamble = (r'\usepackage{xcolor}'
                  r'\usepackage[scaled]{helvet}'
                  r'\usepackage{amssymb}'
                  r'\usepackage{amsmath}'
                  r'\usepackage{bm}'
                  r'\definecolor{offpolicycolor}{RGB}{112, 48, 160}'
                  r'\definecolor{onpolicycolor}{RGB}{68, 114, 196}')
pgf_with_latex = {
    "text.usetex": True,            # use LaTeX to write all text
    "pgf.rcfonts": True,           # Ignore Matplotlibrc
    # "font.family": "phv",
    # "font.serif": 'Computer Modern Roman',
    "text.latex.preamble": latex_preamble,
    "pgf.preamble": latex_preamble,
}
plt.clf()
matplotlib.rcParams.update(pgf_with_latex)

# Choose colors
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
for i in range(len(tableau20)):
    r, g, b = tableau20[i]
    tableau20[i] = (r / 255., g / 255., b / 255.)

on_policy_color = tableau20[10]
vanilla_color = tableau20[14]
bc_color = tableau20[12]
is_color = tableau20[6]
cql_color = tableau20[4]
pop_color = tableau20[2]

###


# df = pd.DataFrame(columns=['alg', 'p', 'seed', 'policy_return'])

save_dir = "results"
items = []
for i, file in enumerate(os.listdir(save_dir)):
    if file.endswith(".pkl"):
        results = pkl.load(open(os.path.join(save_dir, file), "rb"))
        variant = results['variant']
        # q_error = results['log']['Q-error'][-1]
        policy_return = results['log']['policy_return']
        diverged = len(results['log']['step']) < 200
        items.append(pd.DataFrame({'alg': variant['alg'], 'p': variant['proportion_opt_dataset'], 'seed': variant['seed'], 'policy_return': policy_return, 'diverged': diverged}, index=[i]))
df = pd.concat(items, ignore_index=True)

# # load results with pickle
# with open('results.pkl', 'rb') as f:
#     results = pkl.load(f)

# all_ps = [0, 0.1, 0.2, 0.5, 1.0]
all_ps = sorted(df['p'].unique(), reverse=True)

# Set size of figure
plt.figure(figsize=(6, 4))

# Set font size
plt.rcParams.update({'font.size': 16})

# Set to log scale
# plt.yscale('log')

# Set axis limits
plt.xlim(0, 1)
pad_percent = 0.05
y_max, y_min = 1, 0
pad = pad_percent * (y_max - y_min)
y_max += pad
y_min -= pad
# y_max, y_min = 1e3, 1e-2
plt.ylim(y_min, y_max)

# Set number of ticks
plt.xticks(np.linspace(0, 1, 11))
# plt.yticks(np.logspace(-4, 4, 5))

# Label x ticks
plt.xticks([0, 0.5, 1.0], [r'\quad \quad \textcolor{offpolicycolor}{100\% Off-Policy}', r'\textcolor{offpolicycolor}{50\%}/\textcolor{onpolicycolor}{50\%}', r'\textcolor{onpolicycolor}{100\% On-Policy} \quad \quad'])

# Label axis
plt.ylabel('Expected Return')

# Lighten the color of the axes
ax = plt.gca()
ax.spines['top'].set_color('lightgray')
ax.spines['bottom'].set_color('lightgray')
ax.spines['left'].set_color('lightgray')
ax.spines['right'].set_color('lightgray')
# ax.tick_params(axis='x', colors='gray')
# ax.tick_params(axis='y', colors='gray')

# Remove x ticks
ax.xaxis.set_ticks_position('none')


# # Make y axis log scale
# plt.yscale('log')

# Label x ticks
plt.xticks([0, 0.5, 1.0], [r'\quad \quad \textcolor{offpolicycolor}{100\% Off-Policy}', r'\textcolor{offpolicycolor}{50\%}/\textcolor{onpolicycolor}{50\%}', r'\textcolor{onpolicycolor}{100\% On-Policy} \quad \quad'])
plt.xlim([0, 1])


def plot_results(x, series, reduction='mean', bounds='minmax', color='b', label=None, marker=None, linestyle=None):
    if reduction == 'mean':
        mean_performance = series.mean()
    elif reduction == 'median':
        mean_performance = series.median()
    else:
        raise ValueError('reduction must be mean or median')
    
    if bounds == 'minmax':
        upper = series.max()
        lower = series.min()
    elif bounds == 'std':
        upper = mean_performance + series.std()
        lower = mean_performance - series.std()
    elif bounds == 'std_err':
        if isinstance(series, pd.Series):
            n = len(series)
        elif isinstance(series, pd.core.groupby.generic.SeriesGroupBy):
            n = len(series.groups[0])
        else:
            raise ValueError('series must be pandas Series or SeriesGroupBy')
        upper = mean_performance + series.std() / np.sqrt(n)
        lower = mean_performance - series.std() / np.sqrt(n)
    elif bounds == 'percentile':
        upper = series.quantile(0.75)
        lower = series.quantile(0.25)
    else:
        raise ValueError('bounds must be minmax, std, std_err, or percentile')
    
    if isinstance(series, pd.Series):
        mean_performance = [mean_performance, mean_performance]
        upper = [upper, upper]
        lower = [lower, lower]
    
    line, = plt.plot(x, mean_performance,
                     label=label, linestyle=linestyle, marker=marker, color=color)
    plt.fill_between(x, lower, upper, alpha=0.2, color=color)
    return line

reduction = 'mean'
bounds = 'std_err'

# plot on policy results
on_policy_df = df[df['alg'] == 'on_policy'].sort_values(by=['p'], ascending=False)
on_policy_line = plot_results([0, 1], on_policy_df['policy_return'], reduction=reduction, bounds=bounds, color=on_policy_color, linestyle='--')

# plot vanilla results
vanilla_df = df[df['alg'] == 'vanilla'].sort_values(by=['p'], ascending=False)
vanilla_line = plot_results(all_ps, vanilla_df.groupby(['p'])['policy_return'], reduction=reduction, bounds=bounds, color=vanilla_color, marker='o')

# # plot importance sampling results
# is_df = df[df['alg'] == 'is'].sort_values(by=['p'], ascending=False)
# is_line = plot_results(all_ps, is_df.groupby(['p'])['policy_return'], reduction=reduction, bounds=bounds, color=is_color, linestyle='--')

# plot behavior cloning results
bc_df = df[df['alg'] == 'bc'].sort_values(by=['p'], ascending=False)
bc_line = plot_results(all_ps, bc_df.groupby(['p'])['policy_return'], reduction=reduction, bounds=bounds, color=bc_color, marker='o')

# plot cql results
cql_df = df[df['alg'] == 'cql'].sort_values(by=['p'], ascending=False)
cql_line = plot_results(all_ps, cql_df.groupby(['p'])['policy_return'], reduction=reduction, bounds=bounds, color=cql_color, marker='o')

# plot pop results
pop_df = df[df['alg'] == 'pop'].sort_values(by=['p'], ascending=False)
pop_line = plot_results(all_ps, pop_df.groupby(['p'])['policy_return'], reduction=reduction, bounds=bounds, color=pop_color, marker='o')

# legend
# plt.legend([on_policy_line, vanilla_line, is_line, bc_line, cql_line, pop_line],
#            ['On-Policy SAC', 'Offline SAC', 'Exact Importance Sampling', 'Behavior Cloning', 'CQL', 'POP-QL (Our Method)'],
#            loc='upper left', bbox_to_anchor=(1, 1))
plt.legend([on_policy_line, vanilla_line, bc_line, cql_line, pop_line],
           ['On-Policy SAC', 'Offline SAC', 'Behavior Cloning', 'CQL', 'POP-QL (Our Method)'],
           loc='upper left', bbox_to_anchor=(1, 1))
# plt.legend([pop_line],
#               ['POP-Q'],
#                 loc='upper left', bbox_to_anchor=(1, 1))
plt.savefig('frozen_lake_return.pdf', bbox_inches='tight')
plt.close('all')
