""" Ablation study on generator convergence in terms of comparability ratio on Adult dataset """

import json
import pandas as pd
import numpy as np
from typing import Callable

from dataset import fetch_dataset, Dataset, FeatIndex
from model import Synthesizer, NeuralNets, SenSeI, LCIFR, DRO
from eval import Evaluator


def main(n_repeat=5):
    dataset = fetch_dataset("adult")
    train_X, train_y = dataset.train_data(scale="num")
    train_comp_data = dataset.comp_data(batch_size=1024, train=True, scale="num")

    save = {"cat": [], "num": [], "sen": []}
    for i in range(n_repeat):
        syn = Synthesizer(epochs=1000)
        cat_ratio, num_ratio, sen_ratio = syn.fit(
            train_X, dataset.feat_idx, comp_data=train_comp_data, cond=True, comp_func=dataset.is_comparable
        )
        save["cat"].append(cat_ratio)
        save["num"].append(num_ratio)
        save["sen"].append(sen_ratio)

    with open("./ablation/adult_convergence.json", "w") as f:
        json.dump(save, f)

    return


if __name__ == "__main__":
    main()
