from itertools import product
from os import makedirs, path
from typing import Any, Dict

from pandas import DataFrame

from cases import CASES, CASES_VARY_DATA
from experiment_linearity_condition import run_linearity_condition

HERE = path.abspath(__file__)
HEREDIR = path.dirname(HERE)
DATADIR = path.join(HEREDIR, "data")
LINEARDIR = path.join(DATADIR, "close_to_linear")
makedirs(LINEARDIR, exist_ok=True)


def accumulate_linearity_condition(
    case: Dict[str, Any], ignore_missing: bool = False
) -> DataFrame:
    """Accumulate the linearity condition for a case. Store as csv.

    Args:
        case: Dictionary describing the hyper-parameters.
        ignore_missing: If True, ignore missing computations. Otherwise, carry them
            out (may take a long time).

    Returns:
        DataFrame with accumulated data.
    """
    data_name = case["data_name"]
    model_name = case["model_name"]
    widths = [int(w) for w in case["widths"]]
    num_initializations = case["num_initializations"]
    num_perturbations = case["num_perturbations"]

    savepath = path.join(LINEARDIR, f"{data_name}_{model_name}.csv")
    columns = ["model_seed", "perturbation_seed", "width", "C'"]
    data = []

    for model_seed, perturbation_seed, width in product(
        range(num_initializations), range(num_perturbations), widths
    ):
        try:
            linearity_condition = run_linearity_condition(
                data_name,
                model_name,
                width,
                model_seed,
                perturbation_seed,
                load_only=ignore_missing,
            )
        except FileNotFoundError:
            print(
                f"Could not load close-to-linear condition ({data_name}, {model_name})"
                + f" model_seed={model_seed}, perturbation_seed={perturbation_seed}, width={width}"
            )
            if ignore_missing:
                linearity_condition = float("nan")
            else:
                linearity_condition = run_linearity_condition(
                    data_name,
                    model_name,
                    width,
                    model_seed,
                    perturbation_seed,
                    load_only=False,
                )
        data.append([model_seed, perturbation_seed, width, linearity_condition])

    df = DataFrame(
        data=data,
        columns=columns,
    )
    df.to_csv(savepath, index=False)

    return df


if __name__ == "__main__":
    for case in CASES:
        accumulate_linearity_condition(case)

    for case in CASES_VARY_DATA:
        accumulate_linearity_condition(case)
