
# %% Imports
from typing import Dict
import pickle
import argparse

from random import shuffle

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from pavi.utils import get_stops

tfd = tfp.distributions
tfb = tfp.bijectors

# %% Argument parsing

parser = argparse.ArgumentParser()
parser.add_argument(
    "--G",
    type=int,
    default=20,
    required=False
)
parser.add_argument(
    "--G-reduced",
    type=int,
    default=5,
    required=False
)
parser.add_argument(
    "--na-val-idx",
    type=int,
    default=4,
    required=False
)
parser.add_argument(
    "--random-seed",
    type=int,
    default=0,
    required=False
)
parser.add_argument(
    "--encoding-size",
    type=int,
    default=16,
    required=False
)
args, _ = parser.parse_known_args()

# %% Setting up random seed

seed = args.random_seed
tf.random.set_seed(seed)
np.random.seed(seed)

# %% Encoding size

encoding_size = args.encoding_size
na_val_idx = args.na_val_idx

# %% Generative Hierarchical Bayesian Model

D = 2
G = args.G
N = 10
loc_mu = tf.zeros((D,))
scale_mu = 1.0
scale_mu_g = 0.5
scale_x = 0.1

generative_hbm = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Independent(
            tfd.Normal(
                loc=loc_mu,
                scale=scale_mu
            ),
            reinterpreted_batch_ndims=1,
            name="mu"
        ),
        mu_g=lambda mu: tfd.Sample(
            tfd.Independent(
                tfd.Normal(
                    loc=mu,
                    scale=scale_mu_g
                ),
                reinterpreted_batch_ndims=1
            ),
            sample_shape=(G,),
            name="mu_g"
        ),
        x=lambda mu_g: tfd.TransformedDistribution(
            tfd.Sample(
                tfd.Independent(
                    tfd.Normal(
                        loc=mu_g,
                        scale=scale_x
                    ),
                    reinterpreted_batch_ndims=2
                ),
                sample_shape=(N,)
            ),
            bijector=tfb.Transpose(perm=[1, 0, 2]),
            name="x"
        )
    )
)

link_functions = {
    "x": tfb.Identity(),
    "mu_g": tfb.Identity(),
    "mu": tfb.Identity()
}

hierarchies = {
    "x": 0,
    "mu_g": 1,
    "mu": 2
}

hbm_kwargs = dict(
    generative_hbm=generative_hbm,
    hierarchies=hierarchies,
    link_functions=link_functions
)

total_hbm_kwargs = dict(
    generative_hbm=generative_hbm,
    link_functions=link_functions,
    observed_rv="x"
)

faithful_hbm_kwargs = dict(
    generative_hbm=generative_hbm,
    observed_rvs=['x'],
    plate_cardinalities={
        'G': G,
        'N': N
    },
    link_functions=link_functions,
    observed_rv_reshapers={
        "x": tfb.Identity()
    },
    plates_per_rv={
        "mu": tuple(),
        "mu_g": ('G',),
        "x": ('G', 'N')
    },
)

# %% Dataset generation

dataset_name = (
    f"../data/GRE_dataset_G{G}_N{N}_D{D}.p"
)
try:
    dataset = pickle.load(
        open(dataset_name, "rb")
    )
    train_data = dataset["train"]
    val_data = dataset["val"]
except FileNotFoundError:
    seed = 1234
    tf.random.set_seed(seed)
    train_size, val_size = 20_000, 2000
    train_data, val_data = (
        generative_hbm.sample(size)
        for size in [train_size, val_size]
    )
    dataset = {
        data_key: {
            key: value.numpy()
            for key, value in data.items()
        }
        for data_key, data in [
            ("train", train_data),
            ("val", val_data)
        ]
    }
    pickle.dump(
        dataset,
        open(dataset_name, "wb")
    )

# %% Ground graph, used by CF

ground_hbm = tfd.JointDistributionNamed(
    model={
        "mu": tfd.Independent(
            tfd.Normal(
                loc=loc_mu,
                scale=scale_mu
            ),
            reinterpreted_batch_ndims=1
        ),
        **{
            f"mu_{g}": lambda mu: tfd.Independent(
                tfd.Normal(
                    loc=mu,
                    scale=scale_mu_g
                ),
                reinterpreted_batch_ndims=1
            )
            for g in range(G)
        },
        **{
            f"mu_{g}_{n}": eval(
                f"""lambda mu_{g}: tfd.Independent(
                    tfd.Normal(
                        loc=mu_{g},
                        scale=scale_x
                    ),
                    reinterpreted_batch_ndims=1
                )"""
            )
            for g in range(G)
            for n in range(N)
        }
    }
)

cf_hbm_kwargs = dict(
    generative_hbm=ground_hbm,
    observed_rvs=[
        f"mu_{g}_{n}"
        for g in range(G)
        for n in range(N)
    ],
    link_functions={
        "mu": tfb.Identity(),
        **{
            f"mu_{g}": tfb.Identity()
            for g in range(G)
        },
        **{
            f"mu_{g}_{n}": tfb.Identity()
            for g in range(G)
            for n in range(N)
        }
    },
    observed_rv_reshapers={
        f"mu_{g}_{n}": tfb.Identity()
        for g in range(G)
        for n in range(N)
    }
)

