"""Accumulate individual runs for the Gram matrix condition into a single file."""

from itertools import product
from os import makedirs, path
from typing import Any, Dict

from pandas import DataFrame

from cases import CASES, CASES_VARY_DATA
from experiment_gram_condition import run_gram_condition

HERE = path.abspath(__file__)
HEREDIR = path.dirname(HERE)
DATADIR = path.join(HEREDIR, "data")
GRAMDIR = path.join(DATADIR, "gram_min_eigval")
makedirs(GRAMDIR, exist_ok=True)


def accumulate_gram_matrix_condition(
    case: Dict[str, Any], ignore_missing: bool = False
) -> DataFrame:
    """Evaluate the Gram matrix's smallest eigenvalue for a case. Store as csv.

    Args:
        case: Dictionary describing the hyper-parameters.
        ignore_missing: If True, ignore missing computations. Otherwise, carry them
            out (may take a long time).

    Returns:
        DataFrame with accumulated data.
    """
    data_name = case["data_name"]
    model_name = case["model_name"]
    widths = [int(w) for w in case["widths"]]
    num_initializations = case["num_initializations"]

    savepath = path.join(GRAMDIR, f"{data_name}_{model_name}.csv")
    columns = ["model_seed", "width", "min_eigval"]
    data = []

    for model_seed, width in product(range(num_initializations), widths):
        try:
            min_eigval = run_gram_condition(
                data_name, model_name, width, model_seed, load_only=ignore_missing
            )
        except FileNotFoundError:
            print(
                "Could not load Gram matrix minimum eigenvalue for "
                + f"({data_name}, {model_name}) model_seed={model_seed}, width={width}"
            )
            if ignore_missing:
                min_eigval = float("nan")
            else:
                min_eigval = run_gram_condition(
                    data_name, model_name, width, model_seed, load_only=False
                )
        data.append([model_seed, width, min_eigval])

    df = DataFrame(data=data, columns=columns)
    df.to_csv(savepath, index=False)

    return df


if __name__ == "__main__":
    for case in CASES:
        accumulate_gram_matrix_condition(case)

    for case in CASES_VARY_DATA:
        accumulate_gram_matrix_condition(case)
