import argparse
import numpy as np
import os
from utils import load_dict

parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_dir', type=str, default="checkpoints/ihdp100")
parser.add_argument('--metrics', type=str, default="pehe,ate,In-pehe,In-ate")
args = parser.parse_args()

dirs = [x for x in os.listdir(args.checkpoint_dir) if x.endswith('result.pkl')]
metrics = args.metrics.split(',')

all_exp_result_dict = {}
for met in metrics:
	all_exp_result_dict[met] = []
for exp_num in range(0, 100):
	files = [x for x in dirs if x.startswith('exp%04d' % exp_num)]
	if len(files) == 0:
		continue
	result_dict = {}
	selection_criterions = []
	configs = []

	for met in metrics:
		result_dict[met] = []
	for fn in files:
		filepath = os.path.join(args.checkpoint_dir, fn)
		out = load_dict(filepath)
		for met in metrics:
			result_dict[met].append(out[met])
		selection_criterions.append(out['selection_criterion'])
		configs.append(out['config'])

	min_row = np.argmin(selection_criterions)
	true_row = np.argmin(result_dict[metrics[0]])
	print(exp_num, configs[min_row], result_dict['rpol'][min_row],
		  configs[true_row], result_dict['rpol'][true_row])
	for met in metrics:
		#for row in range(len(selection_criterions)):
		#	if configs[row]['model']['parent']['lr'] == 0.0001 and configs[row]['model']['parent']['batch_size'] == 256:
		#		all_exp_result_dict[met].append(result_dict[met][row])
		all_exp_result_dict[met].append(result_dict[met][min_row])


for met in metrics:
	mean = np.mean(all_exp_result_dict[met])
	std = np.std(all_exp_result_dict[met])/np.sqrt(len(all_exp_result_dict[met]))
	print(all_exp_result_dict[met], len(all_exp_result_dict[met]))
	print(met, mean, std)
