#!/usr/bin/env python
"""Evaluate the Gram matrix's smallest eigenvalue for a specified case."""

from argparse import ArgumentParser
from os import makedirs, path

from pandas import DataFrame, read_csv
from torch import cuda, device

from utils import gram_matrix_smallest_eigval, model_and_data

HEREDIR = path.dirname(path.abspath(__file__))
DATADIR = path.join(HEREDIR, "data")
GRAMDIR = path.join(DATADIR, "gram_min_eigval", "raw")
makedirs(GRAMDIR, exist_ok=True)


def run_gram_condition(
    data_name: str,
    model_name: str,
    width: int,
    model_seed: int,
    skip_exists=True,
    epsilon: float = 0.0,
    tol: float = 0.0,
    load_only: bool = False,
) -> float:
    """Evaluate the Gram matrix condition for a case. Store as csv.

    Args:
        data_name: Name of the data set.
        model_name: Name of the neural network.
        width: Width of the neural network.
        model_seed: Random seed used to initialize the model.
        skip_exists: Whether to skip the computation if the file already exists.
            Default: ``True``.
        epsilon: Small constant added to the Gram matrix diagonal to improve
            convergence of the eigensolver. Default: ``0``.
        tol: Relative tolerance for the eigensolver. Default: ``0``.
        load_only: If ``True``, don't run the computation but simply try loading
            the Gram matrix minimum eigenvalue. Default: ``False``.

    Returns:
        Gram matrix's smallest eigenvalue.
    """
    savepath = path.join(
        GRAMDIR, f"{data_name}_{model_name}_width_{width}_model_seed_{model_seed}.csv"
    )
    skip = path.exists(savepath) and skip_exists
    description = (
        f"Computing Gram matrix minimum eigenvalue ({data_name}, {model_name})"
        + f" model_seed={model_seed}, width={width}"
        + (" (skipping)" if skip else "")
    )
    print(description)

    if skip or load_only:
        min_eigval = read_csv(savepath).values.tolist()[0][0]
    else:
        model, data = model_and_data(data_name, model_name, width, model_seed)
        dev = device("cuda" if cuda.is_available() else "cpu")
        model = model.to(dev)
        # avoid negative values from floating point precision
        min_eigval = abs(
            gram_matrix_smallest_eigval(model, data, epsilon=epsilon, tol=tol)
        )
        df = DataFrame([[min_eigval]], columns=["min_eigval"])
        df.to_csv(savepath, index=False)

    return min_eigval


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--data_name", type=str)
    parser.add_argument("--model_name", type=str)
    parser.add_argument("--width", type=int)
    parser.add_argument("--model_seed", type=int)
    parser.add_argument("--skip_exists", action="store_true", default=False)
    parser.add_argument("--epsilon", type=float, default=0.0)
    parser.add_argument("--tol", type=float, default=0.0)

    args = parser.parse_args()

    run_gram_condition(
        args.data_name,
        args.model_name,
        args.width,
        args.model_seed,
        skip_exists=args.skip_exists,
        epsilon=args.epsilon,
        tol=args.tol,
    )
