import os
import sys

import hydra
import numpy as np
from omegaconf import DictConfig
from sklearn.model_selection import train_test_split

import experiments.tabular_params as hyper_vals
from constants.paths import _results_dir, _temp_dir
from get_dataset import get_tabular_dataset
from strategies import Collective_Strategy
from utils_training import collect_metrics, create_collective, train_tabular


def get_temp_path(
    ds_name, act_name, alpha, flip_set, majority_samples=None, use_hist_gbm=True
):
    if use_hist_gbm:
        temp_dir = os.path.join(_temp_dir, ds_name)
    else:
        temp_dir = os.path.join(_temp_dir, "non_hist_gbm", ds_name)
    if majority_samples is not None:
        temp_dir = os.path.join(temp_dir, f"majoritySamples{majority_samples:04d}")
    temp_file = f"{act_name}_{ds_name}_alpha{alpha:.3f}_flips{flip_set}.npz"
    return temp_dir, temp_file


def collect_all(ds_name=None, act_name=None, use_hist_gbm=True):
    if act_name is None:
        # base_acts = ["ranked_labels", "ranked_distance", "ranked_proba", "kdp"]
        base_acts = ["ranked_labels", "ranked_distance", "ranked_proba"]
    else:
        base_acts = [act_name]
    if ds_name is None:
        dss = [
            # "compas",
            # "adult",
            # "hsls",
            "acs_income",
            "tabular_waterbirds",
            "tabular_waterbirds-full",
            "imported_adult",
            "imported_compas",
            "imported_hsls",
        ]
    else:
        dss = [ds_name]
    # trnss = ["raw", "lfr", "fare"]
    trnss = ["raw", "fare"]
    # trnss = ["raw"]
    # partials = [True, False]
    for ds_name in dss:
        for base_act in base_acts:
            for trns in trnss:
                # for use_partial in partials:
                act_name = f"{base_act}-{trns}"
                # if use_partial:
                #     act_name += "-partial"
                try:
                    collect_results(ds_name, act_name, use_hist_gbm=use_hist_gbm)
                except FileNotFoundError as e:
                    print(e, file=sys.stderr)


def collect_partial():
    base_acts = ["ranked_labels", "ranked_proba", "ranked_distance"]
    dss = [
        # "adult",
        # "compas",
        # "hsls",
        "acs_income",
        "tabular_waterbirds",
        "tabular_waterbirds-full",
        "imported_adult",
        "imported_compas",
        "imported_hsls",
    ]
    trnss = ["raw", "fare"]
    # maj_alphas = get_majority_samples()
    maj_alphas = [10, 100]
    # maj_alphas = [10, 50, 100]
    for ds_name in dss:
        for base_act in base_acts:
            for trns in trnss:
                act_name = f"{base_act}-{trns}-partial"
                for maj_alpha in maj_alphas:
                    try:
                        collect_results(ds_name, act_name, maj_alpha)
                    except FileNotFoundError as e:
                        print(e, file=sys.stderr)


