from counterfactual.import_essentials import *
from counterfactual.utils import *
from counterfactual.train import *
from counterfactual.training_module import *
from counterfactual.net import *
from counterfactual.evaluate import *
from counterfactual.baseline import *


if __name__ == "__main__":
    # configs = [
    #     {'name': 'adult', 'path': 'counterfactual/configs/adult.json'}
    #     {'name': 'home', 'path': 'counterfactual/configs/home.json'},
    #     {'name': 'student', 'path': 'counterfactual/configs/student.json'},
    # ]
    m_config = json.load(open("counterfactual/configs/adult.json"))
    t_config = json.load(open("counterfactual/configs/trainer.json"))
    t_config['max_epochs'] = 200
    t_config['deterministic'] = True

    for lm in range(1, 21):
        m_config['lambda_2'] = lm * 0.01 * 5
        train(
                module=CounterfactualModel2Optimizers(m_config),
                t_configs=t_config,
                logger=pl_loggers.TestTubeLogger(Path('log/'),  name=f"adult/tradeoff"),
                description="cross entropy"
        )
