import os
import pandas
import matplotlib.pyplot as plt
import numpy as np

color_pairs = [("blue", "lightblue"),
               ("green", "palegreen"),
               ("orange", "wheat"),
               ("purple", "orchid"),
               ("red", "salmon"),
               ("black", "gray"),
               ("teal", "paleturquoise")]

def smooth_nums(nums, num_hist=10):
  smoothed = []
  curr = 0
  for i, num in enumerate(nums):
    curr += num
    if i >= num_hist:
      curr -= nums[i-num_hist]
    smoothed.append(curr / min(i+1, num_hist))
  return smoothed

file_metric_names = {
  "Real Det Return": "Deterministic Return",
  "Real Sto Return": "Stochastic Return",
  "Running Forward KL": "Approximate Forward KL Divergence",
  "Running Reverse KL": "Approximate Reverse KL Divergence"
}

def average_results(mode, codes, labels, env_name="HalfCheetahFH-v2", smooth=False, num_hist=None, use_file_labels=False, file_env_name=None, save_folder=None, pass_num_expert=None, truncate=None):
  assert len(codes) <= len(color_pairs)
  assert len(codes) == len(labels)
  assert truncate is not None
  num_expert = 4 if (env_name[:8] == "HopperFH") else 16
  if pass_num_expert is not None:
    num_expert = pass_num_expert
  results_csvs = []
  handles = []
  if use_file_labels:
    os.makedirs(save_folder)
  for i in range(len(codes)):
    curr_results_csv = []
    for j in range(len(codes[i])):
      base_path = os.path.join(f"/content/drive/MyDrive/IL/Clean-f-IRL/logs/{env_name}/exp-{num_expert}/", mode)
      results_csv = pandas.read_csv(os.path.join(os.path.join(base_path, codes[i][j]), "progress.csv"))
      curr_results_csv.append(results_csv)
    results_csvs.append(curr_results_csv)
  iteration_numbers = results_csvs[0][0]["Iteration"][:truncate]
  for metric in results_csvs[0][0].keys():
    for i in range(len(results_csvs)):
      stacked_values = []
      for j in range(len(results_csvs[i])):
        orig_values = results_csvs[i][j][metric].values
        orig_values = orig_values[:truncate]
        if smooth:
          results = smooth_nums(orig_values, num_hist=num_hist)
          stacked_values.append(results)
        else:
          stacked_values.append(orig_values)
      stacked_values = np.stack(stacked_values, axis=0)
      mean_values = np.mean(stacked_values, axis=0)
      std_values = np.std(stacked_values, axis=0)
      handle, = plt.plot(iteration_numbers, mean_values, color=color_pairs[i][0])
      plt.fill_between(iteration_numbers, mean_values-std_values, mean_values+std_values, color=color_pairs[i][1])
      handles.append(handle)
    if use_file_labels and metric in file_metric_names:
      metric = file_env_name+" - "+file_metric_names[metric]
    plt.title(metric)
    plt.legend(handles, labels, bbox_to_anchor=(1, 0.07*len(labels)+0.06), loc=1)
    if use_file_labels:
      plt.savefig(os.path.join(save_folder, metric)+".png", dpi=1200, bbox_inches="tight")
      plt.close()
    plt.show()
