from hydra import initialize, compose
from trainer_DLinear import Trainer_DLinear_Causal
from omegaconf import OmegaConf, open_dict
import os
from src.util import *
import numpy as np



import sys
sys.path.append('/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/CT_CRN/')
os.environ["CUDA_VISIBLE_DEVICES"]= "3"


def checkpoint_count(dir):
    list_files = os.listdir(dir)
    version = 0
    for file in list_files:
        if 'best_model' in file:
            version += 1
    return version-1



if __name__ == '__main__':    
    list_coeff = [3.0] # [0.0, 1.0, 2.0, 3.0, 4.0]
    list_len_past = [15]           # [10, 15, 20]  
    list_projection_horizon = [15]  # [5, 10, 15]
    list_cf_seq_mode = ['random_trajectories'] # 'sliding_treatment' or 'random_trajectories'
    list_max_seq_length = [60]  # [20, 30]
    list_lambda1 = [1]
    list_lambda2 = [1]
    list_gpus = [2]   # [3, 2, 1, 0]
    n_workers = 1
    

    with initialize(config_path="./config"):
            args = compose(config_name="config_DLinear.yaml")
            
            with open_dict(args):
                args.dataset.train_batch_size = 1024
                args.dataset.val_batch_size = 1024
                args.dataset.test_batch_size = 1024
                args.dataset.coeff = 4.0             # 0, 1, 2, 3, 4
                args.dataset.len_past = 15            # 15, 30, 45
                args.dataset.autoregressive = False
                args.dataset.flag_include_init = False
                args.dataset.projection_horizon = 15  # 15, 30, 45
                args.dataset.num_patients.train = 10000
                args.dataset.num_patients.val = 1000
                args.dataset.num_patients.test = 1000
                args.dataset.cf_seq_mode = 'random_trajectories'     # 'sliding_treatment' or 'random_trajectories'
                args.model.multi.seq_len= args.dataset.len_past
                args.model.multi.label_len= args.dataset.len_past
                args.model.multi.pred_len= args.dataset.projection_horizon 
                args.dataset.max_seq_length = 60      # 60, 90, 120
                args.exp.lr = 1e-3
                args.exp.max_epochs = 150
                args.exp.logging = False
                args.exp.patience = 10
                args.exp.unscale_rmse = True
                args.exp.percentage_rmse = True
                args.exp.param_lambda1 = 1
                args.exp.param_lambda2 = 1
                args.gpus = 1
                # args.exp.weights_ema = False
                args.exp.seed = np.random.randint(0, 100)
                args.exp.checkpoint_path = '/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/CT_CRN/checkpoints/230932_synthetic/'
                args.exp.data_path = '/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/CT_CRN/checkpoints/230932_synthetic/'
                # args.exp.checkpoint_path_full = args.exp.checkpoint_path + 'DLinear_coeff_' + str(int(args.dataset.coeff)) + '_past_' + str(args.dataset.len_past) + '_maxseq_' + str(args.dataset.max_seq_length) + '_lambda1_' + str(args.exp.param_lambda1) + '_lambda2_' + str(args.exp.param_lambda2) + '_' + args.dataset.cf_seq_mode + '/'
                sub_path = 'DLinear_coeff_' + str(int(args.dataset.coeff)) + '_past_' + str(args.dataset.len_past) + '_maxseq_' + str(args.dataset.max_seq_length) + '_' + args.dataset.cf_seq_mode + '/'
                args.exp.checkpoint_path_full = args.exp.checkpoint_path + sub_path
                args.exp.data_path_full = args.exp.data_path + sub_path
                if not os.path.exists(args.exp.checkpoint_path):
                    os.mkdir(args.exp.checkpoint_path)
                if not os.path.exists(args.exp.checkpoint_path_full):
                    os.mkdir(args.exp.checkpoint_path_full)  
    trainer = Trainer_DLinear_Causal(args)
    # trainer.trainer_pl(args)
    # trainer.train_model_all(list_coeff = list_coeff, 
    #                         list_len_past = list_len_past, 
    #                         list_cf_seq_mode = list_cf_seq_mode, 
    #                         list_max_seq_length = list_max_seq_length, 
    #                         list_lambda1 = list_lambda1,
    #                         list_lambda2 = list_lambda2,
    #                         list_gpu = list_gpus, 
    #                         n_workers = n_workers)
    
    


    files = os.listdir(args.exp.checkpoint_path_full)
    version = len(files)-2


    # load the best model
    from src.models.DLinear_Causality import Model
    if version == 0:
        trainer.model = Model.load_from_checkpoint(args.exp.checkpoint_path_full + '/best_model.ckpt')
    else:
        trainer.model = Model.load_from_checkpoint(args.exp.checkpoint_path_full + '/best_model-v' + str(int(version)) + '.ckpt')
    trainer.model.trainer = trainer.create_trainer(args)
    trainer.model.args = args
    trainer.init_data(args)

    # evaluation
    mse_result = trainer.get_rmse_result(trainer.dataset_collection.test_cf_treatment_seq_non)
    
    # # save_pickle(mse_result, args.exp.checkpoint_path_full + '/mse_result.pickle') 
    # # print("saved pickle")
    
    print(args.model.name)
    print(mse_result)
    print("gamma: ", args.dataset.coeff)
    print("len of input: ", args.dataset.len_past)
    print("len of prediction: ", args.dataset.projection_horizon)
    print(args.dataset.flag_include_init)
    print(args.exp.seed)
    
    