

# %% Imports
import pickle
import argparse
from typing import Dict
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from random import shuffle
from pavi.utils import get_stops

tfd = tfp.distributions
tfb = tfp.bijectors

# %% Argument parsing

parser = argparse.ArgumentParser()
parser.add_argument(
    "--na-val-idx",
    type=int,
    default=6,
    required=False
)
parser.add_argument(
    "--random-seed",
    type=int,
    default=1234,
    required=False
)
args, _ = parser.parse_known_args()

# %% Random seed

seed = args.random_seed
tf.random.set_seed(seed)
np.random.seed(seed)

# %% Validation index

na_val_idx = args.na_val_idx

# %% Generative Hierarchical Bayesian Model

D = 2

N1_full, N1_reduced = 15, 3
N2_full, N2_reduced = 15, 3

loc_a = tf.zeros((D,))
scale_a = 1.

loc_b = 0.

loc_c = 0.


def get_hbm(N1, N2) -> tfd.Distribution:
    return tfd.JointDistributionNamed(
        model=dict(
            a=tfd.Sample(
                tfd.TransformedDistribution(
                    tfd.Independent(
                        tfd.Normal(
                            loc=loc_a,
                            scale=scale_a
                        ),
                        reinterpreted_batch_ndims=1
                    ),
                    bijector=tfb.Exp()
                ),
                sample_shape=(1,),
                name="a"
            ),
            b=lambda a: tfd.TransformedDistribution(
                tfd.Sample(
                    tfd.TransformedDistribution(
                        tfd.Independent(
                            tfd.Normal(
                                loc=loc_b,
                                scale=a
                            ),
                            reinterpreted_batch_ndims=2
                        ),
                        bijector=tfb.Exp()
                    ),
                    sample_shape=(N1,)
                ),
                bijector=tfb.Transpose([1, 0, 2])
            ),
            c=lambda b: tfd.TransformedDistribution(
                tfd.Sample(
                    tfd.TransformedDistribution(
                        tfd.Independent(
                            tfd.Normal(
                                loc=loc_c,
                                scale=b
                            ),
                            reinterpreted_batch_ndims=3
                        ),
                        bijector=tfb.Exp()
                    ),
                    sample_shape=(N2,)
                ),
                bijector=tfb.Transpose([1, 2, 0, 3])
            ),
        )
    )


full_hbm = get_hbm(N1_full, N2_full)
reduced_hbm = get_hbm(N1_reduced, N2_reduced)

# %%

plates_per_rv = {
    "a": ['N0'],
    "b": ['N0', 'N1'],
    "c": ['N0', 'N1', 'N2']
}

link_functions = {
    "a": tfb.Exp(),
    "b": tfb.Exp(),
    "c": tfb.Exp()
}

pavi_kwargs = dict(
    full_hbm=full_hbm,
    reduced_hbm=reduced_hbm,
    plates_per_rv=plates_per_rv,
    link_functions=link_functions
)


# %% Dataset generation

dataset_name = "../data/HV_dataset.p"

try:
    dataset = pickle.load(
        open(dataset_name, "rb")
    )
    train_data = dataset["val"]
    val_data = dataset["val"]
except FileNotFoundError:
    seed = 1234
    tf.random.set_seed(seed)
    train_size, val_size = 2_000, 200
    train_data, val_data = (
        full_hbm.sample(size)
        for size in [train_size, val_size]
    )
    dataset = {
        data_key: {
            key: value.numpy()
            for key, value in data.items()
        }
        for data_key, data in [
            ("train", train_data),
            ("val", val_data)
        ]
    }
    pickle.dump(
        dataset,
        open(dataset_name, "wb")
    )


# %%

