
# %% Imports

import os
from time import time

import tensorflow as tf
from tensorflow_probability import distributions as tfd

from pavi.dual.models import (
    CascadingFlows
)
from generative_hbms.HCPL import (
    S_full, T_full, N_full,
    ground_hbm,
    cf_hbm_kwargs,
    cf_val_data,
    na_val_idx
)


# %% ADAVFamily kwargs

d = 32

rff_kwargs = dict(
    units_per_layers=[d]
)

cf_kwargs = dict(
    **cf_hbm_kwargs,
    auxiliary_variables_size=d,
    rff_kwargs=rff_kwargs,
    nf_kwargs={},
    amortized=False,
    auxiliary_target_type="identity"
)

# %% We build our architecture
cf = CascadingFlows(
    **cf_kwargs,
)

# %% We select the loss used for training
lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-3,
    decay_steps=200,
    decay_rate=0.9
)
optimizer = tf.optimizers.Adam(
    learning_rate=lr_scheduler,
)
cf.compile(
    train_method="reverse_KL",
    n_theta_draws_per_x=8,
    optimizer=optimizer
)

# %% We fit the training data

observed_data = {
    key: value[na_val_idx:na_val_idx + 1]
    for key, value in cf_val_data.items()
    if key in [
        f"X_{s}_{t}_{n}"
        for s in range(S_full)
        for t in range(T_full)
        for n in range(N_full)
    ]
}

# %%

time_1 = time()
history = cf.fit(
    observed_data,
    batch_size=1,
    epochs=4_000
)
time_2 = time()

# %% We compute the ELBO
sample_size = 256

repeated_observed_data = {
    observed_rv: tf.repeat(value, axis=0, repeats=sample_size)
    for observed_rv, value in observed_data.items()
}

(
    parameters_sample,
    augmented_posterior_values,
    _,
    auxiliary_values
) = cf.sample_parameters_conditioned_to_data(
    data=repeated_observed_data,
    return_internals=True
)

p = ground_hbm.log_prob(
    **parameters_sample,
    **repeated_observed_data
)

r = cf.MF_log_prob(
    augmented_posterior_values=augmented_posterior_values,
    auxiliary_values=auxiliary_values,
)

q = (
    cf.joint_log_prob_conditioned_to_data(
        data={
            **parameters_sample,
            **repeated_observed_data
        },
        augmented_posterior_values=(
            augmented_posterior_values
        ),
        auxiliary_values=auxiliary_values
    )
)

augmented_ELBO = tf.reduce_mean(p + r - q)

print(f"""
    Idx:  {na_val_idx}
    Time: {time_2 - time_1}
    Loss: {- augmented_ELBO}
""")

# %%
