import sys
import torch
import numpy as np
import argparse
from munch import Munch
from utils import get_config, set_seed, random_sample_configs
from data import get_data
from trainer import ITETrainer, CFRTrainer, Trainer

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default="configs/ihdp100.yaml")
parser.add_argument('--trainer', type=str, default="ITE", choices=['ITE', 'CFR', 'Vanilla'])
parser.add_argument('--exp_num', type=int, default=0, help="used for datasets which have many splits, ihdp1000, so it can be run parallely")
parser.add_argument('--max_search', type=int, default=50, help="how many random configs searches are done")
args = parser.parse_args()
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.device == 'cuda':
	torch.backends.cudnn.benchmark = True
all_config = Munch.fromDict(get_config(args.config))
set_seed(all_config.seed)
all_config.exp_num = args.exp_num
all_config.checkpoint_dir = all_config.checkpoint_dir
dataset_dict = get_data(all_config)

sampled_configs = random_sample_configs(all_config)[:args.max_search]
# we run experiments with all sampled configs
for run_id, config in enumerate(sampled_configs):
	test_metric = config.test_metric.split(',')
	print(config)
	if args.trainer == 'ITE':
		# we perform individual treatment effect estimation
		# we do not access testing factuals during training
		trainer = ITETrainer(dataset_dict, config, args.device, run_id).to(args.device)
	elif args.trainer == 'CFR':
		# we will use testing factuals to help training
		# since we only care counterfactuals
		trainer = CFRTrainer(dataset_dict, config, args.device, run_id).to(args.device)
	elif args.trainer == 'Vanilla':
		# We know which samples are interested, and we put them into testing dataset
		trainer = Trainer(dataset_dict, config, args.device, run_id).to(args.device)


sys.exit('Training finished')




