# this script shows the following observation:
# the optimal hyperparameter depends on the number of inner iterations
# in particular, chosing a higher number of iterations at test time will lead to a decreased performance
# because the hyperparameter is not optimal anymore
from pathlib import Path

import hydra
import jax.numpy as jnp
import pandas as pd

from id_in_practice.data import gen_l2_data
from id_in_practice.grid_search import grid_search_1d
from id_in_practice.objectives import outer_objective_l2


def different_n_iter_lead_to_different_hp(cfg):
    train_data, test_data = gen_l2_data(**cfg.data)
    data_column_keys = [key for key in cfg.data.keys()]
    inner_column_keys = [key for key in cfg.inner.keys()]
    results_columns = ["alpha_gs", "error_gs", "new_error_for_n_inner", "n_inner_delta"]
    columns = data_column_keys + inner_column_keys + results_columns
    df_results = pd.DataFrame(columns=columns)

    def objective_fun(alpha, data):
        return outer_objective_l2(
            alpha,
            data,
            **cfg.inner,
        )[0]

    alpha_gs, error_gs = grid_search_1d(
        objective_fun,
        lookup_range=jnp.linspace(-10, 0, 200),
        data=train_data,
    )
    for n_inner_delta in [10, 50, 100, 500, 1000]:
        test_inner = cfg.inner.copy()
        test_inner["n_inner"] += n_inner_delta
        _, new_error = outer_objective_l2(
            alpha_gs,
            test_data,
            **test_inner,
        )
        df_results = df_results.append(
            {
                **cfg.data,
                **cfg.inner,
                "alpha_gs": alpha_gs,
                "error_gs": error_gs,
                "new_error_for_n_inner": new_error,
                "n_inner_delta": n_inner_delta,
            },
            ignore_index=True,
        )

    header = False
    output_file = Path(cfg.output_dir) / "different_n_iter.csv"
    if not output_file.exists():
        header = True
    df_results.to_csv(output_file, mode="a", header=header, index=False)
    return df_results


@hydra.main(config_path="conf", config_name="different_n_iter")
def different_n_iter_lead_to_different_hp_main(cfg):
    return different_n_iter_lead_to_different_hp(cfg)


if __name__ == "__main__":
    different_n_iter_lead_to_different_hp_main()
