
# %% Imports
import pickle

from time import time
import tensorflow as tf
import matplotlib.pyplot as plt

from pavi.dual.models import (
    CascadingFlows
)
from generative_hbms.HV import (
    N1_full,
    N2_full,
    seed,
    na_val_idx,
    cf_hbm_kwargs,
    cf_val_data,
    stack_data
)

# %% CascadingFlows kwargs

d = 32

rff_kwargs = dict(
    units_per_layers=[d]
)

cf_kwargs = dict(
    **cf_hbm_kwargs,
    auxiliary_variables_size=d,
    rff_kwargs=rff_kwargs,
    nf_kwargs={},
    amortized=False,
    auxiliary_target_type="identity"
)

# %% We fit the training data - non amortized
samples = {}
n_draws = 256
val_idx = na_val_idx

time_1 = time()
cf = CascadingFlows(
    **cf_kwargs,
)

# %%
cf.compile(
    train_method="reverse_KL",
    n_theta_draws_per_x=8,
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2)
)

# %%
history = cf.fit(
    {
        key: value[val_idx:val_idx + 1]
        for key, value in cf_val_data.items()
        if key in [
            f"c_{n1}_{n2}"
            for n1 in range(N1_full)
            for n2 in range(N2_full)
        ]
    },
    batch_size=1,
    epochs=50_000,
    shuffle=True,
)
time_2 = time()

# %%

plt.plot(history.history["reverse_KL"])
plt.yscale("symlog")
plt.ylabel("ELBO")
plt.xlabel("epoch")

# %%

repeated_observed_data = {
    key: tf.repeat(
        value[val_idx:val_idx + 1],
        repeats=(n_draws,),
        axis=0
    )
    for key, value in cf_val_data.items()
    if key in [
        f"c_{n1}_{n2}"
        for n1 in range(N1_full)
        for n2 in range(N2_full)
    ]
}
(
    parameters_sample,
    augmented_posterior_values,
    _,
    auxiliary_values
) = cf.sample_parameters_conditioned_to_data(
    data=repeated_observed_data,
    return_internals=True
)
samples[val_idx] = stack_data(parameters_sample)

p = cf.generative_hbm.log_prob(
    **parameters_sample,
    **{
        observed_rv: repeated_observed_data[observed_rv]
        for observed_rv in cf.observed_rvs
    }
)
r = cf.MF_log_prob(
    augmented_posterior_values=augmented_posterior_values,
    auxiliary_values=auxiliary_values,
)
q = (
    cf
    .joint_log_prob_conditioned_to_data(
        data={
            **parameters_sample,
            **{
                observed_rv: repeated_observed_data[observed_rv]
                for observed_rv in cf.observed_rvs
            }
        },
        augmented_posterior_values=(
            augmented_posterior_values
        ),
        auxiliary_values=auxiliary_values
    )
)

loss = tf.reduce_mean(q - p - r).numpy()
dt = time_2 - time_1

print(
    f"Val idx: {val_idx} time: {dt} loss: {loss}"
)

# %%

base_name = (
    "../data/"
    f"HV_CF_idx{na_val_idx}_seed{seed}_"
)
pickle.dump(
    history.history,
    open(
        base_name + "history.p",
        "wb"
    )
)
pickle.dump(
    loss,
    open(
        base_name + "loss.p",
        "wb"
    )
)
