from counterfactual.import_essentials import *
from counterfactual.utils import *
from counterfactual.train import *
from counterfactual.training_module import *
from counterfactual.net import *
from counterfactual.evaluate import *
from counterfactual.baseline import *


if __name__ == "__main__":

    configs = [
        {'name': 'adult', 'path': 'counterfactual/configs/adult.json'}
        # {'name': 'home', 'path': 'counterfactual/configs/home.json'},
        # {'name': 'student', 'path': 'counterfactual/configs/student.json'},
    ]
    # adult_config = json.load(open("counterfactual/configs/adult.json"))
    t_config = json.load(open("counterfactual/configs/trainer.json"))
    student_config = json.load(open("counterfactual/configs/student.json"))

    for config in configs:
        m_config = json.load(open(config['path']))
        m_config["lr"] = 0.03
        m_config["lambda_3"] = 0.4
        
        train(
            module=TransCounterNet(m_config),
            t_configs=t_config,
            logger=pl_loggers.TestTubeLogger(Path('log/'), name=f"{config['name']}/transformer"),
            description="cnn"
        )

