import os
import pandas
import matplotlib.pyplot as plt

plot_colors = ["blue", "orange", "green", "red", "purple", "black", "cyan", "silver", "pink", "lime", "salmon", "dodgerblue", "crimson", "cadetblue", "skyblue", "yellow"]

def smooth_nums(nums, num_hist=10):
  smoothed = []
  curr = 0
  for i, num in enumerate(nums):
    curr += num
    if i >= num_hist:
      curr -= nums[i-num_hist]
    smoothed.append(curr / min(i+1, num_hist))
  return smoothed

file_metric_names = {
  "Real Det Return": "Deterministic Return",
  "Real Sto Return": "Stochastic Return",
}

def plot_results(mode, codes, labels, env_name="HalfCheetahFH-v2", smooth=False, num_hist=None, use_file_labels=False, file_env_name=None, save_folder=None, pass_num_expert=None):
  assert len(codes) < len(plot_colors)
  assert len(codes) == len(labels)
  num_expert = 4 if (env_name[:8] == "HopperFH") else 16
  if pass_num_expert is not None:
    num_expert = pass_num_expert
  results_csvs = []
  handles = []
  if use_file_labels:
    os.makedirs(save_folder)
  for i, code in enumerate(codes):
    if pass_num_expert is not None:
      if isinstance(pass_num_expert, list):
        num_expert = pass_num_expert[i]
      else:
        num_expert = pass_num_expert
    base_path = os.path.join(f"/content/drive/MyDrive/IL/Clean-f-IRL/logs/{env_name}/exp-{num_expert}/", mode)
    results_csv = pandas.read_csv(os.path.join(os.path.join(base_path, code), "progress.csv"))
    results_csvs.append(results_csv)
  for metric in results_csvs[0].keys():
    for i, results_csv in enumerate(results_csvs):
      if smooth:
        results = smooth_nums(results_csv[metric].values, num_hist=num_hist)
      else:
        results = results_csv[metric].values
      handle, = plt.plot(results_csv["Iteration"].values, results, color=plot_colors[i])
      handles.append(handle)
    if use_file_labels and metric in file_metric_names:
      metric = file_env_name+" - "+file_metric_names[metric]
    plt.title(metric)
    plt.legend(handles, labels, bbox_to_anchor=(1, 0.07*len(labels)+0.06), loc=1)
    if use_file_labels:
      plt.savefig(os.path.join(save_folder, metric)+".png", dpi=1200, bbox_inches="tight")
      plt.close()
    plt.show()
