import argparse
import numpy as np
from causally.start.quick_start import run_casual
from causally.start.quick_start_for_pretrain import run_pretrain
from causally.config.configurator import Config
from logging import getLogger
from causally.utils.output import save_result, save_result_exp1_hidden
import json
import warnings
warnings.filterwarnings("ignore")  

if __name__ == '__main__':
    model_list = ['SNet','XNet','TNet','RNet','DRNet','CFR_WASS','TARNet','DragonNet'] 
    all_results = {}
    for model in model_list:
        all_results[model] = {}
        parser = argparse.ArgumentParser()
        parser.add_argument('--exp_info', type=str, default='new_param', help='optimizer') 
        parser.add_argument('--only_pretrain', type=bool, default=False, help='pretrain ae and diffusion',choices=[True,False])
        parser.add_argument('--pretrain_then_optimize', type=bool, default=False,help='pretrain and optimizer',choices=[True,False])
        parser.add_argument('--model', type=str,default=model, help='The model for estimating casual effect') 
        parser.add_argument('--dataset', type=str, default='ACIC-NewParam-StandardScalerX', help='The dataset.')
        parser.add_argument('--trainer', type=str, default='CARD_trainer', help='training method.',choices=['CARD_trainer','standard_trainer']) # CARD_trainer # standard_trainer
        parser.add_argument('--device', type=str, default='cpu', help='optimizer',choices=['cpu','cuda', 'mps'])
        parser.add_argument('--v_reinforce', type=str, default='rl', help='is use reinforce finetune')
        parser.add_argument('--v_rl_algo', type=str, default='pg', help='reinforce algorithm')
        parser.add_argument('--v_batch_type', type=str, default='out_batch', help='training data use type, in_batch or out_batch',choices=['in_batch', 'out_batch'])

        args, _ = parser.parse_known_args()
        config = Config(args = args)

        results, measure_results, missing_results, hidden_results = run_casual(config=config)
        
        logger = getLogger()
        logger.info('\n{}'.format(results))
        info = args.exp_info + '\n'
        info += '\n[{},{},{}]\n'.format(args.model,args.dataset,args.trainer)
        for key in results:
            mean = round(np.mean(results[key]),6)
            std = round(np.std(results[key],ddof=1),6)
            all_results[model]['normal_{}_mean'.format(key)] = mean
            all_results[model]['normal_{}_std'.format(key)] = std
            all_results[model]['normal_{}_res'.format(key)] = '{}±{}'.format(mean,std)
            all_results[model]['normal_{}'.format(key)] = results[key]
            info += 'The {} metric: {}±{}\n'.format(key.upper(),mean,std)
        info += 'done!'
        logger.info('\n'+info)

        info= 'measure result:\n'
        for key in measure_results:
            mean = round(np.mean(measure_results[key]),6)
            std = round(np.std(measure_results[key],ddof=1),6)
            all_results[model]['measure_{}_mean'.format(key)] = mean
            all_results[model]['measure_{}_std'.format(key)] = std
            all_results[model]['measure_{}_res'.format(key)] = '{}±{}'.format(mean,std)
            all_results[model]['measure_{}'.format(key)] = measure_results[key]
            info += 'The {} metric: {}±{}\n'.format(key.upper(),mean,std)
        info += 'done!'
        logger.info('\n'+info)

        info = 'missing result:\n'
        for key in missing_results:
            mean = round(np.mean(missing_results[key]),6)
            std = round(np.std(missing_results[key],ddof=1),6)
            all_results[model]['missing_{}_mean'.format(key)] = mean
            all_results[model]['missing_{}_std'.format(key)] = std
            all_results[model]['missing_{}_res'.format(key)] = '{}±{}'.format(mean,std)
            all_results[model]['missing_{}'.format(key)] = missing_results[key]
            info += 'The {} metric: {}±{}\n'.format(key.upper(),mean,std)
        info += 'done!'
        logger.info('\n'+info)

        info = 'hidden result:\n'
        for key in hidden_results:
            mean = round(np.mean(hidden_results[key]),6)
            std = round(np.std(hidden_results[key],ddof=1),6)
            all_results[model]['hidden_{}_mean'.format(key)] = mean
            all_results[model]['hidden_{}_std'.format(key)] = std
            all_results[model]['hidden_{}_res'.format(key)] = '{}±{}'.format(mean,std)
            all_results[model]['hidden_{}'.format(key)] = hidden_results[key]
            info += 'The {} metric: {}±{}\n'.format(key.upper(),mean,std)
        info += 'done!'
        logger.info('\n'+info)

        if args.trainer == 'CARD_trainer':
            with open('./exp_30/part_{}_{}_RL.json'.format(args.exp_info, args.model), 'w') as f:
                json.dump(all_results[model], f, indent=4, ensure_ascii=False)
        elif args.trainer == 'standard_trainer':
            with open('./exp_30/part_{}_{}_IL.json'.format(args.exp_info, args.model), 'w') as f:
                json.dump(all_results[model], f, indent=4, ensure_ascii=False)

    if args.trainer == 'CARD_trainer':            
        with open('./exp_30/full_{}_RL.json'.format(args.exp_info), 'w') as f:
            json.dump(all_results, f, indent=4, ensure_ascii=False)
    elif args.trainer == 'standard_trainer':
        with open('./exp_30/full_{}_IL.json'.format(args.exp_info), 'w') as f:
            json.dump(all_results, f, indent=4, ensure_ascii=False)

