import os
import itertools
import argparse

import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn')

import json

parser = argparse.ArgumentParser()
parser.add_argument("expdir")
parser.add_argument("outdir")
args = parser.parse_args()

def make_plots(noise, risks, outdir):
  plt.figure(figsize=(8,4))
  for conf_string in noise:
    plt.loglog(noise[conf_string], risks[conf_string], label=conf_string)
  # plt.legend(frameon=True, framealpha=0.8)
  plt.title("Varying novel task difficulty with MAML")
  plt.xlabel("Novel task noise")
  plt.ylabel("Risk")
  plt.tight_layout()
  plt.xlim(xmax=10)
  plt.ylim(ymin=0,ymax=100)
  plt.savefig(os.path.join(outdir, 'plot.png'))
  plt.savefig(os.path.join(outdir, 'plot.pdf'))


def get_noise_and_risk(all_configs, all_metrics):
  target_noise = {}
  exp_risk = {}
  def _make_config_string(M,n,k):
    return "(M,n,k)=({0},{1},{2})".format(M,n,k)
  for (conf, metrics) in zip(all_configs, all_metrics):
    conf_string = _make_config_string(conf["n_tasks"], conf["n_supp"], conf["k_supp"])
    if conf_string not in target_noise:
      target_noise[conf_string] = []
      exp_risk[conf_string] = []
    target_noise[conf_string].append(conf["val_obs_noise"])
    exp_risk[conf_string].append(metrics["val.loss"]["values"][-1])
  for conf_string in target_noise:
    sort_idxs = np.argsort(target_noise[conf_string])
    target_noise[conf_string] = np.array(target_noise[conf_string])[sort_idxs]
    exp_risk[conf_string] = np.array(exp_risk[conf_string])[sort_idxs]
  return target_noise, exp_risk


def get_run_stats(expdir):
  alldirs = os.listdir(expdir)
  configs = []
  metrics = []
  for rundir in alldirs:
    # Sacred saves all source code
    if rundir == '_sources':
      continue
    dirpath = os.path.join(expdir, rundir)
    if not os.path.isdir(dirpath):
      continue
    config_f = os.path.join(dirpath, 'config.json')
    metrics_f = os.path.join(dirpath, 'metrics.json')
    with open(config_f, 'r') as f:
      configs.append(json.load(f))
    with open(metrics_f, 'r') as f:
      metrics.append(json.load(f))
  return configs, metrics

def process_experiments(expdir, outdir):
  if not os.path.exists(outdir):
    os.makedirs(outdir)
  all_configs, all_metrics = get_run_stats(expdir)
  noise, risk = get_noise_and_risk(all_configs, all_metrics)
  make_plots(noise, risk, outdir)



if __name__ == '__main__':
  process_experiments(args.expdir, args.outdir)