def collect_results(ds_name, act_name, majority_samples=None, use_hist_gbm=True):
    if majority_samples is None:
        alphas = hyper_vals.get_alphas()
    else:
        # alphas = [0.1, 0.2, 0.25, 0.3, 0.4]
        alphas = [0.2, 0.25, 0.3]
    flip_sets = [0, 1, 2]

    num_retries = hyper_vals.get_n_retries()

    true_flip_nums = np.concatenate([hyper_vals.get_flip_set(i) for i in flip_sets])
    flip_nums = np.zeros_like(true_flip_nums)
    res_shape = (
        len(alphas),
        len(flip_nums),
        num_retries,
    )
    cur_results = {
        "train": np.empty(res_shape, dtype=object),
        "test": np.empty(res_shape, dtype=object),
    }

    best_ks = np.zeros([len(alphas), len(flip_nums)])
    for a_i, alpha in enumerate(alphas):
        f_i = 0
        for flipset in flip_sets:
            temp_dir, temp_file = get_temp_path(
                ds_name,
                act_name,
                alpha,
                flipset,
                majority_samples=majority_samples,
                use_hist_gbm=use_hist_gbm,
            )
            # print(os.path.join(temp_dir, temp_file))
            loaded = dict(np.load(os.path.join(temp_dir, temp_file), allow_pickle=True))
            cur_flips = loaded["nums_to_flip"]

            cur_res = loaded["results"].item()
            for r_i, _ in enumerate(cur_flips):
                best_ks[a_i, f_i] = loaded["best_ks"][r_i]
                flip_nums[f_i] = loaded["nums_to_flip"][r_i]
                for s_i in range(num_retries):
                    cur_results["train"][a_i, f_i, s_i] = cur_res["train"][r_i, s_i]
                    cur_results["test"][a_i, f_i, s_i] = cur_res["test"][r_i, s_i]
                f_i += 1
    if np.any(flip_nums != true_flip_nums):
        print("Mismatch in flip nums", file=sys.stderr)
        print(flip_nums[flip_nums != true_flip_nums])
        print(true_flip_nums[flip_nums != true_flip_nums])

    cur_results["train"] = collect_metrics(cur_results["train"])
    cur_results["test"] = collect_metrics(cur_results["test"])

    if use_hist_gbm:
        save_dir = os.path.join(_results_dir, ds_name)
    else:
        save_dir = os.path.join(_results_dir, "non_hist_gbm", ds_name)
    if majority_samples is not None:
        save_dir = os.path.join(save_dir, f"majoritySamples{majority_samples:04d}")
    os.makedirs(save_dir, exist_ok=True)
    np.savez(
        os.path.join(save_dir, f"{act_name}_{ds_name}.npz"),
        results=cur_results,
        act_name=act_name,
        ds_name=ds_name,
        alphas=alphas,
        flip_nums=flip_nums,
        num_retries=num_retries,
        best_ks=best_ks,
    )


def train_validation(
    ks,
    features_train_full,
    sensitive_train_full,
    labels_train_full,
    strategy: Collective_Strategy,
    alpha: float,
    n_seeds: int,
    majority_samples=None,
    use_hist_gbm=True,
):
    (
        features_train,
        features_val,
        sensitive_train,
        sensitive_val,
        labels_train,
        labels_val,
    ) = train_test_split(
        features_train_full,
        sensitive_train_full,
        labels_train_full,
        test_size=0.15,
        random_state=0,
    )

    result_metrics = {
        "train": np.empty((len(ks), n_seeds), dtype=object),
        "val": np.empty((len(ks), n_seeds), dtype=object),
    }

    for k_idx, k in enumerate(ks):
        strategy.set_k(k)
        for seed in range(n_seeds):
            in_collective = create_collective(
                sensitive_train,
                alpha,
                majority_samples=majority_samples,
                seed=seed,
            )
            mf_train, mf_test = train_tabular(
                features_train,
                labels_train,
                sensitive_train,
                features_val,
                labels_val,
                sensitive_val,
                in_collective,
                strategy=strategy,
                seed=seed,
                use_hist_gbm=use_hist_gbm,
            )
            result_metrics["train"][k_idx, seed] = mf_train
            result_metrics["val"][k_idx, seed] = mf_test

    train_results = collect_metrics(result_metrics["train"])
    val_results = collect_metrics(result_metrics["val"])
    return train_results, val_results


def train_all(
    features_train,
    features_test,
    sensitive_train,
    sensitive_test,
    labels_train,
    labels_test,
    strategy,
    alpha,
    n_seeds,
    majority_samples=None,
    use_hist_gbm=True,
):
    result_metrics = {
        "train": np.empty([n_seeds], dtype=object),
        "test": np.empty([n_seeds], dtype=object),
    }

    for seed in range(n_seeds):
        in_collective = create_collective(
            sensitive_train,
            alpha,
            majority_samples=majority_samples,
            seed=seed,
        )
        mf_train, mf_test = train_tabular(
            features_train,
            labels_train,
            sensitive_train,
            features_test,
            labels_test,
            sensitive_test,
            in_collective,
            strategy=strategy,
            seed=seed,
            use_hist_gbm=use_hist_gbm,
        )
        result_metrics["train"][seed] = mf_train
        result_metrics["test"][seed] = mf_test

    train_results = result_metrics["train"]
    test_results = result_metrics["test"]
    return train_results, test_results


