
# %% Imports

import pickle
import os
from time import time
import numpy as np

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow_probability import distributions as tfd

from pavi.utils.callbacks import (
    TimerCallback,
    BatchELBOCallback
)
from pavi.dual.models import (
    PAVEFamily
)
from generative_hbms.GRE import (
    seed,
    G, N, D, scale_mu, scale_mu_g, scale_x,
    G_reduced, N_reduced,
    pavi_kwargs,
    reduced_train_data,
    val_data,
)

# %% ADAVFamily kwargs

d = 16

num_heads = 2
key_dim = d // num_heads
k = 1
m = 64
n_sabs = 2

rff_kwargs = dict(
    units_per_layers=[d]
)

mab_kwargs = dict(
    multi_head_attention_kwargs=dict(
        num_heads=num_heads,
        key_dim=key_dim
    ),
    rff_kwargs=rff_kwargs,
    layer_normalization_h_kwargs=dict(),
    layer_normalization_out_kwargs=dict()
)

isab_kwargs = dict(
    m=m,
    d=d,
    mab_h_kwargs=mab_kwargs,
    mab_out_kwargs=mab_kwargs
)

set_transformer_kwargs = dict(
    embedding_size=d,
    encoder_kwargs=dict(
        type="ISAB",
        kwargs_per_layer=[
            isab_kwargs
        ] * n_sabs
    ),
    decoder_kwargs=dict(
        pma_kwargs=dict(
            k=k,
            d=d,
            rff_kwargs=rff_kwargs,
            mab_kwargs=mab_kwargs,
        ),
        sab_kwargs=mab_kwargs,
        rff_kwargs=rff_kwargs
    )
)

conditional_nf_chain_kwargs = dict(
    nf_type_kwargs_per_bijector=[
        (
            "MAF",
            dict(
                hidden_units=[32, 32]
            )
        ),
        (
            "affine",
            dict(
                scale_type="tril",
                rff_kwargs=rff_kwargs
            )
        )
    ],
    with_permute=False,
    with_batch_norm=False
)

flow_posterior_scheme_kwargs = (
    "flow",
    dict(
        conditional_nf_chain_kwargs=conditional_nf_chain_kwargs
    )
)

posterior_schemes_kwargs = {
    "mu": flow_posterior_scheme_kwargs,
    "mu_g": flow_posterior_scheme_kwargs,
    "x": (
        "observed",
        dict()
    )
}

family_kwargs = dict(
    posterior_schemes_kwargs=posterior_schemes_kwargs,
    encoding_sizes={
        ('P',): d,
        ('P', 'G'): d,
    },
    embedder_rff_kwargs=rff_kwargs,
    set_transformer_kwargs=set_transformer_kwargs,
    **pavi_kwargs
)

n_theta_draws = 8


# %% We build our architecture
pave_family = PAVEFamily(
    **family_kwargs
)

# %% Unregularized ELBO
pave_family.compile(
    train_method="unregularized_ELBO",
    n_theta_draws=n_theta_draws,
    optimizer="adam"
)

# %% We fit the training data

c1 = BatchELBOCallback(
    elbo_batches=20,
    p_model=pavi_kwargs["full_hbm"],
    observed_values={
        "x": tf.expand_dims(
            val_data["x"][:20],
            axis=1
        )
    },
    sample_size=256
)

time_1 = time()
hist_1 = pave_family.fit(
    {
        "x": tf.reshape(reduced_train_data["x"], (20_000, 1, G_reduced, N_reduced, D))
    },
    batch_size=32,
    epochs=4,
    shuffle=True,
    callbacks=[
        TimerCallback(),
        c1
    ]
)

# %% reverse KL
pave_family.compile(
    train_method="reverse_KL",
    n_theta_draws=n_theta_draws,
    optimizer=tf.keras.optimizers.Adam(
        learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=1e-3,
            decay_steps=100,
            decay_rate=0.9
        )
    )
)

# %% We fit the training data
hist_2 = pave_family.fit(
    {
        "x": tf.reshape(reduced_train_data["x"], (20_000, 1, G_reduced, N_reduced, D))
    },
    batch_size=32,
    epochs=10,
    shuffle=True,
    # verbose=2,
    callbacks=[
        TimerCallback(),
        c1
    ]
)
time_2 = time()