class NADataset(tf.data.Dataset):

    def _generator():

        n1_list = [n for n in range(N1_full)]
        n2_list = [n for n in range(N2_full)]
        shuffle(n1_list)
        shuffle(n2_list)

        c = val_data["c"][na_val_idx:na_val_idx + 1]
        c = c[:, :, n1_list]
        c = c[:, :, :, n2_list]

        for n1 in get_stops(
            full_size=N1_full,
            batch_size=N1_reduced
        ):
            c_n1 = c[:, :, n1 - N1_reduced:n1]
            for n2 in get_stops(
                full_size=N2_full,
                batch_size=N2_reduced
            ):
                c_n2 = c_n1[:, :, :, n2 - N2_reduced:n2]
                yield (
                    {
                        "c": c_n2
                    },
                    {
                        "N0": tf.zeros((1,)),
                        "N1": n1_list[n1 - N1_reduced:n1],
                        "N2": n2_list[n2 - N2_reduced:n2]
                    }
                )

    def __new__(cls):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature=(
                {
                    "c": tf.TensorSpec(
                        shape=(1, 1, N1_reduced, N2_reduced, D),
                        dtype=tf.float32
                    )
                },
                {
                    "N0": tf.TensorSpec(
                        shape=(1,),
                        dtype=tf.int32
                    ),
                    "N1": tf.TensorSpec(
                        shape=(N1_reduced,),
                        dtype=tf.int32
                    ),
                    "N2": tf.TensorSpec(
                        shape=(N2_reduced,),
                        dtype=tf.int32
                    )
                }
            )
        )


na_dataset = (
    NADataset()
    .prefetch(
        buffer_size=tf.data.AUTOTUNE
    )
)

# %%

ground_hbm = tfd.JointDistributionNamed(
    model=dict(
        a=tfd.TransformedDistribution(
            tfd.Independent(
                tfd.Normal(
                    loc=loc_a,
                    scale=scale_a
                ),
                reinterpreted_batch_ndims=1
            ),
            bijector=tfb.Exp(),
            name="a"
        ),
        **{
            f"b_{n1}": lambda a: tfd.TransformedDistribution(
                tfd.Independent(
                    tfd.Normal(
                        loc=loc_b,
                        scale=a
                    ),
                    reinterpreted_batch_ndims=1
                ),
                bijector=tfb.Exp()
            )
            for n1 in range(N1_full)
        },
        **{
            f"c_{n1}_{n2}": eval(
                f"""lambda b_{n1}: tfd.TransformedDistribution(
                    tfd.Independent(
                        tfd.Normal(
                            loc=loc_c,
                            scale=b_{n1}
                        ),
                        reinterpreted_batch_ndims=1
                    ),
                    bijector=tfb.Exp()
                )"""
            )
            for n1 in range(N1_full)
            for n2 in range(N2_full)
        }
    )
)

cf_hbm_kwargs = dict(
    generative_hbm=ground_hbm,
    observed_rvs=[
        f"c_{n1}_{n2}"
        for n1 in range(N1_full)
        for n2 in range(N2_full)
    ],
    link_functions={
        "a": tfb.Exp(),
        **{
            f"b_{n1}": tfb.Exp()
            for n1 in range(N1_full)
        },
        **{f"c_{n1}_{n2}": tfb.Exp()
            for n1 in range(N1_full)
            for n2 in range(N2_full)
        }
    },
    observed_rv_reshapers={f"c_{n1}_{n2}": tfb.Identity()
        for n1 in range(N1_full)
        for n2 in range(N2_full)
    }
)

# %%


def slice_data(
    data: Dict[str, tf.Tensor]
) -> Dict[str, tf.Tensor]:
    output_data = {}

    output_data["a"] = data["a"][..., 0, :]
    for n1 in range(N1_full):
        output_data[f"b_{n1}"] = data["b"][..., 0, n1, :]
        for n2 in range(N2_full):
            output_data[f"c_{n1}_{n2}"] = data["c"][..., 0, n1, n2, :]
    
    return output_data


def stack_data(
    data: Dict[str, tf.Tensor]
) -> Dict[str, tf.Tensor]:
    output_data = {}

    output_data["a"] = data["a"][..., None, :]
    output_data["b"] = tf.stack(
        [
            data[f"b_{n1}"][..., None, :]
            for n1 in range(N1_full)
        ],
        axis=-2
    )
    try:
        output_data["c"] = tf.stack(
            [
                tf.stack(
                    [
                        data[f"c_{n1}_{n2}"][..., None, :]
                        for n2 in range(N2_full)
                    ],
                    axis=-2
                )
                for n1 in range(N1_full)
            ],
            axis=-3
        )
    except KeyError:
        pass

    return output_data


cf_train_data = slice_data(train_data)
cf_val_data = slice_data(val_data)
