
# %% Imports

import pickle
from time import time
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow_probability import distributions as tfd

from pavi.dual.models import (
    UIVFamily
)
from pavi.utils.callbacks import (TimerCallback,)
from generative_hbms.GRE import (
    seed,
    G, N,
    scale_mu_g, scale_mu, scale_x,
    generative_hbm,
    link_functions,
    val_data,
    na_val_idx,
)

# %% UIVFamily kwargs

nf_chain_kwargs = dict(
    nf_type_kwargs_per_bijector=[
        (
            "affine",
            dict(
                scale_type="diag",
                rff_kwargs=dict(
                    units_per_layers={
                        2: [32, 32],
                        20: [64, 64],
                        200: [128, 128]
                    }[G]
                )
            )
        )
    ],
    with_permute=False,
    with_batch_norm=False
)

uivi_kwargs = dict(
    generative_hbm=generative_hbm,
    link_functions=link_functions,
    observed_rv="x",
    conditional_nf_chain_kwargs=nf_chain_kwargs,
    embedding_RV_size={
        2: 3,
        20: 6,
        200: 9
    }[G]
)


# %% We build our architecture
ui_family = UIVFamily(
    **uivi_kwargs
)

# %%

observed_values = {
    "x": tf.expand_dims(val_data["x"][na_val_idx], axis=0)
}

# %% We fit the training data
time_1 = time()

# %% We select the loss used for training

ui_family.compile(
    n_theta_draws=8,
    t_mcmc_burn_in=5,
    n_mcmc_samples=5,
    optimizer=tf.keras.optimizers.Adam(
        learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=1e-2,
            decay_steps=300,
            decay_rate=0.9
        )
    )
)

# %% We fit the training data
hist_2 = ui_family.fit(
    observed_values,
    batch_size=1,
    epochs=3_000,
    verbose=0,
    callbacks=[
        TimerCallback()
    ]
)
time_2 = time()

# %% We compute the ELBO
sample_size = 256
repeated_x = tf.repeat(
    val_data["x"][na_val_idx:na_val_idx + 1],
    repeats=(sample_size,),
    axis=0
)
z, epsilon = ui_family.sample((sample_size,), return_epsilon=True)
q = ui_family.q_z(z, epsilon)
p = ui_family.generative_hbm.log_prob(
    **z,
    x=repeated_x
)
ELBO = tf.reduce_mean(
    p - q
).numpy()

print(f"""
    Idx:  {na_val_idx}
    Time: {time_2 - time_1}
    Loss: {- ELBO}
""")

# %%

fig, axs = plt.subplots(
    nrows=1,
    ncols=2,
    figsize=(20, 10)
)
group_means = []
circles = []
for g in range(G):
    axs[0].scatter(
        val_data["x"][na_val_idx, g, :, 0],
        val_data["x"][na_val_idx, g, :, 1],
        color=f"C{g}",
        alpha=0.5
    )

    mean = tf.reduce_mean(
        val_data["x"][na_val_idx, g],
        axis=-2
    )

    circles.append(
        plt.Circle(
            (mean[0], mean[1]),
            2 * scale_x / N**0.5,
            fill=False,
            color="black",
        )
    )

    group_means.append(mean)

population_mean = tf.reduce_mean(
    tf.stack(
        group_means,
        axis=-2
    ),
    axis=-2
)
posterior_mean = population_mean / (1 + scale_mu_g**2/(G * scale_mu**2))
posterior_scale = (1/(1/scale_mu**2 + G/scale_mu_g**2))**0.5

circle = plt.Circle(
    (posterior_mean[0], posterior_mean[1]),
    2 * posterior_scale,
    fill=False,
    color="black"
)

axs[0].axis("equal")
axs[0].set_ylabel(
    f"Example {na_val_idx}",
    fontsize=30,
    rotation=0,
    ha="right",
    va="center"
)
plt.draw()
x_lim = axs[0].get_xlim()
y_lim = axs[0].get_ylim()

axs[1].scatter(
    z["mu"][:, 0],
    z["mu"][:, 1],
    color="black",
    s=20,
    alpha=0.5
)
axs[1].add_patch(circle)

for g in range(G):
    axs[1].scatter(
        z["mu_g"][:, g, 0],
        z["mu_g"][:, g, 1],
        color=f"C{g}",
        s=20,
        alpha=0.5
    )
    axs[1].add_patch(circles[g])

axs[1].set_xlim(x_lim)
axs[1].set_ylim(y_lim)
axs[1].tick_params(
    which="major",
    labelsize=30
)

axs[0].set_title(
    "Data",
    fontsize=30
)
axs[1].set_title(
    "q samples",
    fontsize=30
)
plt.show()

# %%
base_name = (
    "../data/"
    f"GRE_UIVI_G{G}_idx{na_val_idx}_seed{seed}_"
)

pickle.dump(
    hist_2.history,
    open(
        base_name + "history.p",
        "wb"
    )
)
pickle.dump(
    - ELBO,
    open(
        base_name + "loss.p",
        "wb"
    )
)