"""Evaluate the close-to-linear condition for a case."""

from argparse import ArgumentParser
from os import makedirs, path

from pandas import DataFrame, read_csv
from torch import cuda, device

from experiment_gram_condition import run_gram_condition
from utils import close_to_linear_condition, model_and_data

HEREDIR = path.dirname(path.abspath(__file__))
DATADIR = path.join(HEREDIR, "data")
LINEARDIR = path.join(DATADIR, "close_to_linear", "raw")
makedirs(LINEARDIR, exist_ok=True)


def run_linearity_condition(
    data_name: str,
    model_name: str,
    width: int,
    model_seed: int,
    perturbation_seed: int,
    skip_exists: bool = True,
    load_only: bool = False,
) -> float:
    """Evaluate the linearity condition for a case. Store as csv.

    Args:
        data_name: Name of the data set.
        model_name: Name of the neural network.
        width: Width of the neural network.
        model_seed: Random seed used to initialize the model.
        perturbation_seed: Random seed used to perturb the parameters to evaluate a
            single sample of the condition
        skip_exists: Whether to skip the computation if the file already exists.
            Default: ``True``.
        load_only: If ``True``, don't run the computation but simply try loading
            the result. Default: ``False``.

    Returns:
        Value of the close-to-linearity-condition.
    """
    savepath = path.join(
        LINEARDIR,
        f"{data_name}_{model_name}_width_{width}_model_seed_{model_seed}"
        + f"_perturbation_seed_{perturbation_seed}.csv",
    )
    skip = path.exists(savepath) and skip_exists
    description = (
        f"Computing close-to-linear condition ({data_name}, {model_name})"
        + f" model_seed={model_seed},"
        + f" perturbation_seed={perturbation_seed}, width={width}"
        + (" (skipping)" if skip else "")
    )
    print(description)

    if skip or load_only:
        linearity_condition = read_csv(savepath).values.tolist()[0][0]
    else:
        model, data = model_and_data(data_name, model_name, width, model_seed)
        dev = device("cuda" if cuda.is_available() else "cpu")
        model = model.to(dev)
        # re-use minimum Gram matrix eigenvalue if possible
        min_eigval = run_gram_condition(
            data_name, model_name, width, model_seed, skip_exists=True
        )
        linearity_condition = close_to_linear_condition(
            model, data, min_eigval, perturbation_seed
        )
        df = DataFrame([[linearity_condition]], columns=["C'"])
        df.to_csv(savepath, index=False)

    return linearity_condition


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--data_name", type=str)
    parser.add_argument("--model_name", type=str)
    parser.add_argument("--width", type=int)
    parser.add_argument("--model_seed", type=int)
    parser.add_argument("--perturbation_seed", type=int)
    parser.add_argument("--skip_exists", action="store_true", default=False)

    args = parser.parse_args()

    run_linearity_condition(
        args.data_name,
        args.model_name,
        args.width,
        args.model_seed,
        args.perturbation_seed,
        skip_exists=args.skip_exists,
    )
