import sys
from omegaconf import OmegaConf
from copy import deepcopy

sys.path.append('./model/')
sys.path.append('./model/base')
sys.path.append('./model/MSM')
sys.path.append('./model/RMSN')
sys.path.append('./model/BR')
sys.path.append('./model/GNET')
sys.path.append('./model/GAN')

sys.path.append('./yaml')

from makeSimData import MakeSimData
from utils import numpy2torch, set_multi_label

# MSM
from MSMPropensity import MSMPropensityTreatment, MSMPropensityHistory
from MSMRegressor import MSMRegressor

# RMSN
from RMSNPropensityNetwork import RMSNPropensityNetworkTreatment, RMSNPropensityNetworkHistory
from RMSNEncoderDecoder    import RMSNEncoder, RMSNDecoder

# CRN
from CRN import CRNEncoder, CRNDecoder

# GNET
from GNet import GNet

# CT, EDCT
from CT import CT
from EDCT import EDCTEncoder, EDCTDecoder

# EDTS
from EDTS import EDTSEncoder, EDTSDecoder

# -----------------------------------------------------------------
# MSM
# -----------------------------------------------------------------
def run_msm(coeff, seed):
    # load config
    config_base = OmegaConf.load(f'./yaml/msm.yaml')
    config_sym  = OmegaConf.load(f'./yaml/cancer_sim.yaml')
    config       = OmegaConf.merge(config_base, config_sym)

    # hyperparameters
    config.dataset.coeff = coeff
    config.exp.seed      = seed
    
    # simulation
    makeSimData = MakeSimData(config)
    processed_dataset, row_dataset = makeSimData.make()

    # Propensity scores
    propensity_treatment = MSMPropensityTreatment(config, processed_dataset)
    propensity_treatment.fit()
    
    propensity_history = MSMPropensityHistory(config, processed_dataset)
    propensity_history.fit()
    
    # Regressor
    msm_regressor = MSMRegressor(config, propensity_treatment, propensity_history, processed_dataset)
    msm_regressor.fit()
    
    # RMSE
    one_step_rmse = msm_regressor.get_one_step_factual_rmse("test_cf")
    multi_step_rmse = msm_regressor.get_multi_step_counterfactual_rmse()
    
    multi_step_rmse[1] = one_step_rmse
    
    sorted_keys = sorted(multi_step_rmse)
    multi_step_rmse = {k: multi_step_rmse[k] for k in sorted_keys}
    
    ret_rmse = {
            "valid_one": 0.0,
            "test_multi" : multi_step_rmse
        }   
    
    return ret_rmse 

# -----------------------------------------------------------------
# RMSN
# -----------------------------------------------------------------
def run_rmsn(gpu, no, coeff, seed):
    # load config
    config_base  = OmegaConf.load(f'./yaml/rmsn_base.yaml')
    config_sym   = OmegaConf.load(f'./yaml/cancer_sim.yaml')
    config_hype  = OmegaConf.load(f'./yaml/RMSN/rmsn_hype_{no:02}.yaml')
    config       = OmegaConf.merge(config_base, config_sym, config_hype)

    # hyperparameters
    config.exp.gpu       = gpu
    config.dataset.coeff = coeff
    config.exp.seed      = seed
    
    # simulation
    makeSimData = MakeSimData(config)
    processed_dataset, row_dataset = makeSimData.make()
    processed_dataset = set_multi_label(processed_dataset)
    processed_dataset_torch = numpy2torch(processed_dataset)

    # Propensity Treatment
    propensity_treatment = RMSNPropensityNetworkTreatment(config, processed_dataset_torch)
    propensity_treatment.train_RMSN()
    
    # Propensity History
    propensity_history = RMSNPropensityNetworkHistory(config, processed_dataset_torch)
    propensity_history.train_RMSN()
    
    # Encoder
    encoder = RMSNEncoder(config, propensity_treatment, propensity_history, processed_dataset_torch)
    encoder_loss = encoder.train_RMSN()

    encoder_rmse_valid = encoder.validation_step()
    encoder_rmse_test  = encoder.test_step()
    
    encoder_rmse = {
        "valid_one": encoder_rmse_valid,
        "test_multi" : encoder_rmse_test
    } 
    
    # Decoder
    decoder = RMSNDecoder(config, encoder, processed_dataset_torch)
    decoder_loss = decoder.train_RMSN()
    
    decoder_rmse_valid = decoder.validation_step()
    decoder_rmse_test  = decoder.test_step()
    decoder_rmse = {
        "valid_multi": decoder_rmse_valid,
        "test_multi" : decoder_rmse_test
    }
    
    return encoder_rmse, decoder_rmse, encoder_loss, decoder_loss

# -----------------------------------------------------------------
# CRN
# -----------------------------------------------------------------
def run_crn(gpu, no, coeff, seed):
    config_base  = OmegaConf.load(f'./yaml/crn_base.yaml')
    config_sym   = OmegaConf.load(f'./yaml/cancer_sim.yaml')
    config_hype  = OmegaConf.load(f'./yaml/CRN/crn_hype_{no:02}.yaml')
    config       = OmegaConf.merge(config_base, config_sym, config_hype)

    # hyperparameters
    config.exp.gpu       = gpu
    config.dataset.coeff = coeff
    config.exp.seed      = seed
    
    # simulation
    makeSimData = MakeSimData(config)
    processed_dataset, row_dataset = makeSimData.make()
    processed_dataset_torch = numpy2torch(processed_dataset)

    # Encoder
    encoder = CRNEncoder(config, deepcopy(processed_dataset_torch))
    encoder_rmse, encoder_loss = encoder.train_BR()

    # Decoder
    decoder = CRNDecoder(config, encoder)
    decoder_rmse, decoder_loss = decoder.train_BR()
    
    return encoder_rmse, decoder_rmse, encoder_loss, decoder_loss

