# we use this file to implement the formulas
# from the paper
from pathlib import Path

import hydra
import jax
import jax.numpy as jnp
import pandas as pd

from id_in_practice.data import gen_biquadratic_data


def lower_bound_loss_diff(
    data_factors,
    n_inner,
    inner_lr=1e-2,
    z_0=None,
):
    # the lower bound is written in latex as
    # - \frac12 \| \left(P(C) - P(C E_t U )\right) (CE_t \kappa - Cw^\star + \Gamma P_t(H) z_0)  \|_2^2
    sqrt_H, U, kappa, sqrt_G, omega = data_factors
    t = n_inner
    dimension_latent_inner, dimension_x = sqrt_H.shape
    C = sqrt_G @ sqrt_H.T
    proj_C = C @ jnp.linalg.pinv(C)
    H_bar = sqrt_H @ sqrt_H.T
    H = sqrt_H.T @ sqrt_H
    P_t = jnp.linalg.matrix_power(jnp.eye(dimension_latent_inner) - inner_lr * H_bar, t)
    P_t_H = jnp.linalg.matrix_power(jnp.eye(dimension_x) - inner_lr * H, t)
    E_t = (P_t - jnp.eye(dimension_latent_inner)) @ jnp.linalg.inv(H_bar)
    proj_CEU = C @ E_t @ U @ jnp.linalg.pinv(C @ E_t @ U)
    proj_diff = proj_C - proj_CEU
    omega_proj = jnp.linalg.pinv(C) @ proj_C @ omega
    vector = C @ (E_t @ kappa + omega_proj)
    if z_0 is not None:
        vector += sqrt_G @ P_t_H @ z_0
    return - 0.5 * jnp.linalg.norm(proj_diff @ vector) ** 2


def optimal_loss(
    data_factors,
    n_inner,
    inner_lr=1e-2,
    z_0=None,
):
    # the optimal loss is written in latex as
    # \frac12 \| \left(P(C) - P(C E_t U )\right) (CE_t \kappa - Cw^\star + \Gamma P_t(H) z_0)  \|_2^2
    # \+ \frac12 \|z^\star_{C^\perp} + P(C^\perp) \Gamma P_t(H) z_0\|_2^2
    sqrt_H, U, kappa, sqrt_G, omega = data_factors
    t = n_inner
    dimension_latent_inner, dimension_x = sqrt_H.shape
    dimension_latent_outer = sqrt_G.shape[0]
    lower_bound = lower_bound_loss_diff(
        data_factors=data_factors,
        n_inner=n_inner,
        inner_lr=inner_lr,
        z_0=z_0,
    )

    C = sqrt_G @ sqrt_H.T
    proj_C_perp = jnp.eye(dimension_latent_outer) - C @ jnp.linalg.pinv(C)
    P_t_H = jnp.linalg.matrix_power(jnp.eye(dimension_x) - inner_lr * sqrt_H.T @ sqrt_H, t)
    vector = - proj_C_perp @ omega
    if z_0 is not None:
        vector += proj_C_perp @ sqrt_G @ P_t_H @ z_0
    return -lower_bound + 0.5 * jnp.linalg.norm(vector) ** 2


def evaluate_lower_bound(
    nts=100,
    inner_lr=1e-2,
    dimension_x=10,
    output_file=None,
    seed=0,
    t_range=(0, 1000),
    **dims,
):
    *_, data_factors = gen_biquadratic_data(
        dimension_x=dimension_x,
        **dims,
        seed=seed,
        return_factors=True,
    )
    key = jax.random.PRNGKey(seed)
    z_0 = jax.random.normal(key, (dimension_x,))
    ts = range(*t_range, int(t_range[1]/nts))
    lower_bounds = [
        lower_bound_loss_diff(
            n_inner=t,
            data_factors=data_factors,
            inner_lr=inner_lr,
            z_0=z_0,
        )
        for t in ts
    ]
    df_results = pd.DataFrame(
        {
            "t": ts,
            **dims,
            "dimension_x": dimension_x,
            "inner_lr": inner_lr,
            "seed": seed,
            "lower_bounds": lower_bounds,
        }
    )
    root_dir = Path(hydra.utils.get_original_cwd())
    output_file = root_dir / output_file
    if not output_file.exists():
        df_results.to_csv(output_file, mode="w", header=True, index=False)
    else:
        df_results.to_csv(output_file, mode="a", header=False, index=False)

    return df_results


@hydra.main(config_path="../config", config_name="strongly_convex_large")
def evaluate_lower_bound_main(conf):
    return evaluate_lower_bound(
        **conf.data,
        **conf.eval,
        seed=conf.seed,
        inner_lr=conf.inner_opt.inner_lr,
    )


if __name__ == "__main__":
    evaluate_lower_bound_main()
