# a hydra script that basically solves the biquadratic problem
# and evaluates it for multiple inner optimization times
from functools import partial
from pathlib import Path

import hydra
import jax
import jax.numpy as jnp
import pandas as pd

from id_in_practice.data import gen_biquadratic_data
from id_in_practice.objectives import outer_objective_biquadratic
from id_in_practice.resolution_biquadratic import solve_biquadratic
from id_in_practice.theory_biquadratic import lower_bound_loss_diff


jax.config.update("jax_enable_x64", True)


def solve_and_evaluate(conf):
    *data, data_factors = gen_biquadratic_data(
        **conf.data,
        seed=conf.seed,
        return_factors=True,
    )
    if conf.with_init:
        key = jax.random.PRNGKey(conf.seed)
        dimension_x = conf.data.dimension_x
        z_0 = jax.random.normal(key, (dimension_x,))
    else:
        z_0 = None
    theta_star_t, final_state = solve_biquadratic(
        data,
        **conf.inner_opt,
        **conf.outer_opt,
        z_0=z_0,
    )

    if conf.verbose:
        print(
            f"For t: {conf.inner_opt.n_inner}, the outer error is {final_state.error}"
            f"and the outer loss is {final_state.value}."
        )
    ts = jnp.linspace(*conf.eval.t_range, conf.eval.nts).astype(int)
    outer_loss = partial(
        outer_objective_biquadratic,
        data=data,
        **conf.inner_opt,
        theta=theta_star_t,
        z_0=z_0,
    )
    outer_losses, _ = jax.vmap(outer_loss)(n_inner=ts)

    lower_bound = lower_bound_loss_diff(
        data_factors=data_factors,
        **conf.inner_opt,
        z_0=z_0,
    )
    # write results to a csv
    # only add a header to the csv if it does not exist
    # the columns of the csv should be the following:
    # inference time t, training time n_inner, outer loss at t,
    # all the info from conf.data, conf.outer_opt and conf.inner_opt
    # and the seed
    # the csv should be saved in the file specified by conf.eval.output_file
    df_results = pd.DataFrame(
        {
            "t": ts,
            "outer_loss": outer_losses,
            "final_outer_loss": final_state.value,
            "theta_err": final_state.error,
            **conf.data,
            **conf.outer_opt,
            **conf.inner_opt,
            "seed": conf.seed,
            "with_init": conf.with_init,
            "lower_bound": lower_bound,
        }
    )
    root_dir = Path(hydra.utils.get_original_cwd())
    output_file = root_dir / conf.eval.output_file
    if not output_file.exists():
        df_results.to_csv(output_file, mode="w", header=True, index=False)
    else:
        df_results.to_csv(output_file, mode="a", header=False, index=False)

    return df_results


@hydra.main(config_path="../config", config_name="conf")
def main(conf):
    df_results = solve_and_evaluate(conf)
    return df_results


if __name__ == "__main__":
    main()
