import os
import math
from utils.plotter import Plotter
from utils.sweeper import unfinished_index, time_info, memory_info
from utils.helper import set_one_thread


def get_process_result_dict(result, config_idx, mode='Train'):
  result_dict = {
    'Env': result['Env'][0],
    'Agent': result['Agent'][0],
    'Config Index': config_idx,
    'Return (mean)': result['Return'][-100:].mean() if mode=='Train' else result['Return'][-20:].mean()
  }
  return result_dict

def get_csv_result_dict(result, config_idx, mode='Train'):
  result_dict = {
    'Env': result['Env'][0],
    'Agent': result['Agent'][0],
    'Config Index': config_idx,
    'Return (mean)': result['Return (mean)'].mean(),
    'Return (se)': result['Return (mean)'].sem(ddof=0)
  }
  return result_dict

cfg = {
  'exp': 'exp_name',
  'merged': True,
  'x_label': 'Step',
  'y_label': 'Average Return',
  'hue_label': 'Agent',
  'show': False,
  'imgType': 'png',
  'ci': 'se',
  'x_format': None,
  'y_format': None,
  'xlim': {'min': None, 'max': None},
  'ylim': {'min': None, 'max': None},
  'EMA': True,
  'loc': 'lower right',
  'sweep_keys': ['lr', 'consod_end', 'consod_epoch', 'consod_size', 'memory_size', 'network_update_steps', 'target_network_update_steps'],
  'sort_by': ['Return (mean)', 'Return (se)'],
  'ascending': [False, True],
  'runs': 1
}

def analyze(exp, runs=1):
  set_one_thread()
  cfg['exp'] = exp
  cfg['runs'] = runs
  plotter = Plotter(cfg)
  
  plotter.csv_results('Train', get_csv_result_dict, get_process_result_dict)
  # plotter.plot_results(mode='Train', indexes='all')
  
  # plotter.plot_indexList([5,4,3,2,1], 'Train', 'top')


if __name__ == "__main__":
  exp, runs = 'catcher_dqn', 10
  # exp, runs = 'catcher_medqn', 10
  unfinished_index(exp, runs=runs)
  memory_info(exp, runs=runs)
  time_info(exp, runs=runs)
  analyze(exp, runs=runs)