"""
File for optimisation of the Bertsekas example: 100-state Markov chain.
"""

# System imports
import chex
import matplotlib.pyplot as plt
import pickle
import jax
import jax.numpy as jnp
from tqdm import tqdm

# Local imports
from bertsekas_Markov_chain import ExampleThree as Env
from bertsekas_Markov_chain import GameState


def example_three_opt(rng_key: chex.PRNGKey, env: Env, state: GameState):
    """
    Run 1,000 steps, accumulate the D and C matrices, and the d vector. This
    function runs both the surrogate method and the original Bertsekas method.
    """
    # Initialise the parameter vectors and accumulated components.
    rng_key, params_key, costs_key, step_key = jax.random.split(rng_key, 4)
    params_bert = jax.random.normal(params_key, shape=(3,1))
    params_bert_diag = params_bert
    params_bert_identity = params_bert
    params_D_reg = params_bert
    params_surr_1 = params_bert
    params_surr_5 = params_bert
    params_surr_10 = params_bert
    params_surr_50 = params_bert
    params_surr_10_reg = params_bert
    eta = 1 # Learning rate.
    lr_surr = 0.1 # Surrogate learning rate.
    beta = 1 # Regularisation hyperparam.
    D_t = jnp.zeros(shape=(3,3))
    D_t_diag = jnp.zeros(shape=(3,3))
    identity = jnp.identity(3)
    C_t = jnp.zeros(shape=(3,3))
    d_t = jnp.zeros(shape=(3,1))

    param_est_C_d = []
    param_est_bert = []
    param_est_bert_diag = []
    param_est_bert_identity = []
    param_est_reg = []
    param_est_surr_1 = []
    param_est_surr_5 = []
    param_est_surr_10 = []
    param_est_surr_50 = []
    param_est_surr_10_reg = []
    iter_log = []
    for t in tqdm(range(1_000)):
        # Get the cost and observation.
        obs_1 = env.obs(state)
        cost, costs_key = env.costs(state, costs_key) # Transition cost.

        # Take a step, and get the new observation.
        state, _, step_key = env.step(state, step_key)
        obs_2 = env.obs(state)

        # Get the accumulated components, then accumulate.
        add_D = obs_1 @ obs_1.T           # [3x3]
        add_C = obs_1 @ (obs_1 - obs_2).T # [3x3]
        add_d = obs_1 * cost              # [3x1]

        D_t = (t * D_t + add_D) / (t + 1)
        # Element-wise multiplication with identity matrix.
        D_t_diag = (t * D_t_diag + add_D * identity) / (t + 1)
        # Regularisation of D_k with beta*Identity matrix.
        D_t_reg = D_t + (beta / (t + 1)) * identity
        C_t = (t * C_t + add_C) / (t + 1)
        d_t = (t * d_t + add_d) / (t + 1)
        if(t >= 50):
            # Update params.
            params_bert = (
                params_bert -
                eta * jnp.linalg.pinv(D_t) @ (C_t @ params_bert - d_t))
            
            params_bert_diag = (
                params_bert_diag -
                eta * jnp.linalg.pinv(D_t_diag) @ (C_t @ params_bert_diag - d_t))
            
            params_bert_identity = (
                params_bert_identity -
                eta * (C_t @ params_bert_identity - d_t))
            
            params_D_reg = (
                params_D_reg -
                eta * jnp.linalg.pinv(D_t_reg) @ (C_t @ params_D_reg - d_t))
            
            params_surr_m = params_surr_1
            for _ in range(1):
                # Surrogate inner loops.
                params_surr_m = params_surr_m - lr_surr * (
                    eta * (C_t @ params_surr_1 - d_t) +
                    D_t @ (params_surr_m - params_surr_1)
                )
            params_surr_1 = params_surr_m

            params_surr_m = params_surr_5
            for _ in range(5):
                # Surrogate inner loops.
                params_surr_m = params_surr_m - lr_surr * (
                    eta * (C_t @ params_surr_5 - d_t) +
                    D_t @ (params_surr_m - params_surr_5)
                )
            params_surr_5 = params_surr_m
            
            params_surr_m = params_surr_10
            for _ in range(10):
                # Surrogate inner loops.
                params_surr_m = params_surr_m - lr_surr * (
                    eta * (C_t @ params_surr_10 - d_t) +
                    D_t @ (params_surr_m - params_surr_10)
                )
            params_surr_10 = params_surr_m

            params_surr_m = params_surr_50
            for _ in range(50):
                # Surrogate inner loops.
                params_surr_m = params_surr_m - lr_surr * (
                    eta * (C_t @ params_surr_50 - d_t) +
                    D_t @ (params_surr_m - params_surr_50)
                )
            params_surr_50 = params_surr_m

            params_surr_m = params_surr_10_reg
            for _ in range(10):
                # Surrogate inner loops.
                params_surr_m = params_surr_m - lr_surr * (
                    eta * (C_t @ params_surr_10_reg - d_t) +
                    D_t @ (params_surr_m - params_surr_10_reg)
                )
            params_surr_10_reg = params_surr_m


            # Take various measures of the parameters: params = C_t^-1 * d_t, etc.
            param_est = jnp.linalg.pinv(C_t) @ d_t
            
            param_est_C_d.append(param_est[0])
            param_est_bert.append(params_bert[0])
            param_est_bert_diag.append(params_bert_diag[0])
            param_est_bert_identity.append(params_bert_identity[0])
            param_est_reg.append(params_D_reg[0])
            param_est_surr_1.append(params_surr_1[0])
            param_est_surr_5.append(params_surr_5[0])
            param_est_surr_10.append(params_surr_10[0])
            param_est_surr_50.append(params_surr_50[0])
            param_est_surr_10_reg.append(params_surr_10_reg[0])
            iter_log.append(t)
    return (
        param_est_C_d,
        param_est_bert,
        param_est_bert_diag,
        param_est_bert_identity,
        param_est_reg,
        param_est_surr_1,
        param_est_surr_5,
        param_est_surr_10,
        param_est_surr_50,
        param_est_surr_10_reg,
        iter_log,
    )


