# main.py
from vssm import *
from tests.data.data_loader import DataLoader


def main():
    datasets = [
        {'name': 'TicTacToe: Classify whether x has won the game', 'uci_id': 101, 'target_column': 'class', 'positive_class_value': 'positive'},
        {'name': 'Mushroom: Determine if mushrooms are poisonous or edible', 'uci_id': 73, 'target_column': 'poisonous', 'positive_class_value': 'p'},
    ]
    for d in datasets:
        run_model(**d)

def run_model(name, uci_id, target_column, positive_class_value):
    # fetch dataset
    loader = DataLoader()
    x_train, x_test, y_train, y_test = loader.load_uci(uci_id=uci_id,
                                                       target_column=target_column,
                                                       positive_class_value=positive_class_value)

    print(f'Training VSSM on the following dataset: {name}')
    model = VSSM(logging_enabled=False)
    model.fit(x_train, y_train)
    model.multi_predict(x_test, y_test)
    print(model.get_pyHTN_conds())
    print(model.get_lit_priorities())


if __name__ == "__main__":
    main()

