
# %% Imports

import pickle
import os
from time import time
import matplotlib.pyplot as plt

import tensorflow as tf
import numpy as np
from tensorflow_probability import distributions as tfd

from pavi.dual.models import (
    PAVEFamily
)
from pavi.utils.callbacks import (TimerCallback, BatchELBOCallback)
from generative_hbms.HV import (
    seed,
    pavi_kwargs,
    val_data,
    na_val_idx,
    na_dataset,
    N1_full
)

# %% FRMFamily kwargs

d = 128

num_heads = 4
key_dim = d // num_heads
k = 1
m = 128
n_sabs = 2

rff_kwargs = dict(
    units_per_layers=[d, d]
)

mab_kwargs = dict(
    multi_head_attention_kwargs=dict(
        num_heads=num_heads,
        key_dim=key_dim
    ),
    rff_kwargs=rff_kwargs,
    layer_normalization_h_kwargs=dict(),
    layer_normalization_out_kwargs=dict()
)

isab_kwargs = dict(
    m=m,
    d=d,
    mab_h_kwargs=mab_kwargs,
    mab_out_kwargs=mab_kwargs
)

set_transformer_kwargs = dict(
    embedding_size=d,
    encoder_kwargs=dict(
        type="ISAB",
        kwargs_per_layer=[
            isab_kwargs
        ] * n_sabs
    ),
    decoder_kwargs=dict(
        pma_kwargs=dict(
            k=k,
            d=d,
            rff_kwargs=rff_kwargs,
            mab_kwargs=mab_kwargs,
        ),
        sab_kwargs=mab_kwargs,
        rff_kwargs=rff_kwargs
    )
)

conditional_nf_chain_kwargs = dict(
    nf_type_kwargs_per_bijector=[
        (
            "MAF",
            dict(
                hidden_units=[d, d]
            )
        ),
        (
            "affine",
            dict(
                scale_type="diag",
                rff_kwargs=rff_kwargs
            )
        )
    ],
    with_permute=False,
    with_batch_norm=False
)

flow_posterior_scheme_kwargs = (
    "flow",
    dict(
        conditional_nf_chain_kwargs=conditional_nf_chain_kwargs
    )
)

posterior_schemes_kwargs = {
    "a": flow_posterior_scheme_kwargs,
    "b": flow_posterior_scheme_kwargs,
    "c": ("observed", {})
}

pavi_family_kwargs = dict(
    posterior_schemes_kwargs=posterior_schemes_kwargs,
    encoding_sizes={
        ('N0',): d,
        ('N0', 'N1'): d,
    },
    embedder_rff_kwargs=rff_kwargs,
    set_transformer_kwargs=set_transformer_kwargs,
    **pavi_kwargs
)


# %% We build our architecture
pavi_family = PAVEFamily(
    **pavi_family_kwargs
)

# %% We select the loss used for training
pavi_family.compile(
    train_method="reverse_KL",
    n_theta_draws=8,
    # optimizer="adam"
    optimizer=tf.keras.optimizers.Adam(
        learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=1e-3,
            decay_steps=200,
            decay_rate=0.9
        )
    )
)

observed_values = {
    "c": tf.expand_dims(
        val_data["c"][na_val_idx],
        axis=0
    )
}

# %% We fit the training data

call_me_maybe = BatchELBOCallback(
    elbo_batches=25,
    p_model=pavi_kwargs["full_hbm"],
    observed_values=observed_values,
    sample_size=32
)

time_1 = time()
history = pavi_family.fit(
    na_dataset,
    batch_size=1,
    epochs=1_000,
    callbacks=[
        TimerCallback(),
        call_me_maybe
    ]
)
time_2 = time()

# %% We compute the ELBO
sample_size = 256
q_sample = pavi_family.sample(
    sample_shape=(sample_size,),
    observed_values=observed_values,
    return_observed_values=True
)
q = pavi_family.log_prob(
    q_sample
)
p = pavi_kwargs["full_hbm"].log_prob(
    q_sample
)
ELBO = tf.reduce_mean(p - q).numpy()

print(f"""
    Idx:  {na_val_idx}
    Time: {time_2 - time_1}
    Loss: {- ELBO}
""")

# %%

fig, axs = plt.subplots(
    nrows=1,
    ncols=2,
    figsize=(20, 10)
)
for n1 in range(N1_full):
    axs[0].scatter(
        np.log(val_data["c"][na_val_idx, 0, n1, :, 0]),
        np.log(val_data["c"][na_val_idx, 0, n1, :, 1]),
        color=f"C{n1}",
        alpha=0.5
    )

axs[0].axis("equal")
axs[0].set_ylabel(
    f"Example {na_val_idx}",
    fontsize=30,
    rotation=0,
    ha="right",
    va="center"
)
plt.draw()
# x_lim = axs[0].get_xlim()
# y_lim = axs[0].get_ylim()
x_lim = [-5, 5]
y_lim = [-5, 5]

axs[1].scatter(
    np.log(q_sample["a"][:, 0, 0]),
    np.log(q_sample["a"][:, 0, 1]),
    color="black",
    s=20,
    alpha=0.5
)

for n1 in range(N1_full):
    axs[1].scatter(
        np.log(q_sample["b"][:, 0, n1, 0]),
        np.log(q_sample["b"][:, 0, n1, 1]),
        color=f"C{n1}",
        s=20,
        alpha=0.5
    )

for i in range(2):
    axs[i].set_xlim(x_lim)
    axs[i].set_ylim(y_lim)
    axs[i].tick_params(
        which="major",
        labelsize=30
    )

axs[0].set_title(
    "Data",
    fontsize=30
)
axs[1].set_title(
    "q samples",
    fontsize=30
)
plt.show()

# %%

base_name = (
    "../data/"
    f"HV_PAVI-E_idx{na_val_idx}_seed{seed}_"
)
pickle.dump(
    {
        **history.history,
        "ELBO": call_me_maybe.ELBOs[::25]
    },
    open(
        base_name + "history.p",
        "wb"
    )
)
pickle.dump(
    - ELBO,
    open(
        base_name + "loss.p",
        "wb"
    )
)
