import os
import pandas
import matplotlib.pyplot as plt
import numpy as np

def smooth_nums(nums, num_hist=10):
  smoothed = []
  curr = 0
  for i, num in enumerate(nums):
    curr += num
    if i >= num_hist:
      curr -= nums[i-num_hist]
    smoothed.append(curr / min(i+1, num_hist))
  return smoothed

def threshold_results(mode, codes, labels, env_name, num, return_type, smooth=False, num_hist=None, pass_num_expert=None, truncate=None):
  assert len(codes) == len(labels)
  assert truncate is not None
  num_expert = 4 if (env_name[:8] == "HopperFH" and mode != "bc") else 16
  if pass_num_expert is not None:
    num_expert = pass_num_expert
  results_csvs = []
  for i in range(len(codes)):
    curr_results_csv = []
    for j in range(len(codes[i])):
      base_path = os.path.join(f"/content/drive/MyDrive/IL/Clean-f-IRL/logs/{env_name}/exp-{num_expert}/", mode)
      try:
        results_csv = pandas.read_csv(os.path.join(os.path.join(base_path, codes[i][j]), "progress.csv"))
      except FileNotFoundError:
        results_csv = pandas.read_csv(os.path.join(os.path.join(base_path, codes[i][j]), "raw_progress.csv"))
      print(len(results_csv))
      curr_results_csv.append(results_csv)
    results_csvs.append(curr_results_csv)
  iteration_numbers = results_csvs[0][0]["Iteration"][:truncate]
  for metric in results_csvs[0][0].keys():
    if (return_type == "det" and metric == "Real Det Return") or (return_type == "sto" and metric == "Real Sto Return"):
      for i in range(len(results_csvs)):
        stacked_values = []
        for j in range(len(results_csvs[i])):
          orig_values = results_csvs[i][j][metric].values
          orig_values = orig_values[:truncate]
          if smooth:
            results = smooth_nums(orig_values, num_hist=num_hist)
            stacked_values.append(results)
          else:
            stacked_values.append(orig_values)
        stacked_values = np.stack(stacked_values, axis=0)
        mean_values = np.mean(stacked_values, axis=0)
        found = False
        for k in range(len(mean_values)):
          if mean_values[k] > num:
            print(labels[i], f"> {num}", iteration_numbers[k], mean_values[k])
            found = True
            break
        if found:
          continue
        print(labels[i], "did not surpass", num)
