from hydra import initialize, compose
from trainer_FEDformer import Trainer_FEDformer_Causal
from omegaconf import OmegaConf, open_dict
import os
import numpy as np


import sys
sys.path.append('/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/CT_CRN/')
# os.environ["CUDA_VISIBLE_DEVICES"]= "1"

if __name__ == '__main__':  
    
    list_coeff = [3.0] # [0.0, 1.0, 2.0, 3.0, 4.0]
    list_len_past = [15]           # [15, 30, 45]   
    list_projection_horizon = [15]  # [5, 10, 15]
    list_max_seq_length = [60]
    list_lambda1 = [1]
    list_lambda2 = [1]
    list_gpus = [1]   # [3, 2, 1, 0]
    list_cf_seq_mode = ['random_trajectories'] # 'sliding_treatment' or 'random_trajectories'
    n_workers = 1
    
    with initialize(config_path="./config"):
        args = compose(config_name="config_FEDformer.yaml")
        
        with open_dict(args):
            args.dataset.train_batch_size = 1024
            args.dataset.val_batch_size = 1024
            args.dataset.test_batch_size = 1024
            args.dataset.coeff = 4.0
            args.dataset.len_past = 15
            args.dataset.num_patients.train = 10000
            args.dataset.num_patients.val = 1000
            args.dataset.num_patients.test = 1000
            args.dataset.autoregressive = False
            args.dataset.min_no_samples = args.dataset.len_past
            args.dataset.projection_horizon = 15
            args.dataset.label_len = args.dataset.len_past
            args.dataset.max_seq_length = 60
            args.model.multi.pred_len = args.dataset.projection_horizon
            args.model.multi.seq_len = args.dataset.len_past - 1
            args.model.multi.mode_select = 'else'
            args.exp.lr = 1e-3
            args.exp.max_epochs = 150
            args.exp.logging = False
            args.exp.patience = 10
            args.exp.unscale_rmse = True
            args.exp.percentage_rmse = True
            args.gpus = 1
            args.dataset.cf_seq_mode = 'random_trajectories'     # 'sliding_treatment' or 'random_trajectories'
            args.exp.seed = np.random.randint(0, 100)
            # args.exp.weights_ema = False
            args.exp.checkpoint_path = '/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/CT_CRN/checkpoints/230935_synthetic/'
            args.exp.checkpoint_path_full = args.exp.checkpoint_path + 'FEDformer_else_coeff_' + str(int(args.dataset.coeff)) + '_past_' + str(int(args.dataset.len_past)) + '_maxseq_' + str(int(args.dataset.max_seq_length)) + '_random_trajectories/'
            args.exp.data_path = '/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/CT_CRN/checkpoints/230935_synthetic/'
            args.exp.data_path_full = args.exp.data_path + 'FEDformer_else_coeff_' + str(int(args.dataset.coeff)) + '_past_' + str(int(args.dataset.len_past)) + '_maxseq_' + str(int(args.dataset.max_seq_length)) + '_' + args.dataset.cf_seq_mode + '/' 
            if not os.path.exists(args.exp.checkpoint_path):
                os.mkdir(args.exp.checkpoint_path)
            if not os.path.exists(args.exp.checkpoint_path_full):
                os.mkdir(args.exp.checkpoint_path_full)
                
    trainer = Trainer_FEDformer_Causal(args)
    trainer.train_model_all(list_coeff, 
                            list_len_past, 
                            list_cf_seq_mode, 
                            list_max_seq_length, 
                            list_lambda1, 
                            list_lambda2, 
                            list_gpus, 
                            n_workers)


    files = os.listdir(args.exp.checkpoint_path_full)
    version = len(files)-1


    from src.models.FEDformer_Causality import Model
    if version == 0:
        trainer.model = Model.load_from_checkpoint(args.exp.checkpoint_path_full + '/best_model.ckpt')
    else:
        trainer.model = Model.load_from_checkpoint(args.exp.checkpoint_path_full + '/best_model-v' + str(int(version)) + '.ckpt')
    trainer.model.trainer = trainer.create_trainer(args)
    trainer.model.args = args
    trainer.init_data(args)

    mse_result = trainer.get_rmse_result(trainer.dataset_collection.test_cf_treatment_seq_non)
    print(args.model.name)
    print(mse_result)
    print("gamma: ", args.dataset.coeff)
    print("len of input: ", args.dataset.len_past)
    print("len of prediction: ", args.dataset.projection_horizon)