if __name__ == "__main__":
    # rng_key = jax.random.key(
    #     np.random.randint(low=0, high=np.iinfo(int).max)
    # )
    rng_key = jax.random.key(123)
    # First-pass will be non-vectorised.
    rng_key, env_key = jax.random.split(rng_key, 2)
    env = Env(env_key)
    state = env.init()

    output = example_three_opt(rng_key, env, state)

    with open("results/exp3_outputs.pkl", "wb") as file:
        pickle.dump(output, file)
    
    (
        param_est_C_d,
        param_est_bert,
        param_est_bert_diag,
        param_est_bert_identity,
        param_est_reg,
        param_est_surr_1,
        param_est_surr_5,
        param_est_surr_10,
        param_est_surr_50,
        param_est_surr_10_reg,
        iter_log
    ) = output
    
    # Plot the methods from the Bertsekas paper.
    # plt.plot(iter_log, param_est_C_d, label=r"$\theta_{t} = C_t^{-1}d_t$")
    plt.rcParams.update({'font.size': 14})
    plt.plot(iter_log, param_est_bert, label=r"$\theta_{t} - \hat{D}_t^{-1}(\hat{C}_t\theta_t - \hat{r}_t)$")
    # plt.plot(iter_log, param_est_bert_diag, label=r"Berts diag($D_t$)")
    # plt.plot(iter_log, param_est_bert_identity, label=r"Berts $D_t = Id$")
    plt.plot(iter_log, param_est_surr_1, label=r"Surr-GD, inner$=1$")
    plt.plot(iter_log, param_est_surr_10, label=r"Surr-GD, inner$=10$")
    plt.plot(iter_log, param_est_surr_50, label=r"Surr-GD, inner$=50$")
    # plt.title(r"First component of parameter vectors $\theta_t$, various estimations.")
    plt.xlabel("Steps")
    plt.ylabel(r"Parameter value: $\theta^1$")
    plt.xscale('log')
    plt.grid()
    plt.legend()
    plt.tight_layout()
    plt.savefig("results/ex3_param_values.pdf")
    plt.close()

    # Plot various surrogate inner loop numbers.
    plt.plot(iter_log, param_est_bert, label=r"Berts params")
    plt.plot(iter_log, param_est_bert_identity, label=r"Berts $D_t = Id$")
    plt.plot(iter_log, param_est_surr_1, label=r"Surrogate params, $m=1$")
    plt.plot(iter_log, param_est_surr_5, label=r"Surrogate params, $m=5$")
    plt.plot(iter_log, param_est_surr_10, label=r"Surrogate params, $m=10$")
    plt.plot(iter_log, param_est_surr_50, label=r"Surrogate params, $m=50$")
    plt.title(r"Surrogate first parameter, varying inner loop iterations.")
    plt.xlabel("Steps")
    plt.ylabel("Parameter value")
    plt.grid()
    plt.legend()
    plt.savefig("results/ex3_varying_m.pdf")
    plt.close()

    # Plot various regularisation methods.
    plt.plot(iter_log, param_est_C_d, label=r"$\theta_{t} = C_t^{-1}d_t$")
    plt.plot(iter_log, param_est_reg, label=r"$D_k + \beta Id$")
    plt.plot(iter_log, param_est_surr_10_reg, label=r"Surrogate $D_k + \beta Id$")
    plt.title(r"Regularised methods first parameter.")
    plt.xlabel("Steps")
    plt.ylabel("Parameter value")
    plt.grid()
    plt.legend()
    plt.savefig("results/ex3_varying_regularisation.pdf")
    plt.close()
