base_config: configs/base/base.yaml
encod_data_dim: null # input_dims
encod_seq_len: null # subseq_size
recon_seq_len: null # subseq_size
# ext_input_dim: null # ext_input_dim

ic_post_var_min: 0.0001
ic_enc_seq_len: 0
cd_pass_rate: 0.0
cd_rate: 0.0

ext_input_dim: 0
ic_enc_dim: 64
ci_enc_dim: 64
ci_lag: 2
con_dim: 64
co_dim: 64
ic_dim: 64
gen_dim: 64
fac_dim: 128

dropout_rate: 0.0
cell_clip: 1.0
loss_scale: 0.1
recon_reduce_mean: true
variational: true

weight_decay: 0.001
l2_start_epoch: 0
l2_increase_epoch: 10
l2_ic_enc_scale: 0.0
l2_ci_enc_scale: 0.0
l2_gen_scale: 0.0
l2_con_scale: 0.0
kl_start_epoch: 0
kl_increase_epoch: 10
kl_ic_scale: 1.0e-5
kl_co_scale: 1.0e-5

# weight_decay: 0.01
# l2_start_epoch: 0
# l2_increase_epoch: 1
# l2_ic_enc_scale: 0.0
# l2_ci_enc_scale: 0.0
# l2_gen_scale: 0.0
# l2_con_scale: 0.0
# kl_start_epoch: 0
# kl_increase_epoch: 1
# kl_ic_scale: 1.0e-6
# kl_co_scale: 1.0e-6

#                 encode_mode = "ts_encoder",

# from experiments.configs.base_config import BaseConfig

# class LFADSConfig(BaseConfig):
#     def __init__(self, 
#                 #  input_dims = 6,
#                 # subseq_size = 128,

                
#                 

#                 **kwargs
#                  ):
#         super().__init__(model_type="lfads", **kwargs)

        


#         self.encod_data_dim = self.input_dims
#         self.encod_seq_len = self.subseq_size
#         self.recon_seq_len = self.subseq_size
#         self.ext_input_dim = ext_input_dim
#         self.ic_enc_seq_len = ic_enc_seq_len

#         self.ic_post_var_min =ic_post_var_min
#         self.cd_pass_rate = cd_pass_rate
#         self.cd_rate = cd_rate

#         self.ic_enc_dim = ic_enc_dim
#         self.ci_enc_dim = ci_enc_dim
#         self.ci_lag = ci_lag
#         self.con_dim = con_dim
#         self.co_dim = co_dim
#         self.ic_dim = ic_dim
#         self.gen_dim = gen_dim
#         self.fac_dim = fac_dim


#         self.dropout_rate = dropout_rate
#         self.cell_clip = cell_clip
#         self.loss_scale = loss_scale
#         self.recon_reduce_mean = recon_reduce_mean

#         self.weight_decay = weight_decay
#         self.l2_start_epoch = l2_start_epoch
#         self.l2_increase_epoch = l2_increase_epoch
#         self.l2_ic_enc_scale = l2_ic_enc_scale
#         self.l2_ci_enc_scale = l2_ci_enc_scale
#         self.l2_gen_scale = l2_gen_scale
#         self.l2_con_scale = l2_con_scale
#         self.kl_start_epoch = kl_start_epoch
#         self.kl_increase_epoch = kl_increase_epoch
#         self.kl_ic_scale = kl_ic_scale
#         self.kl_co_scale = kl_co_scale

#         self.variational = variational
#         self.encode_mode = encode_mode


# lfads_configs = {}


# lfads_configs["lfads_har"] = LFADSConfig(data_name="har",
#                                         data_type="subseq", # fullts is much more expensive, but it works
#                                         device="cuda:0",
#                                         input_dims=6,
#                                         batch_size=128,
#                                         subseq_size=128,
#                                         epochs=10,
#                                         eval_every_n = 20,
#                                         p_train=1,
#                                         p_val=0.5
#                                         )


# lfads_configs["lfads_ecg"] = LFADSConfig(data_name="ecg",
#                                         data_type="subseq", # fullts is much more expensive, but it works
#                                         device="cuda:0",
#                                         input_dims=2,
#                                         batch_size=128,
#                                         eval_batch_size=128,
#                                         subseq_size=500,
#                                         epochs=100,
#                                         eval_every_n = 10,
#                                         downsample_factor=5,
#                                         p_train=0.2,
#                                         p_val=0.1
#                                         )


# lfads_configs["lfads_ppg"] = LFADSConfig(data_name="ppg",
#                                         data_type="subseq", # fullts is much more expensive, but it works
#                                         device="cuda:0",
#                                         input_dims=1,
#                                         batch_size=64,
#                                         eval_batch_size=64,
#                                         epochs=100,
#                                         eval_every_n = 10,
#                                         subseq_size=960,
#                                         downsample_factor=4,  # downsamples to 8 samples per second         
#                                         p_train=0.2,
#                                         p_val=0.1
#                                         )
