import os
import pandas
from .plot_results import smooth_nums

def threshold(mode, codes, labels, env_name, num, return_type, smooth=False, num_hist=10, pass_num_expert=None):
  assert len(codes) == len(labels)
  num_expert = 4 if env_name[:8] == "HopperFH" else 16
  if pass_num_expert is not None:
    num_expert = pass_num_expert
  base_path = os.path.join(f"/content/drive/MyDrive/IL/f-IRL/logs/{env_name}/exp-{num_expert}/", mode)
  results_csvs = []
  for code in codes:
    try:
      results_csv = pandas.read_csv(os.path.join(os.path.join(base_path, code), "progress.csv"))
    except FileNotFoundError:
      results_csv = pandas.read_csv(os.path.join(os.path.join(base_path, code), "progress_logs.csv"))
    results_csvs.append(results_csv)
  if return_type == "det":
    metric_name = "Real Det Return"
  elif return_type == "sto":
    metric_name = "Real Sto Return"
  else:
    raise ValueError(f"return_type {return_type} not supported.")
  for i, results_csv in enumerate(results_csvs):
    found = False
    if smooth:
      returns = smooth_nums(results_csv[metric_name].values, num_hist=num_hist)
    else:
      returns = results_csv[metric_name].values
    for j in range(len(returns)):
      if returns[j] > num:
        print(labels[i], f"> {num}", results_csv["Itration"].values[j], returns[j])
        found = True
        break
    if found:
      continue
    print(labels[i], "did not surpass", num)