# -----------------------------------------------------------------
# GNET
# -----------------------------------------------------------------
def run_gnet(gpu, no, coeff, seed):
    # load config
    config_base  = OmegaConf.load(f'./yaml/gnet_base.yaml')
    config_sym   = OmegaConf.load(f'./yaml/cancer_sim.yaml')
    config_hype  = OmegaConf.load(f'./yaml/GNET/gnet_hype_{no:02}.yaml')
    config       = OmegaConf.merge(config_base, config_sym, config_hype)

    # hyperparameters
    config.exp.gpu       = gpu
    config.dataset.coeff = coeff
    config.exp.seed      = seed
    
    # simulation
    makeSimData = MakeSimData(config)
    processed_dataset, row_dataset = makeSimData.make()
    processed_dataset_torch = numpy2torch(processed_dataset)

    # Encoder
    gnet = GNet(config, processed_dataset_torch)
    encoder_loss = gnet.train_GNet()

    encoder_rmse_valid = gnet.validation_step()
    #encoder_rmse_test  = gnet.test_step()
    gnet.set_resid()
    encoder_rmse_test = gnet.get_multi_step_counterfactual_gnet()
    
    encoder_rmse = {
        "valid_one": encoder_rmse_valid,
        "test_multi" : encoder_rmse_test
    } 
    
    return encoder_rmse, encoder_loss

# -----------------------------------------------------------------
# CT
# -----------------------------------------------------------------
def run_ct(gpu, no, coeff, seed):
    config_base  = OmegaConf.load(f'./yaml/ct_base.yaml')
    config_sym   = OmegaConf.load(f'./yaml/cancer_sim.yaml')
    config_hype  = OmegaConf.load(f'./yaml/CT/ct_hype_{no:02}.yaml')
    config       = OmegaConf.merge(config_base, config_sym, config_hype)

    # hyperparameters
    config.exp.gpu       = gpu
    config.dataset.coeff = coeff
    config.exp.seed      = seed
    
    # simulation
    makeSimData = MakeSimData(config)
    processed_dataset, row_dataset = makeSimData.make()
    processed_dataset_torch = numpy2torch(processed_dataset)

    # Encoder
    encoder = CT(config, deepcopy(processed_dataset_torch))
    encoder_rmse, encoder_loss = encoder.train_BR()
    
    return encoder_rmse,encoder_loss

# -----------------------------------------------------------------
# EDCT
# -----------------------------------------------------------------
def run_edct(gpu, no, coeff, seed):
    # load config
    config_base  = OmegaConf.load(f'./yaml/edct_base.yaml')
    config_sym   = OmegaConf.load(f'./yaml/cancer_sim.yaml')
    config_hype  = OmegaConf.load(f'./yaml/EDCT/EDCT_decoder_{no:02}.yaml')
    config       = OmegaConf.merge(config_base, config_sym, config_hype)

    # hyperparameters
    config.exp.gpu       = gpu
    config.dataset.coeff = coeff
    config.exp.seed      = seed
    
    # simulation
    makeSimData = MakeSimData(config)
    processed_dataset, row_dataset = makeSimData.make()
    processed_dataset_torch = numpy2torch(processed_dataset)

    # Encoder
    encoder = EDCTEncoder(config, deepcopy(processed_dataset_torch))
    encoder_rmse, encoder_loss = encoder.train_BR()
    
    # Decoder
    decoder = EDCTDecoder(config, encoder)
    decoder_rmse, decoder_loss = decoder.train_BR()
    
    return encoder_rmse, decoder_rmse, encoder_loss, decoder_loss

# -----------------------------------------------------------------
# EDTS
# -----------------------------------------------------------------
def run_edts(gpu, no, coeff, seed):
    # load config
    config_base  = OmegaConf.load(f'./yaml/edts_base.yaml')
    config_sym   = OmegaConf.load(f'./yaml/cancer_sim.yaml')
    config_hype  = OmegaConf.load(f'./yaml/EDTS/EDTS_decoder_{no:02}.yaml')
    config       = OmegaConf.merge(config_base, config_sym, config_hype)

    # hyperparameters
    config.exp.gpu       = gpu
    config.dataset.coeff = coeff
    config.exp.seed      = seed
    
    # simulation
    makeSimData = MakeSimData(config)
    processed_dataset, row_dataset = makeSimData.make()
    processed_dataset_torch = numpy2torch(processed_dataset)

    # Encoder
    encoder = EDTSEncoder(config, deepcopy(processed_dataset_torch))
    encoder_rmse, encoder_loss = encoder.train_GAN()
    
    # Decoder
    decoder                    = EDTSDecoder(config, encoder)
    decoder_rmse, decoder_loss = decoder.train_GAN()

    return encoder_rmse, decoder_rmse, encoder_loss, decoder_loss
      