# %% We compute the ELBO
# We store away samples for comparison
n_draws = 256
samples = {}
losses = []
for val_idx in range(0, 20):
    samples[val_idx] = {
        key: value.numpy()
        for key, value in (
            pave_family
            .sample(
                sample_shape=(n_draws,),
                observed_values={
                    "x": tf.expand_dims(
                        val_data["x"][val_idx:val_idx + 1],
                        axis=1
                    )
                },
                return_observed_values=True
            )
            .items()
        )
    }

    q = pave_family.log_prob(
        values=samples[val_idx]
    )
    p = pave_family.hbms["full"].log_prob(
        **samples[val_idx]
    )
    loss = tf.reduce_mean(
        q - p
    )
    losses.append(loss.numpy())
time_3 = time()
print(f"Sampling: {time_3-time_2}")
print(
    "Losses:\n",
    losses,
    "\nMean:\n",
    tf.reduce_mean(losses),
    "\nStd:\n",
    tf.math.reduce_std(losses)
)

# %%

PLOT = True
if PLOT:
    fig, axs = plt.subplots(
        nrows=3,
        ncols=2,
        figsize=(20, 30)
    )
    for val_idx in range(0, 3):
        group_means = []
        circles = []
        for g in range(G):
            axs[val_idx, 0].scatter(
                val_data["x"][val_idx, g, :, 0],
                val_data["x"][val_idx, g, :, 1],
                color=f"C{g}",
                alpha=0.5
            )

            mean = tf.reduce_mean(
                val_data["x"][val_idx, g],
                axis=-2
            )

            circles.append(
                plt.Circle(
                    (mean[0], mean[1]),
                    2 * scale_x / N**0.5,
                    fill=False,
                    color="black",
                )
            )

            group_means.append(mean)

        population_mean = tf.reduce_mean(
            tf.stack(
                group_means,
                axis=-2
            ),
            axis=-2
        )
        posterior_mean = population_mean / (1 + scale_mu_g**2/(G * scale_mu**2))
        posterior_scale = (1/(1/scale_mu**2 + G/scale_mu_g**2))**0.5

        circle = plt.Circle(
            (posterior_mean[0], posterior_mean[1]),
            2 * posterior_scale,
            fill=False,
            color="black"
        )

        axs[val_idx, 0].axis("equal")
        axs[val_idx, 0].set_ylabel(
            f"Example {val_idx}",
            fontsize=30,
            rotation=0,
            ha="right",
            va="center"
        )
        plt.draw()
        x_lim = axs[val_idx, 0].get_xlim()
        y_lim = axs[val_idx, 0].get_ylim()

        axs[val_idx, 1].scatter(
            samples[val_idx]["mu"][:, 0, 0],
            samples[val_idx]["mu"][:, 0, 1],
            color="black",
            s=20,
            alpha=0.5
        )
        axs[val_idx, 1].add_patch(circle)

        for g in range(G):
            axs[val_idx, 1].scatter(
                samples[val_idx]["mu_g"][:, 0, g, 0],
                samples[val_idx]["mu_g"][:, 0, g, 1],
                color=f"C{g}",
                s=20,
                alpha=0.5
            )
            axs[val_idx, 1].add_patch(circles[g])

        axs[val_idx, 1].set_xlim(x_lim)
        axs[val_idx, 1].set_ylim(y_lim)
        axs[val_idx, 1].tick_params(
            which="major",
            labelsize=30
        )

    axs[0, 0].set_title(
        "Data",
        fontsize=30
    )
    axs[0, 1].set_title(
        "posterior samples",
        fontsize=30
    )
    plt.show()

# %%

print(
    f"G {G} number of weights: {tf.reduce_sum([tf.reduce_prod(w.shape) for w in pave_family.get_weights()]).numpy()}"
)

# %%

base_name = (
    "../data/"
    f"GRE_PAVI-E_sa_G{G}_seed{seed}_"
)
pickle.dump(
    samples,
    open(
        base_name + "sample.p",
        "wb"
    )
)
pickle.dump(
    hist_1.history,
    open(
        base_name + "history_1.p",
        "wb"
    )
)
pickle.dump(
    hist_2.history,
    open(
        base_name + "history_2.p",
        "wb"
    )
)
pickle.dump(
    c1.ELBOs,
    open(
        base_name + "elbos.p",
        "wb"
    )
)

pickle.dump(
    losses,
    open(
        base_name + "losses.p",
        "wb"
    )
)

# %%
