import logging
from FPVE_config import *
import utils.FPVE_utils as FPVE_utils
import FPVE


def main():
    args = parser.parse_args()
    FPVE_utils.set_logger(args)
    model = FPVE_utils.get_model(args)
    if args.FPVE_fitness_data_ratio > 0:
        (
            train_loader,
            valid_loader,
            test_loader,
            target_idx,
            sensitive_idx,
            FPVE_fitness_loader,
        ) = FPVE_utils.get_fairness_data_FPVE(args)
    else:
        train_loader, valid_loader, test_loader, target_idx, sensitive_idx = (
            FPVE_utils.get_fairness_data_FPVE(args)
        )
    logger = logging.getLogger("train_logger")
    logger.info("=> Model : {}".format(model))

    logger.info("START PRUNING:")
    if args.FPVE_fitness_data_ratio > 0:
        alg = FPVE.FPVE_fairness(
            model,
            train_loader,
            valid_loader,
            test_loader,
            args,
            target_idx,
            sensitive_idx,
            FPVE_fitness_loader=FPVE_fitness_loader,
        )
        alg.run(args.iterative_steps)
    else:
        alg = FPVE.FPVE_fairness(
            model,
            train_loader,
            valid_loader,
            test_loader,
            args,
            target_idx,
            sensitive_idx,
        )
        alg.run(args.iterative_steps)


if __name__ == "__main__":
    main()