def find_best_k(val_results, ks):
    fairness_fields = ["Equal odds difference", "Statistical Parity"]
    rankings = np.zeros([len(ks)])
    best_ks = {}
    for field in fairness_fields:
        cur_scores = val_results[field].mean(axis=1)
        cur_scores = np.abs(cur_scores)
        cur_ranks = np.argsort(cur_scores)
        for i, rank in enumerate(cur_ranks):
            rankings[rank] += i
        best_ks[field] = ks[cur_ranks[0]]
    best_k = ks[np.argmin(rankings)]
    best_ks["Overall"] = best_k
    return best_ks["Overall"]


@hydra.main(version_base=None, config_path="configs")
def main(cfg: DictConfig):
    ds_name = cfg.dataset
    ds = get_tabular_dataset(ds_name)
    alpha = cfg.alpha
    flip_set = cfg.flip_set
    n_seeds = hyper_vals.get_n_retries()
    if "use_hist_gbm" in cfg:
        use_hist_gbm = cfg.use_hist_gbm
    else:
        use_hist_gbm = True
    # if "use_partial" in cfg:
    #     use_partial = cfg.use_partial
    # else:
    #     use_partial = False
    if "majority_samples" in cfg:
        majority_samples = cfg.majority_samples
    else:
        majority_samples = None

    feature_transform = hydra.utils.instantiate(cfg.transform)
    if feature_transform.requires_fit:
        feature_transform.fit(features=ds[0][0], labels=ds[0][2], in_group=ds[0][1])

    # strategy = hydra.utils.instantiate(cfg.strategy, data_transform=feature_transform)

    # ratios_to_flip = get_flip_ratios()
    # nums_to_flip = nums_to_flip = hyper_vals.get_flip_nums(reduced=ds_name == "celeba")
    nums_to_flip = hyper_vals.get_flip_set(flip_set)
    ks = np.arange(0, 101, 10)
    ks[0] = 1

    # filename = get_full_path(str(strategy), cfg.dataset, alpha)
    # print(filename)

    train_set, test_set = ds
    features_train, sensitive_train, labels_train = train_set
    features_test, sensitive_test, labels_test = test_set

    best_ks = np.zeros(len(nums_to_flip), dtype=int)
    results_shape = (len(nums_to_flip), n_seeds)
    result_metrics = {
        "train": np.empty(results_shape, dtype=object),
        "test": np.empty(results_shape, dtype=object),
    }

    # collective_size = alpha * np.count_nonzero(sensitive_train)

    for r_i, n_flips in enumerate(nums_to_flip):
        print(f"Running for {n_flips} flips", file=sys.stderr)

        strategy = hydra.utils.instantiate(
            cfg.strategy,
            data_transform=feature_transform,
            num_to_flip=n_flips,
            use_partial=majority_samples is not None,
        )

        try:
            _, val_results = train_validation(
                ks,
                features_train,
                sensitive_train,
                labels_train,
                strategy,
                alpha,
                n_seeds=5,
                majority_samples=majority_samples,
                use_hist_gbm=use_hist_gbm,
            )

            best_k = find_best_k(val_results, ks)
            best_ks[r_i] = best_k

            strategy.set_k(best_k)
        except AttributeError:
            print("No k, no need for validation", file=sys.stderr)

        result_metrics["train"][r_i, :], result_metrics["test"][r_i, :] = train_all(
            features_train,
            features_test,
            sensitive_train,
            sensitive_test,
            labels_train,
            labels_test,
            strategy,
            alpha,
            n_seeds,
            majority_samples=majority_samples,
            use_hist_gbm=use_hist_gbm,
        )

    # train_results = collect_metric_frames(result_metrics["train"])
    # test_results = collect_metric_frames(result_metrics["test"])

    # temp_dir = os.path.join(_temp_dir, ds, cfg.dataset)
    # temp_file = f"eqdp_{cfg.dataset}_alpha{alpha:.3f}.npz"
    temp_dir, temp_file = get_temp_path(
        cfg.dataset,
        str(strategy),
        alpha,
        flip_set,
        majority_samples=majority_samples,
        use_hist_gbm=use_hist_gbm,
    )
    os.makedirs(temp_dir, exist_ok=True)
    np.savez(
        os.path.join(temp_dir, temp_file),
        results=result_metrics,
        best_ks=best_ks,
        ks=ks,
        nums_to_flip=nums_to_flip,
        alpha=alpha,
        num_retries=n_seeds,
    )
    print(f"Saved to {os.path.join(temp_dir, temp_file)}", file=sys.stderr)


if __name__ == "__main__":
    main()
