
# %% Imports
import pickle
import argparse
from typing import Dict


from random import shuffle

from pavi.utils import get_stops

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

tfd = tfp.distributions
tfb = tfp.bijectors


# %% Argument parsing

parser = argparse.ArgumentParser()
parser.add_argument(
    "--na-val-idx",
    type=int,
    default=4,
    required=False
)
parser.add_argument(
    "--random-seed",
    type=int,
    default=4321,
    required=False
)
args, _ = parser.parse_known_args()

# %%

na_val_idx = args.na_val_idx
seed = args.random_seed
tf.random.set_seed(seed)
np.random.seed(seed)

# %% Generative Hierarchical Bayesian Model

D = 2
L = 3
G = 20
N = 10
loc_mu = tf.zeros((D,))
scale_mu = 1.0
scale_mu_g = 0.5
scale_x = 0.1
dirichlet_concentration = tf.ones((L,)) * 1

generative_hbm = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Sample(
            tfd.Independent(
                tfd.Normal(
                    loc=loc_mu,
                    scale=scale_mu
                ),
                reinterpreted_batch_ndims=1
            ),
            sample_shape=(L,),
            name="mu"
        ),
        mu_g=lambda mu: tfd.Sample(
            tfd.Independent(
                tfd.Normal(
                    loc=mu,
                    scale=scale_mu_g
                ),
                reinterpreted_batch_ndims=2
            ),
            sample_shape=(G,),
            name="mu_g"
        ),
        probs=tfd.Sample(
            tfd.Dirichlet(
                concentration=dirichlet_concentration
            ),
            sample_shape=(G,),
            name="probs"
        ),
        x=lambda mu_g, probs: tfd.TransformedDistribution(
            tfd.Sample(
                tfd.Independent(
                    tfd.Mixture(
                        cat=tfd.Categorical(probs=probs),
                        components=[
                            tfd.Independent(
                                tfd.Normal(
                                    loc=mu_g[..., i, :],
                                    scale=scale_x
                                ),
                                reinterpreted_batch_ndims=1
                            )
                            for i in range(L)
                        ]
                    ),
                    reinterpreted_batch_ndims=1
                ),
                sample_shape=(N,)
            ),
            bijector=tfb.Transpose(perm=[1, 0, 2]),
            name="x"
        )
    )
)

link_functions = {
    "x": tfb.Identity(),
    "mu_g": tfb.Identity(),
    "probs": tfb.SoftmaxCentered(),
    "mu": tfb.Identity()
}

hierarchies = {
    "x": 0,
    "mu_g": 1,
    "probs": 1,
    "mu": 2
}

hbm_kwargs = dict(
    generative_hbm=generative_hbm,
    hierarchies=hierarchies,
    link_functions=link_functions
)

# %% Dataset generation

dataset_name = (
    f"../data/GM_dataset.p"
)
try:
    dataset = pickle.load(
        open(dataset_name, "rb")
    )
    train_data = dataset["train"]
    val_data = dataset["val"]
except FileNotFoundError:
    seed = 1234
    tf.random.set_seed(seed)
    train_size, val_size = 20_000, 2000
    train_data, val_data = (
        generative_hbm.sample(size)
        for size in [train_size, val_size]
    )
    dataset = {
        data_key: {
            key: value.numpy()
            for key, value in data.items()
        }
        for data_key, data in [
            ("train", train_data),
            ("val", val_data)
        ]
    }
    pickle.dump(
        dataset,
        open(dataset_name, "wb")
    )

# %% Ground graph, used by CF

ground_hbm = tfd.JointDistributionNamed(
    model={
        "mu": tfd.Sample(
            tfd.Independent(
                tfd.Normal(
                    loc=loc_mu,
                    scale=scale_mu
                ),
                reinterpreted_batch_ndims=1
            ),
            sample_shape=(L,)
        ),
        **{
            f"mu_{g}": lambda mu: tfd.Independent(
                tfd.Normal(
                    loc=mu,
                    scale=scale_mu_g
                ),
                reinterpreted_batch_ndims=2
            )
            for g in range(G)
        },
        **{
            f"probs_{g}": tfd.Dirichlet(
                concentration=dirichlet_concentration
            )
            for g in range(G)
        },
        **{
            f"mu_{g}_{n}": eval(
                f"""lambda mu_{g}, probs_{g}: tfd.Mixture(
                    cat=tfd.Categorical(probs=probs_{g}),
                    components=[
                        tfd.Independent(
                            tfd.Normal(
                                loc=mu_{g}[..., i, :],
                                scale=scale_x
                            ),
                            reinterpreted_batch_ndims=1
                        )
                        for i in range(L)
                    ]
                    )"""
            )
            for g in range(G)
            for n in range(N)
        }
    }
)

