import os

import hydra
import numpy as np
from omegaconf import DictConfig

from constants.paths import _results_dir
from experiments.tabular_params import get_alphas, get_flip_set, get_n_retries
from get_dataset import get_tabular_dataset
from strategies import Random_Choice
from utils_training import collect_metrics, create_collective, train_tabular


@hydra.main(version_base=None, config_path="configs")
def main(cfg: DictConfig):
    ds_name = cfg.dataset
    ds = get_tabular_dataset(ds_name)

    if "use_hist_gbm" in cfg:
        use_hist_gbm = cfg.use_hist_gbm
    else:
        use_hist_gbm = True

    train_set, test_set = ds
    features, in_group, labels = train_set
    features_test, in_group_test, labels_test = test_set

    num_retries = get_n_retries()
    alphas = get_alphas()
    # nums_to_flip = get_flip_nums(reduced=ds_name == "celeba")
    nums_to_flip = np.concatenate([get_flip_set(i) for i in range(3)])

    res_shape = (len(alphas), len(nums_to_flip), num_retries)
    result_randlip = {
        "train": np.empty(res_shape, dtype=object),
        "test": np.empty(res_shape, dtype=object),
    }

    for a_i, alpha in enumerate(alphas):
        print(alpha)
        collective_size = alpha * np.count_nonzero(in_group)
        for r_i, num2flip in enumerate(nums_to_flip[nums_to_flip <= collective_size]):
            for seed in range(num_retries):
                action = Random_Choice(num_to_flip=num2flip)
                in_collective = create_collective(in_group, alpha, seed)
                metrics_train, metrics_test = train_tabular(
                    features,
                    labels,
                    in_group,
                    features_test,
                    labels_test,
                    in_group_test,
                    in_collective=in_collective,
                    strategy=action,
                    seed=seed,
                    use_hist_gbm=use_hist_gbm,
                )
                result_randlip["train"][a_i, r_i, seed] = metrics_train
                result_randlip["test"][a_i, r_i, seed] = metrics_test
    result_randlip["train"] = collect_metrics(result_randlip["train"])
    result_randlip["test"] = collect_metrics(result_randlip["test"])

    # save_dir = os.path.join(_results_dir, ds_name)
    if use_hist_gbm:
        save_dir = os.path.join(_results_dir, ds_name)
    else:
        save_dir = os.path.join(_results_dir, "non_hist_gbm", ds_name)
    save_file = f"{str(action)}_{ds_name}.npz"
    os.makedirs(save_dir, exist_ok=True)
    np.savez(
        os.path.join(save_dir, save_file),
        results=result_randlip,
        act_name=str(action),
        ds_name=ds_name,
        alphas=alphas,
        flip_nums=nums_to_flip,
        num_retries=num_retries,
    )


if __name__ == "__main__":
    main()
