import sys
import torch
import numpy as np
import argparse
from munch import Munch
from utils import get_config, set_seed, random_sample_configs
from data import get_data
from trainer import ITETrainer, CFRTrainer, Trainer, ToyTrainer

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default="configs/ihdp100.yaml")
parser.add_argument('--trainer', type=str, default="ITE", choices=['ITE', 'CFR', 'Vanilla', 'Toy'])
parser.add_argument('--exp_num', type=int, default=0, help="used for datasets which have many splits, ihdp1000, so it can be run parallely")
parser.add_argument('--max_search', type=int, default=50, help="how many random configs searches are done")
parser.add_argument('--max_samples', type=int, default=1e99, help="how many random configs searches are done")
args = parser.parse_args()
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.device == 'cuda':
	torch.backends.cudnn.benchmark = True
all_config = Munch.fromDict(get_config(args.config))
sampled_configs = random_sample_configs(all_config)[:args.max_search]

# we run experiments with all sampled configs
for run_id, config in enumerate(sampled_configs):
	set_seed(config.seed)
	config.exp_num = args.exp_num
	config.data.max_samples = args.max_samples
	dataset_dict = get_data(config)
	config.checkpoint_dir = config.checkpoint_dir + '_seed%s_mn%d' % (config.seed, len(dataset_dict['x_train']))
	test_metric = config.test_metric.split(',')
	print(config)
	if args.trainer == 'ITE':
		# we perform individual treatment effect estimation
		# we do not access testing factuals during training
		trainer = ITETrainer(dataset_dict, config, args.device, run_id).to(args.device)
	elif args.trainer == 'CFR':
		# we will use testing factuals to help training
		# since we only care counterfactuals
		trainer = CFRTrainer(dataset_dict, config, args.device, run_id).to(args.device)
	elif args.trainer == 'Vanilla':
		# We know which samples are interested, and we put them into testing dataset
		trainer = Trainer(dataset_dict, config, args.device, run_id).to(args.device)

	elif args.trainer == 'Toy':
		# We know which samples are interested, and we put them into testing dataset
		trainer = ToyTrainer(dataset_dict, config, args.device, run_id).to(args.device)

	#trainer.run()
	trainer.evaluate_and_record(config, dataset_dict, test_metric, run_id)
sys.exit('Training finished')




