
# %% Imports

import pickle
import os
from time import time
import matplotlib.pyplot as plt

import tensorflow as tf
import numpy as np
from tensorflow_probability import distributions as tfd

from pavi.dual.models import (
    PAVFFamily
)
from pavi.utils.callbacks import (TimerCallback, ELBOCallback)
from generative_hbms.GRE import (
    seed,
    encoding_size,
    G, G_reduced, N, scale_mu, scale_mu_g, scale_x,
    pavi_kwargs,
    val_data,
    na_val_idx,
)

# %% FRMFamily kwargs

d = encoding_size

rff_kwargs = dict(
    units_per_layers=[d]
)

conditional_nf_chain_kwargs = dict(
    nf_type_kwargs_per_bijector=[
        (
            "MAF",
            dict(
                hidden_units=[32, 32]
            )
        ),
        (
            "affine",
            dict(
                scale_type="tril",
                rff_kwargs=rff_kwargs
            )
        )
    ],
    with_permute=False,
    with_batch_norm=False
)

flow_posterior_scheme_kwargs = (
    "flow",
    dict(
        conditional_nf_chain_kwargs=conditional_nf_chain_kwargs
    )
)

posterior_schemes_kwargs = {
    "mu": flow_posterior_scheme_kwargs,
    "mu_g": flow_posterior_scheme_kwargs,
    "x": (
        "observed",
        dict()
    )
}

family_kwargs = dict(
    posterior_schemes_kwargs=posterior_schemes_kwargs,
    encoding_sizes={
        ('P',): d,
        ('P', 'G'): d,
    },
    **pavi_kwargs
)

n_theta_draws = 8


# %% We build our architecture
pavf_family = PAVFFamily(
    **family_kwargs
)

# %%

observed_values = {
    "x": tf.expand_dims(val_data["x"][na_val_idx], axis=0)
}

# %% We select the loss used for training

pavf_family.compile(
    train_method="reverse_KL",
    n_theta_draws=n_theta_draws,
    optimizer=tf.keras.optimizers.Adam(
        learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=1e-3,
            decay_steps=50,
            decay_rate=0.9
        )
    )
)

# %% We fit the training 
time_1 = time()
hist_2 = pavf_family.fit(
    observed_values,
    batch_size=1,
    epochs=10_000,
    verbose=0,
    callbacks=[
        TimerCallback(),
        ELBOCallback(
            elbo_epochs=20,
            p_model=pavi_kwargs["full_hbm"],
            observed_values=observed_values
        )
    ]
)
time_2 = time()

# %% We compute the ELBO
sample_size = 256
q_sample = pavf_family.sample(
    sample_shape=(sample_size,),
    observed_values=observed_values,
    return_observed_values=True
)
q = pavf_family.log_prob(
    q_sample
)
p = pavi_kwargs["full_hbm"].log_prob(
    q_sample
)
ELBO = tf.reduce_mean(p - q).numpy()

print(f"""
    Idx:  {na_val_idx}
    Time: {time_2 - time_1}
    Loss: {- ELBO}
""")

# %% Plotting

PLOT = True
if PLOT:
    fig, axs = plt.subplots(
        nrows=1,
        ncols=2,
        figsize=(20, 10)
    )
    group_means = []
    circles = []
    for g in range(G):
        axs[0].scatter(
            val_data["x"][na_val_idx, g, :, 0],
            val_data["x"][na_val_idx, g, :, 1],
            color=f"C{g}",
            alpha=0.5
        )

        mean = tf.reduce_mean(
            val_data["x"][na_val_idx, g],
            axis=-2
        )

        circles.append(
            plt.Circle(
                (mean[0], mean[1]),
                2 * scale_x / N**0.5,
                fill=False,
                color="black",
            )
        )

        group_means.append(mean)

    population_mean = tf.reduce_mean(
        tf.stack(
            group_means,
            axis=-2
        ),
        axis=-2
    )
    posterior_mean = population_mean / (1 + scale_mu_g**2/(G * scale_mu**2))
    posterior_scale = (1/(1/scale_mu**2 + G/scale_mu_g**2))**0.5

    circle = plt.Circle(
        (posterior_mean[0], posterior_mean[1]),
        2 * posterior_scale,
        fill=False,
        color="black"
    )

    axs[0].axis("equal")
    axs[0].set_ylabel(
        f"Example {na_val_idx}",
        fontsize=30,
        rotation=0,
        ha="right",
        va="center"
    )
    plt.draw()
    x_lim = axs[0].get_xlim()
    y_lim = axs[0].get_ylim()

    axs[1].scatter(
        q_sample["mu"][:, 0, 0],
        q_sample["mu"][:, 0, 1],
        color="black",
        s=20,
        alpha=0.5
    )
    axs[1].add_patch(circle)

    for g in range(G):
        axs[1].scatter(
            q_sample["mu_g"][:, 0, g, 0],
            q_sample["mu_g"][:, 0, g, 1],
            color=f"C{g}",
            s=20,
            alpha=0.5
        )
        axs[1].add_patch(circles[g])

    axs[1].set_xlim(x_lim)
    axs[1].set_ylim(y_lim)
    axs[1].tick_params(
        which="major",
        labelsize=30
    )

    axs[0].set_title(
        "Data",
        fontsize=30
    )
    axs[1].set_title(
        "q samples",
        fontsize=30
    )
    plt.show()

# %%

print(
    f"G {G} number of weights: {tf.reduce_sum([tf.reduce_prod(w.shape) for w in pavf_family.get_weights()]).numpy()}"
)

# %%

base_name = (
    "../data/"
    f"GRE_PAVI-F_G{G}_idx{na_val_idx}_seed{seed}_"
)

# %%

pickle.dump(
    {
        na_val_idx: {
            rv: value.numpy()
            for rv, value in q_sample.items()
        }
    },
    open(
        base_name + "sample.p",
        "wb"
    )
)
pickle.dump(
    hist_2.history,
    open(
        base_name + "history_2.p",
        "wb"
    )
)
pickle.dump(
    - ELBO,
    open(
        base_name + "loss.p",
        "wb"
    )
)

# %%