# %% Data reshaping


def stack_data(
    data: Dict[str, tf.Tensor]
) -> Dict[str, tf.Tensor]:
    output_data = {}
    output_data["mu"] = data["mu"]
    output_data["mu_g"] = tf.stack(
        [
            data[f"mu_{g}"]
            for g in range(G)
        ],
        axis=-2
    )
    try:
        output_data["x"] = tf.stack(
            [
                tf.stack(
                    [
                        data[f"mu_{g}_{n}"]
                        for n in range(N)
                    ],
                    axis=-2
                )
                for g in range(G)
            ],
            axis=-3
        )
    except KeyError:
        pass

    return output_data


def slice_data(
    data: Dict[str, tf.Tensor]
) -> Dict[str, tf.Tensor]:
    output_data = {}
    output_data["mu"] = data["mu"]
    for g in range(G):
        output_data[f"mu_{g}"] = data["mu_g"][..., g, :]
    try:
        for g in range(G):
            for n in range(N):
                output_data[f"mu_{g}_{n}"] = data["x"][..., g, n, :]
    except KeyError:
        pass

    return output_data


# %% CF Data

cf_train_data = slice_data(train_data)
cf_val_data = slice_data(val_data)

# %% Batch-Shape HBM, used for template training

def get_batched_hbm(
    G: int,
    N: int,
    only_event_shape: bool = False
) -> tfd.Distribution:
    return tfd.JointDistributionNamed(
        model=dict(
            mu=tfd.Independent(
                tfd.Normal(
                    loc=[loc_mu],
                    scale=scale_mu,
                ),
                reinterpreted_batch_ndims=(
                    2
                    if only_event_shape
                    else 1
                ),
                name="mu"
            ),
            mu_g=lambda mu: tfd.Independent(
                tfd.Normal(
                    loc=tf.stack(
                        [mu] * G,
                        axis=-2
                    ),
                    scale=scale_mu_g,
                ),
                reinterpreted_batch_ndims=(
                    3
                    if only_event_shape
                    else 1
                ),
                name="mu_g"
            ),
            x=lambda mu_g: tfd.Independent(
                tfd.Normal(
                    loc=tf.stack(
                        [mu_g] * N,
                        axis=-2
                    ),
                    scale=scale_x,
                ),
                reinterpreted_batch_ndims=(
                    4
                    if only_event_shape
                    else 1
                ),
                name="x"
            ),
        )
    )

G_reduced = args.G_reduced
N_reduced = N

reduced_hbm = get_batched_hbm(
    G=G_reduced,
    N=N_reduced
)

# %% PAVI

pavi_kwargs = dict(
    full_hbm=get_batched_hbm(
        G=G,
        N=N,
        only_event_shape=True
    ),
    reduced_hbm=get_batched_hbm(
        G=G_reduced,
        N=N_reduced,
        only_event_shape=True
    ),
    plates_per_rv={
        "mu": ['P'],
        "mu_g": ['P', 'G'],
        "x": ['P', 'G', 'N']
    },
    link_functions=link_functions
)

dataset_name = (
    f"../data/GRE_dataset_G{G_reduced}_N{N_reduced}_D{D}.p"
)
try:
    dataset = pickle.load(
        open(dataset_name, "rb")
    )
    reduced_train_data = dataset["train"]
    reduced_val_data = dataset["val"]
except FileNotFoundError:
    seed = 1234
    tf.random.set_seed(seed)
    train_size, val_size = 20_000, 2000
    reduced_train_data, reduced_val_data = (
        reduced_hbm.sample(size)
        for size in [train_size, val_size]
    )
    dataset = {
        data_key: {
            key: value.numpy()
            for key, value in data.items()
        }
        for data_key, data in [
            ("train", reduced_train_data),
            ("val", reduced_val_data)
        ]
    }
    pickle.dump(
        dataset,
        open(dataset_name, "wb")
    )


# %%

class NADataset(tf.data.Dataset):

    def _generator():

        g_list = [g for g in range(G)]
        shuffle(g_list)

        X = val_data["x"][na_val_idx:na_val_idx + 1]
        X = X[None, :, g_list]

        for g in get_stops(
            full_size=G,
            batch_size=G_reduced
        ):
            yield (
                {
                    "x": X[:, :, g - G_reduced:g]
                },
                {
                    "P": tf.zeros((1,)),
                    "G": g_list[g - G_reduced:g],
                    "N": tf.range(0, N)
                }
            )

    def __new__(cls):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature=(
                {
                    "x": tf.TensorSpec(
                        shape=(1, 1, G_reduced, N, D),
                        dtype=tf.float32
                    )
                },
                {
                    "P": tf.TensorSpec(
                        shape=(1,),
                        dtype=tf.int32
                    ),
                    "G": tf.TensorSpec(
                        shape=(G_reduced,),
                        dtype=tf.int32
                    ),
                    "N": tf.TensorSpec(
                        shape=(N,),
                        dtype=tf.int32
                    )
                }
            )
        )


na_dataset = (
    NADataset()
    .prefetch(
        buffer_size=tf.data.AUTOTUNE
    )
)
