import torch
import sys
# sys.path.append('../loss')
# sys.path.append('../util')
# sys.path.append('../data')
# from loss.loss import *
# from util.utils import *
import time
# import copy


import hydra
import logging
import sys
sys.path.append('/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/CausalTransformer/')

import numpy as np
from pytorch_lightning.utilities.seed import seed_everything
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from hydra.utils import instantiate
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
# from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from src.models.utils import AlphaRise, FilteringMlFlowLogger,  grad_reverse, BRTreatmentOutcomeHead, bce
from pytorch_lightning import loggers as pl_loggers

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
torch.set_default_dtype(torch.double)

import torch.nn.functional as F
import os, copy
import torch.multiprocessing as mp
from src.util import *



class Trainer_Causal_MICN:
    def __init__(self, args):
        self.args = args




    
    def init_data(self, args):
        seed_everything(args.exp.seed)
        print("data path: ", args.exp.data_path_full + '/dataset_collection.pickle')
        if os.path.exists(args.exp.data_path_full + '/dataset_collection.pickle'):
            self.dataset_collection = load_pickle(args.exp.data_path_full + '/dataset_collection.pickle')
            self.dataset_collection.processed_data_multi = True
            if args.dataset.autoregressive:
                self.train_dataloader = DataLoader(self.dataset_collection.train_f, batch_size=args.dataset.train_batch_size, shuffle=True)
                self.valid_dataloader = DataLoader(self.dataset_collection.val_f, batch_size=args.dataset.val_batch_size, shuffle=False)
                print("create auto-regressive data")
            else:
                self.train_dataloader = DataLoader(self.dataset_collection.train_f_non, batch_size=args.dataset.train_batch_size, shuffle=True)
                self.valid_dataloader = DataLoader(self.dataset_collection.val_f_non, batch_size=args.dataset.val_batch_size, shuffle=False)
                print("create non-autoregressive data")
            if args['dataset']['name'] == 'mimic3_real':
                if self.args.dataset.autoregressive:
                    self.test_dataloader = DataLoader(self.dataset_collection.test_f, batch_size=args.dataset.test_batch_size, shuffle=False)
                else:
                    self.test_dataloader = DataLoader(self.dataset_collection.test_f_non, batch_size=args.dataset.test_batch_size, shuffle=False)
            else:
                if self.args.dataset.autoregressive:
                    self.test_dataloader_one_step = DataLoader(self.dataset_collection.test_cf_one_step, batch_size=args.dataset.test_batch_size, shuffle=False)
                    self.test_dataloader_seq = DataLoader(self.dataset_collection.test_cf_treatment_seq, batch_size=args.dataset.test_batch_size, shuffle=False)
                else:
                    self.test_dataloader_one_step = DataLoader(self.dataset_collection.test_cf_one_step_non, batch_size=args.dataset.test_batch_size, shuffle=False)
                    self.test_dataloader_seq = DataLoader(self.dataset_collection.test_cf_treatment_seq_non, batch_size=args.dataset.test_batch_size, shuffle=False)
        else:
            # data generation
            if args.dataset.name == 'mimic3_real':
                self.dataset_collection, self.train_data_loader, self.valid_data_loader, self.test_dataloader = self.dataset_generation(args)
            else:
                self.dataset_collection, self.train_data_loader, self.valid_data_loader, self.test_data_loader_one_step, self.test_dataloader_seq = self.dataset_generation(args)
            save_pickle(self.dataset_collection, args.exp.data_path_full + '/dataset_collection.pickle')
            
            
        OmegaConf.set_struct(args, False)
        OmegaConf.register_new_resolver("sum", lambda x, y: x + y, replace=True)
        logger.info('\n' + OmegaConf.to_yaml(args, resolve=True))

        args.model.dim_outcomes = self.dataset_collection.train_f.data['outputs'].shape[-1]
        args.model.dim_treatments = self.dataset_collection.train_f.data['current_treatments'].shape[-1]
        args.model.dim_vitals = self.dataset_collection.train_f.data['vitals'].shape[-1] if self.dataset_collection.has_vitals else 0
        args.model.dim_static_features = self.dataset_collection.train_f.data['static_features'].shape[-1]
        
        if not os.path.exists(args.exp.data_path_full + '/dataset_collection.pickle'):
            del self.dataset_collection
        
    def init_model_trainer(self, args):
        # model and optimizer
        self.model = self.init_model(args)
        print("create model!!")
      
        self.model_trainer = self.create_trainer(args)
        print("create trainer!!")
    
    def init_model(self, args):
        # Train_callbacks
        # self.model_callbacks = [AlphaRise(rate=args.exp.alpha_rate)]

        # MlFlow Logger
        if args.exp.logging:
            experiment_name = f'{args.model.name}/{args.dataset.name}'
            # self.logger = FilteringMlFlowLogger(filter_submodels=[], experiment_name=experiment_name, tracking_uri=args.exp.mlflow_uri)    # comet logger
            self.logger = pl_loggers.CometLogger(api_key = None, save_dir="logs/")
            self.model_callbacks += [LearningRateMonitor(logging_interval='epoch')]
            artifacts_path = None
            # artifacts_path = hydra.utils.to_absolute_path(self.logger.experiment.get_run(self.logger.run_id).info.artifact_uri)
        else:
            self.logger = None
            artifacts_path = None

        # model = instantiate(args.model.multi, args, self.dataset_collection, _recursive_=False)
        model = instantiate(args.model.multi, args, dataset_collection=self.dataset_collection)
        return model
    
    import numpy as np

        
    def dataset_generation(self, args):
        seed_everything(args.exp.seed)
        dataset_collection = instantiate(args.dataset, _recursive_=True)
        dataset_collection.process_data_multi()
        
        if args.dataset.autoregressive:
            train_dataloader = DataLoader(dataset_collection.train_f, batch_size=args.dataset.train_batch_size, shuffle=True)
            valid_dataloader = DataLoader(dataset_collection.val_f, batch_size=args.dataset.val_batch_size, shuffle=False)
            # print("create auto-regressive data")
        else:
            train_dataloader = DataLoader(dataset_collection.train_f_non, batch_size=args.dataset.train_batch_size, shuffle=True)
            valid_dataloader = DataLoader(dataset_collection.val_f_non, batch_size=args.dataset.val_batch_size, shuffle=False)
            # print("create non-autoregressive data")
        if args['dataset']['name'] == 'mimic3_real':
            if args.dataset.autoregressive:
                test_dataloader = DataLoader(dataset_collection.test_f, batch_size=args.dataset.test_batch_size, shuffle=False)
            else:
                test_dataloader = DataLoader(dataset_collection.test_f_non, batch_size=args.dataset.test_batch_size, shuffle=False)
        else:
            if args.dataset.autoregressive:
                test_dataloader_one_step = DataLoader(dataset_collection.test_cf_one_step, batch_size=args.dataset.test_batch_size, shuffle=False)
                test_dataloader_seq = DataLoader(dataset_collection.test_cf_treatment_seq, batch_size=args.dataset.test_batch_size, shuffle=False)
            else:
                test_dataloader_one_step = DataLoader(dataset_collection.test_cf_one_step_non, batch_size=args.dataset.test_batch_size, shuffle=False)
                test_dataloader_seq = DataLoader(dataset_collection.test_cf_treatment_seq_non, batch_size=args.dataset.test_batch_size, shuffle=False)
        
        # train_dataloader = DataLoader(dataset_collection.train_f, batch_size=args.dataset.train_batch_size, shuffle=True, drop_last=True)
        # valid_dataloader = DataLoader(dataset_collection.val_f, batch_size=args.dataset.val_batch_size, shuffle=False, drop_last=True)
        # if self.args['dataset']['name'] == 'mimic3_real':
        #     test_dataloader = DataLoader(dataset_collection.test_f, batch_size=args.dataset.test_batch_size, shuffle=False)
        # else:
        #     test_dataloader_one_step = DataLoader(dataset_collection.test_cf_one_step, batch_size=args.dataset.test_batch_size, shuffle=False)
        #     test_dataloader_seq = DataLoader(dataset_collection.test_cf_treatment_seq, batch_size=args.dataset.test_batch_size, shuffle=False)
        if args['dataset']['name'] == 'mimic3_real':
            return dataset_collection, train_dataloader, valid_dataloader, test_dataloader
        else:
            return dataset_collection, train_dataloader, valid_dataloader, test_dataloader_one_step, test_dataloader_seq
    
    def create_trainer(self, args):
        early_stop_callback = EarlyStopping(monitor="valid_loss", mode="min", patience=10, verbose=True)
        checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="valid_loss", mode="min", dirpath=args.exp.checkpoint_path_full, filename="best_model")
        self.model_callbacks = [early_stop_callback, checkpoint_callback]
        if args.exp.logging:
            experiment_name = f'{args.model.name}/{args.dataset.name}'
            # self.logger = FilteringMlFlowLogger(filter_submodels=[], experiment_name=experiment_name, tracking_uri=args.exp.mlflow_uri)    # comet logger
            self.logger = pl_loggers.CometLogger(api_key = API_KEY, save_dir="logs/")
            self.model_callbacks += [LearningRateMonitor(logging_interval='epoch')]
            artifacts_path = None
            # artifacts_path = hydra.utils.to_absolute_path(self.logger.experiment.get_run(self.logger.run_id).info.artifact_uri)
        else:
            self.logger = None
            artifacts_path = None
        
        # print("in trainer: ", args.gpus)
            
        return Trainer(accelerator='gpu', devices=args.gpus, logger=self.logger, max_epochs=args.exp.max_epochs,
                                        callbacks=self.model_callbacks, detect_anomaly=True)
    
    def checkpoint_count(self, dir):
        list_files = os.listdir(dir)
        version = 0
        for file in list_files:
            if 'best_model' in file:
                version += 1
        return version-1

    
    def make_job_models(self, list_coeff, list_len_past, list_cf_seq_mode, list_max_seq_length, list_lambda1, list_lambda2, list_gpu):
        self.dict_path_model = mp.Manager().dict()
        self.queue_job = mp.Queue()
        count = 0
        for coeff in list_coeff:
            for len_past in list_len_past:
                # for projection_horizon in list_projection_horizon:
                for max_seq_length in list_max_seq_length:
                    # projection_horizon = projection_horizon
                    max_seq_length = int(max_seq_length)
                    for cf_seq_mode in list_cf_seq_mode:
                        for cur_lambda1 in list_lambda1:
                            for cur_lambda2 in list_lambda2:
                                args = copy.deepcopy(self.args)
                                with open_dict(args):
                                    args.dataset.coeff = coeff
                                    args.dataset.len_past = len_past
                                    # args.dataset.projection_horizon = projection_horizon
                                    args.dataset.projection_horizon = len_past
                                    args.dataset.max_seq_length = max_seq_length
                                    args.dataset.cf_seq_mode = cf_seq_mode
                                    args.dataset.autoregressive = False
                                    args.model.multi.seq_len = args.dataset.len_past
                                    args.model.multi.label_len = args.dataset.len_past
                                    args.model.multi.pred_len = args.dataset.projection_horizon
                                    args.gpus = [list_gpu[count % len(list_gpu)]]
                                    args.exp.param_lambda1 = cur_lambda1
                                    args.exp.param_lambda2 = cur_lambda2            
                                    
                                    sub_path = args.model.name + '_coeff_' + str(int(args.dataset.coeff)) + '_past_' + str(args.dataset.len_past) + '_maxseq_' + str(args.dataset.max_seq_length) + '_' + args.dataset.cf_seq_mode + '/'
                                    args.exp.checkpoint_path_full = args.exp.checkpoint_path + sub_path
                                    args.exp.data_path_full = args.exp.data_path + args.model.name + '_coeff_' + str(int(args.dataset.coeff)) + '_past_' + str(args.dataset.len_past) + '_maxseq_' + str(args.dataset.max_seq_length) + '_' + args.dataset.cf_seq_mode + '/'
                                    if not os.path.exists(args.exp.data_path_full):
                                        os.mkdir(args.exp.data_path_full)
                                    if not os.path.exists(args.exp.checkpoint_path_full):
                                        os.mkdir(args.exp.checkpoint_path_full)

                                if os.path.exists(args.exp.data_path_full + '/dataset_collection.pickle'):
                                    print(args.exp.data_path_full + '/dataset_collection.pickle' + " exist!!")
                                else:
                                    self.init_data(args)
                            

                                # if self.checkpoint_count(args.exp.checkpoint_path_full) < 0:
                                self.queue_job.put(copy.deepcopy(args))
                                count += 1
                                # else:
                                #     print("checkpoint exist!!")
                                    
        print(f"End - {self.queue_job.qsize()} job made")
                        

    
    def train_model_all(self, list_coeff, list_len_past, list_cf_seq_mode, list_max_seq_length, list_lambda1, list_lambda2, list_gpu, n_workers):
        self.make_job_models(list_coeff, list_len_past, list_cf_seq_mode, list_max_seq_length, list_lambda1, list_lambda2, list_gpu)
        list_process = []
        for _ in range(n_workers):
            process = mp.Process(target=self.__train_from_queue)
            process.start()
            list_process.append(process)
            time.sleep(8)
        for process in list_process:
            process.join()
    
    def get_device_max_free_memory(self, devices):
        list_free_memory = []
        for device in devices:
            with torch.cuda.device(device): 
                info = torch.cuda.mem_get_info()
                list_free_memory.append(info[0])
        max_free_memory = max(list_free_memory)
        return list_free_memory.index(max_free_memory)

    def __train_from_queue(self):
        while True:
            if self.queue_job.empty():
                print(f"-- Empty Job!!")
                return
            args = self.queue_job.get()
            # no_gpu = self.get_device_max_free_memory([0, 1, 2, 3])
            # with open_dict(args):
            #     args.gpus = no_gpu
            # self.args = copy.deepcopy(args)
            self.init_data(args)
            self.init_model_trainer(args)
            print("model, trainer are ready!!")
            
            print(f"train_gpu: {args['gpus']}, lambda1: {args.exp.param_lambda1}, lambda2: {args.exp.param_lambda2}, remain job size: {self.queue_job.qsize()}")
            self.model_trainer.fit(self.model, self.train_dataloader, self.valid_dataloader)
        
    def trainer_pl(self):     
        self.model_trainer.fit(self.model, self.train_data_loader, self.valid_data_loader)
        # self.model_trainer.fit(self.model)

    def get_rmse_result(self, dataset):
        rmse_orig, rmse_all, rmse_last, rmse_time_step = self.model.get_normalised_masked_rmse(dataset, one_step_counterfactual=True)
        return rmse_orig, rmse_all, rmse_last, rmse_time_step
    
    def get_n_step_rmse_result(self, dataset):
        return self.model.get_normalised_n_step_rmses(dataset)
    
    def get_prediction(self, dataset, flag_n_step=False):
        if flag_n_step:
            return self.model.get_autoregressive_predictions(dataset)
        else:
            return self.model.get_predictions(dataset)
    
    def get_trend_seasonality(self, dataset):
        return self.model.get_trend_seasonality(dataset)
        
    def get_representation(self, dataset):
        return self.model.get_representations(dataset)