from torch import nn

class Config:
    config = {
        "train_params": {
            "distance_dim": 1,
            "lam1": 0.1,
            "lam2": 0.1,
            "lam3": 0.1,
            "stage1_iter": 1,
            "mi_iter": 1,
            "stage2_iter": 1,
            "treatment_weight_decay": 0.0,
            "instrumental_weight_decay": 0.0,
            "covariate_weight_decay": 0.1,
            "s1_weight_decay": 0.0,
            "odds_weight_decay": 0.0,
            "selection_weight_decay": 0.0,
            "r0_weight_decay": 0.0,
            "r1_weight_decay": 0.0,
            "S1_weight_decay": 0.0,
            "S0_weight_decay": 0.0,
            "y_weight_decay": 0.0,
            "y1_weight_decay": 0.0,
            "lam_y": 0.1,
            "stage1_S1_iter": 20,
            "covariate_iter": 20,
            "odds_iter": 100,
            "n_epoch": 100,
            "epoch": 20,
            "z_dim": 10,
            "z_ratio": 0.0, 
            "lam4": 0.1,
        }
    }
    # 1.0 15; 
    experiment_num = 10
    c_strength = 1.0
    u_strength = 1.0
    sample_num = 5000

    def __init__(self):
      self.networks = self.initialize_model_structure()

    def initialize_model_structure(self):
       networks = [
            # treatment_net 0
            nn.Sequential(nn.Linear(1, 1)), 
            # instrumental_net 1
            nn.Sequential(nn.Linear(3, 1)
                          ),
            # selection_net 2
            nn.Sequential(nn.Linear(3, 1),
                          nn.Sigmoid()
                          ),
            # covariate_net 3
            nn.Sequential(nn.Linear(2, 8),
                          nn.ReLU(),
                          nn.Linear(8, 1),
                          ),
            # phis_net 4
            nn.Sequential(nn.Linear(10, 4),
                          nn.ReLU(),
                          nn.Linear(4, 1),
                          nn.Sigmoid()
                          ),
            # phit_net 5
            nn.Sequential(nn.Linear(10, 4),
                          nn.ReLU(),
                          nn.Linear(4, 1),
                          nn.Sigmoid()
                          ),
            # odds_net 6
            nn.Sequential(nn.Linear(3, 1),
                          ),
            # S_net 7
            nn.Sequential(nn.Linear(5, 1),
                          nn.Sigmoid()
                          ),
            # s_net 8
            nn.Sequential(nn.Linear(3, 1),
                          nn.Sigmoid()
                          ),
            # y_net 9
            nn.Sequential(nn.Linear(3, 1),
                          ),
            # y1_net 10
            nn.Sequential(nn.Linear(3, 1),
                          ),
            # h2_net 11
            nn.Sequential(nn.Linear(2, 4),
                          nn.ReLU(),
                          nn.Linear(4, 1),
                          ),
        ]
       return networks