
# %% Imports

import pickle
import os
from time import time

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow_probability import distributions as tfd

from pavi.dual.models import (
    PAVFFamily
)
from pavi.utils.callbacks import (
    TimerCallback,
    ELBOCallback,
)
from generative_hbms.HCPL import (
    pavi_kwargs,
    val_data,
    na_val_idx
)


# %% ADAVFamily kwargs

d_1 = 8
d_2 = 32

conditional_nf_chain_kwargs_light = dict(
    nf_type_kwargs_per_bijector=[
        (
            "MAF",
            dict(
                hidden_units=[d_1, d_1]
            )
        ),
        (
            "affine",
            dict(
                scale_type="diag",
                rff_kwargs=dict(
                    units_per_layers=[d_1, d_1]
                )
            )
        )
    ],
    with_permute=False,
    with_batch_norm=False
)
conditional_nf_chain_kwargs_heavy = dict(
    nf_type_kwargs_per_bijector=[
        (
            "MAF",
            dict(
                hidden_units=[d_2, d_2]
            )
        ),
        (
            "affine",
            dict(
                scale_type="diag",
                rff_kwargs=dict(
                    units_per_layers=[d_2, d_2]
                )
            )
        )
    ],
    with_permute=False,
    with_batch_norm=False
)

flow_posterior_scheme_kwargs_light = (
    "flow",
    dict(
        conditional_nf_chain_kwargs=conditional_nf_chain_kwargs_light
    )
)
flow_posterior_scheme_kwargs_heavy = (
    "flow",
    dict(
        conditional_nf_chain_kwargs=conditional_nf_chain_kwargs_heavy
    )
)

posterior_schemes_kwargs = {
    "mu_g": flow_posterior_scheme_kwargs_heavy,
    "epsilon": flow_posterior_scheme_kwargs_light,
    "mu_s": flow_posterior_scheme_kwargs_heavy,
    "sigma": flow_posterior_scheme_kwargs_light,
    "mu_s_t": flow_posterior_scheme_kwargs_heavy,
    "probs": flow_posterior_scheme_kwargs_light,
    "labels": flow_posterior_scheme_kwargs_light,
    "kappa": flow_posterior_scheme_kwargs_light,
    "X_s_t": (
        "observed",
        dict()
    )
}

family_kwargs = dict(
    posterior_schemes_kwargs=posterior_schemes_kwargs,
    encoding_sizes={
        ('P',): d_2,
        ('P', 'S'): d_2,
        ('P', 'S', 'T'): d_2,
        ('P', 'S', 'N'): d_1
    },
    **pavi_kwargs
)

epochs = 1_500
train_method = "reverse_KL"
n_theta_draws = 8


# %% We build our architecture
pavf_family = PAVFFamily(
    **family_kwargs
)

# %% We select the loss used for training
pavf_family.compile(
    train_method=train_method,
    n_theta_draws=n_theta_draws,
    optimizer="adam"
)

observed_values = {
    "X_s_t": val_data["X_s_t"][na_val_idx]
}

# %% We fit the training data
time_1 = time()
history = pavf_family.fit(
    observed_values,
    batch_size=1,
    epochs=epochs,
    verbose=2,
    callbacks=[
        TimerCallback(),
        ELBOCallback(
            elbo_epochs=50,
            p_model=pavi_kwargs["full_hbm"],
            observed_values=observed_values
        ),
    ]
)
time_2 = time()

# %%