cf_hbm_kwargs = dict(
    generative_hbm=ground_hbm,
    observed_rvs=[
        f"mu_{g}_{n}"
        for g in range(G)
        for n in range(N)
    ],
    link_functions={
        "mu": tfb.Identity(),
        **{
            f"mu_{g}": tfb.Identity()
            for g in range(G)
        },
        **{
            f"probs_{g}": tfb.SoftmaxCentered()
            for g in range(G)
        },
        **{
            f"mu_{g}_{n}": tfb.Identity()
            for g in range(G)
            for n in range(N)
        }
    },
    observed_rv_reshapers={
        f"mu_{g}_{n}": tfb.Identity()
        for g in range(G)
        for n in range(N)
    }
)

# %% Data reshaping


def stack_data(
    data: Dict[str, tf.Tensor]
) -> Dict[str, tf.Tensor]:
    output_data = {}
    output_data["mu"] = data["mu"]
    output_data["mu_g"] = tf.stack(
        [
            data[f"mu_{g}"]
            for g in range(G)
        ],
        axis=-3
    )
    output_data["probs"] = tf.stack(
        [
            data[f"probs_{g}"]
            for g in range(G)
        ],
        axis=-2
    )
    try:
        output_data["x"] = tf.stack(
            [
                tf.stack(
                    [
                        data[f"mu_{g}_{n}"]
                        for n in range(N)
                    ],
                    axis=-2
                )
                for g in range(G)
            ],
            axis=-3
        )
    except KeyError:
        pass

    return output_data


def slice_data(
    data: Dict[str, tf.Tensor]
) -> Dict[str, tf.Tensor]:
    output_data = {}
    output_data["mu"] = data["mu"]
    for g in range(G):
        output_data[f"mu_{g}"] = data["mu_g"][..., g, :, :]
        output_data[f"probs_{g}"] = data["probs"][..., g, :]
    try:
        for g in range(G):
            for n in range(N):
                output_data[f"mu_{g}_{n}"] = data["x"][..., g, n, :]
    except KeyError:
        pass

    return output_data


# %% CF Data

cf_train_data = slice_data(train_data)
cf_val_data = slice_data(val_data)

# %% Batch_shape HBM, used for template training

def get_batched_hbm(
    G: int,
    N:int
):
    return tfd.JointDistributionNamed(
        model=dict(
            mu=tfd.Sample(
                tfd.Independent(
                    tfd.Normal(
                        loc=loc_mu,
                        scale=scale_mu
                    ),
                    reinterpreted_batch_ndims=1,
                ),
                sample_shape=(1, L),
                name="mu"
            ),
            mu_g=lambda mu: tfd.Independent(
                tfd.Normal(
                    loc=tf.stack(
                        [mu] * G,
                        axis=-3
                    ),
                    scale=scale_mu_g
                ),
                reinterpreted_batch_ndims=4,
                name="mu_g"
            ),
            probs=tfd.Sample(
                tfd.Dirichlet(
                    concentration=dirichlet_concentration,
                ),
                sample_shape=(1, G),
                name="probs"
            ),
            x=lambda mu_g, probs: tfd.TransformedDistribution(
                tfd.Sample(
                    tfd.Independent(
                        tfd.Mixture(
                            cat=tfd.Categorical(
                                probs=probs
                            ),
                            components=[
                                tfd.Independent(
                                    tfd.Normal(
                                        loc=mu_g[..., i, :],
                                        scale=scale_x
                                    ),
                                    reinterpreted_batch_ndims=1
                                )
                                for i in range(L)
                            ]
                        ),
                        name="x",
                        reinterpreted_batch_ndims=2
                    ),
                    sample_shape=(N,)
                ),
                bijector=tfb.Transpose([1, 2, 0, 3])
            )
        )
    )

G_reduced = 5

pavi_kwargs = dict(
    full_hbm=get_batched_hbm(
        G=G,
        N=N
    ),
    reduced_hbm=get_batched_hbm(
        G=G_reduced,
        N=N
    ),
    plates_per_rv={
        "mu": ['P'],
        "mu_g": ['P', 'G'],
        "probs": ['P', 'G'],
        "x": ['P', 'G', 'N']
    },
    link_functions=link_functions
)

# %%


class NADataset(tf.data.Dataset):

    def _generator():

        g_list = [g for g in range(G)]
        shuffle(g_list)

        X = val_data["x"][na_val_idx:na_val_idx + 1]
        X = X[None, :, g_list]

        for g in get_stops(
            full_size=G,
            batch_size=G_reduced
        ):
            yield (
                {
                    "x": X[:, :, g - G_reduced:g]
                },
                {
                    "P": tf.zeros((1,)),
                    "G": g_list[g - G_reduced:g],
                    "N": tf.range(0, N)
                }
            )

    def __new__(cls):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature=(
                {
                    "x": tf.TensorSpec(
                        shape=(1, 1, G_reduced, N, D),
                        dtype=tf.float32
                    )
                },
                {
                    "P": tf.TensorSpec(
                        shape=(1,),
                        dtype=tf.int32
                    ),
                    "G": tf.TensorSpec(
                        shape=(G_reduced,),
                        dtype=tf.int32
                    ),
                    "N": tf.TensorSpec(
                        shape=(N,),
                        dtype=tf.int32
                    )
                }
            )
        )


na_dataset = (
    NADataset()
    .prefetch(
        buffer_size=tf.data.AUTOTUNE
    )
)
