from functools import partial
from pathlib import Path

import jax
import jax.numpy as jnp
import hydra
import pandas as pd

from id_in_practice.data import gen_l2_data_diag_correl
from id_in_practice.objectives import outer_objective_l2
from id_in_practice.resolution_l2 import solve_l2


@partial(jax.jit, static_argnames=(
    "n_inner",
    "m_inner",
    "inner_lr",
    "original_objective_fun",
    "train_data",
))
def gradient_correlations_differences(
    m_inner,
    n_inner,
    alpha,
    original_objective_fun,
    train_data,
    inner_lr,
):
    def objective_fun(alpha, implicit_diff=True):
        return original_objective_fun(
            alpha,
            train_data,
            m_inner=m_inner,
            n_inner=n_inner,
            inner_lr=inner_lr,
            implicit_diff=implicit_diff,
        )[0]

    objective_fun_id = partial(objective_fun, implicit_diff=True)
    objective_fun_unrolled = partial(objective_fun, implicit_diff=False)

    grad_id = jax.grad(objective_fun_id)(alpha).flatten()
    grad_unrolled = jax.grad(objective_fun_unrolled)(alpha).flatten()
    gradient_correlation = jnp.dot(grad_id, grad_unrolled) / jnp.linalg.norm(grad_id) / jnp.linalg.norm(grad_unrolled)
    gradient_difference = jnp.linalg.norm(grad_id - grad_unrolled) / jnp.linalg.norm(grad_unrolled)
    return gradient_correlation, gradient_difference


def gradient_comparison_evolution_l2_diag_correl(cfg):
    # data generation
    train_data, test_data = gen_l2_data_diag_correl(**cfg.data)

    # callback configuration
    gradient_correlations = []
    gradient_differences = []

    def correlations_differences_callback(alpha, state):
        iter_num = state.iter_num
        if not iter_num % cfg.log_period == 0:
            return

        def objective_fun(alpha, implicit_diff=True):
            return outer_objective_l2(
                alpha,
                train_data,
                implicit_diff=implicit_diff,
                **cfg.inner,
            )[0]

        objective_fun_id = partial(objective_fun, implicit_diff=True)
        objective_fun_unrolled = partial(objective_fun, implicit_diff=False)

        grad_id = jax.grad(objective_fun_id)(alpha).flatten()
        grad_unrolled = jax.grad(objective_fun_unrolled)(alpha).flatten()
        norm_unrolled = jnp.linalg.norm(grad_unrolled)
        gradient_correlation = jnp.dot(grad_id, grad_unrolled) / jnp.linalg.norm(grad_id) / norm_unrolled
        gradient_difference = jnp.linalg.norm(grad_id - grad_unrolled) / norm_unrolled
        gradient_correlations.append(gradient_correlation.item())
        gradient_differences.append(gradient_difference.item())

    solve_l2(
        train_data,
        callback=correlations_differences_callback,
        diag=True,
        implicit_diff=False,
        **cfg.outer,
        **cfg.inner_optim,
    )

    # writing results in a common csv file
    outer_column_keys = [key for key in cfg.outer.keys()]
    inner_column_keys = [key for key in cfg.inner.keys()]
    inner_optim_column_keys = [f"optim_{key}" for key in cfg.inner_optim.keys()]
    data_column_keys = [key for key in cfg.data.keys()]
    grad_comparison_column_keys = ["iteration", "gradient_correlation", "gradient_difference"]
    columns = (
        outer_column_keys +
        inner_column_keys +
        inner_optim_column_keys +
        data_column_keys +
        grad_comparison_column_keys
    )
    df = pd.DataFrame(columns=columns)
    for i_log, (grad_correl, grad_diff) in enumerate(zip(gradient_correlations, gradient_differences)):
        df.loc[i_log] = [
            cfg.outer[key] for key in cfg.outer.keys()
        ] + [
            cfg.inner[key] for key in cfg.inner.keys()
        ] + [
            cfg.inner_optim[key] for key in cfg.inner_optim.keys()
        ] + [
            cfg.data[key] for key in cfg.data.keys()
        ] + [
            i_log * cfg.log_period,
            grad_correl,
            grad_diff,
        ]
    header = False
    output_file = Path(cfg.output_dir) / "gradient_comparison.csv"
    if not output_file.exists():
        header = True
    df.to_csv(output_file, mode="a", header=header, index=False)


@hydra.main(config_path="config", config_name="gradient_comparison")
def gradient_comparison_evolution_l2_diag_correl_main(cfg):
    gradient_comparison_evolution_l2_diag_correl(cfg)


if __name__ == "__main__":
    gradient_comparison_evolution_l2_diag_correl_main()
