from argparser import get_args
from arl import ARL
from baseline import baseline
from dataloader import loadDataset
from metrics import FairnessMetrics
from VFair import VFair
import torch


def test(args, path):
    elements = path.split('/')
    args.dataset = elements[2]
    model_name = elements[3]
    args.lr_adversary = 1
    train_dataset = loadDataset(
        dataset=args.dataset,
        train_or_test="train",
        embedding_size=args.embedding_size,
    )
    test_dataset = loadDataset(dataset=args.dataset, train_or_test="test")

    # Set the model parameters.
    model_params = {}
    model_params["learner_hidden_units"] = [64, 32]
    model_params["embedding_size"] = train_dataset.categorical_embedding_sizes
    model_params["n_num_cols"] = len(train_dataset.mean_std.keys())

    if model_name == "ARL":
        model = ARL(embedding_size=train_dataset.categorical_embedding_sizes,
                    n_num_cols=len(train_dataset.mean_std.keys()))
    elif model_name == "baseline":
        model = baseline(**model_params)
    elif model_name == "VFair":
        model = VFair(**model_params)
    else:
        print("Unknown model")

    model.load_state_dict(torch.load(path))
    test_cat, test_num, test_target = test_dataset[:]
    with torch.no_grad():
        test_logits, test_sigmoid, test_pred = model.learner(
            test_cat, test_num
        )
    n_iters = 0
    metrics = FairnessMetrics(1, test_dataset.subgroup_indexes, test_dataset.subgroup_minority, eval_every=1)
    # Calculate accuracy metrics.
    metrics.set_var(test_logits.to('cuda'), test_target, 0)
    metrics.set_acc_other(test_pred, test_target, 0)
    metrics.set_acc(test_pred, test_target, n_iters)
    print(f"{metrics.acc[0][0]} {metrics.worst[0][0]} {metrics.diff[0][0]} {metrics.sum[0][0]} {metrics.var[0][0]}")


if __name__ == '__main__':
    args = get_args()
    for dataset in ['compas', 'uci_adult', 'law_school']:
        for method in ['baseline', 'ARL', 'VFair']:
            print(f"{dataset} {method}")
            for i in range(10):
                path = f'./final/{dataset}/{method}/checkpoints/model_checkpoint{i}.pt'
                test(args, path)
