import argparse
import numpy as np
from causally.start.quick_start import run_casual
from causally.config.configurator import Config
from logging import getLogger
from causally.utils.output import save_result,save_uite_result
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model', type=str,default='UITE_MMD',
                      help='The model for estimating casual effect')
    parser.add_argument('-d', '--dataset', type=str, default='ACIC',
                      help='The dataset.')

    args, _ = parser.parse_known_args()

    config = Config(args = args)
    if config['model'] == 'BCAUSS':
        config['train_batch_size'] = 128
    if args.model in ['UITE_MMD'] and args.dataset in ['ACIC']:
        config['train_batch_size'] = 64
    results = run_casual(config=config)
    logger = getLogger()
    logger.info('\n{}'.format(results))

    info = '\n[{},{}]\n'.format(args.model,args.dataset)
    for key in results:
        mean = round(np.mean(results[key]),2)
        std = round(np.std(results[key],ddof=1),2)
        info += 'The {} metric: {} ± {}\n'.format(key.upper(),mean,std)
    info += 'done!'
    logger.info('\n'+info)