from hydra import initialize, compose
from trainer_MICN import Trainer_Causal_MICN
from omegaconf import OmegaConf, open_dict
import os
import numpy as np
import sys

os.environ["CUDA_VISIBLE_DEVICES"]= "1"


if __name__ == '__main__':  
    list_coeff = [4.0] # [0.0, 1.0, 2.0, 3.0, 4.0]
    list_len_past = [15]           # [10, 15, 20]   
    list_projection_horizon = [15]  # [10, 15, 20]
    list_max_seq_length = [60]
    list_lambda1 = [1]
    list_lambda2 = [1]
    list_gpus = [1]    # [3, 2, 1, 0]
    list_cf_seq_mode = ['random_trajectories'] # 'sliding_treatment' or 'random_trajectories'
    n_workers = 1
    
    with initialize(config_path="./config"):
        args = compose(config_name="config_MICN.yaml")
        
        with open_dict(args):
            args.dataset.train_batch_size = 1024
            args.dataset.val_batch_size = 1024
            args.dataset.test_batch_size = 1024
            args.dataset.coeff = 4.0
            args.dataset.len_past = 15
            args.dataset.num_patients.train = 10000
            args.dataset.num_patients.val = 1000
            args.dataset.num_patients.test = 1000
            args.dataset.autoregressive = False
            args.dataset.min_no_samples = args.dataset.len_past
            args.dataset.projection_horizon = 15
            args.model.multi.pred_len = args.dataset.projection_horizon
            args.dataset.label_len = args.dataset.len_past // 2
            args.dataset.max_seq_length = 60
            args.dataset.cf_seq_mode = 'random_trajectories'    # 'random_trajectories' or 'sliding_treatment'
            # args.model.multi.c_out = args.dataset.projection_horizon
            args.model.multi.seq_len = args.dataset.len_past
            args.model.multi.d_model = 256
            args.model.multi.c_out = args.model.multi.outcome_dims
            args.gpus = 1
            args.exp.param_lambda1= 0
            args.exp.param_lambda2= 0
            args.exp.lr = 1e-3
            args.exp.max_epochs = 150
            args.exp.logging = False
            args.exp.patience = 10
            args.exp.unscale_rmse = True
            args.exp.percentage_rmse = True
            # args.exp.weights_ema = False
            args.exp.seed = np.random.randint(0, 1000)
            args.exp.checkpoint_path = '/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/CT_CRN/checkpoints/230931_synthetic/'
            args.exp.checkpoint_path_full = args.exp.checkpoint_path + 'MICN_coeff_' + str(int(args.dataset.coeff)) + '_past_' + str(int(args.dataset.len_past)) + '_maxseq_' + str(int(args.dataset.max_seq_length)) + '_random_trajectories/'
            args.exp.data_path = '/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/CT_CRN/checkpoints/230931_synthetic/'
            args.exp.data_path_full = args.exp.data_path + 'MICN_coeff_' + str(int(args.dataset.coeff)) + '_past_' + str(int(args.dataset.len_past)) + '_maxseq_' + str(int(args.dataset.max_seq_length)) + '_' + args.dataset.cf_seq_mode + '/'
                    
            if not os.path.exists(args.exp.checkpoint_path):
                os.mkdir(args.exp.checkpoint_path)
            if not os.path.exists(args.exp.checkpoint_path_full): 
                os.mkdir(args.exp.checkpoint_path_full)
            if not os.path.exists(args.exp.data_path):
                os.mkdir(args.exp.data_path)
            if not os.path.exists(args.exp.data_path_full): 
                os.mkdir(args.exp.data_path_full)
                
    trainer = Trainer_Causal_MICN(args)
    # trainer.train_model_all(list_coeff, 
    #                         list_len_past, 
    #                         list_cf_seq_mode, 
    #                         list_max_seq_length, 
    #                         list_lambda1, 
    #                         list_lambda2, 
    #                         list_gpus, 
    #                         n_workers)





    files = os.listdir(args.exp.checkpoint_path_full)
    version = len(files)-2


    from src.models.Causal_MICN import Model
    if version == 0:
        trainer.model = Model.load_from_checkpoint(args.exp.checkpoint_path_full + '/best_model.ckpt')
    else:
        trainer.model = Model.load_from_checkpoint(args.exp.checkpoint_path_full + '/best_model-v' + str(int(version)) + '.ckpt')
    trainer.model.trainer = trainer.create_trainer(args)
    trainer.model.args = args
    trainer.init_data(args)


    mse_result = trainer.get_rmse_result(trainer.dataset_collection.test_cf_treatment_seq_non)
    print(args.model.name)
    print(mse_result)
    print("gamma: ", args.dataset.coeff)
    # print("len of input: ", args.dataset.past_len)
    print("len of prediction: ", args.dataset.projection_horizon)
