# %% Imports
from typing import Dict
import pickle
import argparse

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

from pavi.utils import (
    repeat_to_shape,
    one_hot_straight_through
)

tfd = tfp.distributions
tfb = tfp.bijectors
Root = tfp.distributions.JointDistributionCoroutine.Root

# %% Sigmoid function

sig = tfb.Sigmoid(tf.cast(-1.0, tf.float32), tf.cast(1.0, tf.float32))

# %%

parser = argparse.ArgumentParser()
parser.add_argument(
    "--na-val-idx",
    type=int,
    default=4,
    required=False
)
args, _ = parser.parse_known_args()

na_val_idx = args.na_val_idx

# %% Generative Hierarchical Bayesian Model

# synhetic
N_full, N_reduced = 12, 3
T_full, T_reduced = 6, 3
S_full, S_reduced = 5, 3
D = 4
L = 7

# Synthetic
mu_g_low = -5
mu_g_high = 5
kappa_low = -2
kappa_high = -1
sigma_low = -1
sigma_high = 0
epsilon_low = 0
epsilon_high = 1

concentration = 1.0 * tf.ones((L,))
temperature = 1.0


def get_hbm(S: int, T: int, N: int) -> tfd.Distribution:
    return tfd.JointDistributionNamed(
        model=dict(
            mu_g=tfd.Sample(
                tfd.Independent(
                    tfd.Uniform(
                        low=mu_g_low * tf.ones((D,)),
                        high=mu_g_high * tf.ones((D,)),
                    ),
                    reinterpreted_batch_ndims=1
                ),
                sample_shape=(1, L,),
                name="mu_g"
            ),
            epsilon=tfd.TransformedDistribution(
                tfd.Sample(
                    tfd.Uniform(
                        low=epsilon_low,
                        high=epsilon_high
                    ),
                    sample_shape=(1, L, D)
                ),
                bijector=tfb.Exp(),
                name="epsilon"
            ),
            mu_s=lambda mu_g, epsilon: tfd.TransformedDistribution(
                tfd.Sample(
                    tfd.Independent(
                        tfd.Normal(
                            loc=mu_g,
                            scale=epsilon
                        ),
                        reinterpreted_batch_ndims=3
                    ),
                    sample_shape=(S,),
                ),
                bijector=tfb.Transpose([1, 0, 2, 3]),
                name="mu_s"
            ),
            sigma=tfd.TransformedDistribution(
                tfd.Sample(
                    tfd.Uniform(
                        low=sigma_low,
                        high=sigma_high
                    ),
                    sample_shape=(1, L, D)
                ),
                bijector=tfb.Exp(),
                name="sigma"
            ),
            mu_s_t=lambda mu_s, sigma: tfd.TransformedDistribution(
                tfd.Sample(
                    tfd.Independent(
                        tfd.Normal(
                            loc=mu_s,
                            scale=repeat_to_shape(
                                sigma,
                                target_shape=(S,),
                                axis=-3
                            ),
                        ),
                        reinterpreted_batch_ndims=4
                    ),
                    sample_shape=(T,)
                ),
                bijector=tfb.Transpose([1, 2, 0, 3, 4]),
                name="mu_s_t"
            ),
            kappa=tfd.TransformedDistribution(
                tfd.Sample(
                    tfd.Uniform(
                        low=kappa_low,
                        high=kappa_high
                    ),
                    sample_shape=(1, L, D)
                ),
                bijector=tfb.Exp(),
                name="kappa"
            ),
            probs=tfd.Sample(
                tfd.Dirichlet(
                    concentration=concentration,
                ),
                sample_shape=(1, S, N),
                name="probs"
            ),
            labels=lambda probs: tfd.Independent(
                tfd.RelaxedOneHotCategorical(
                    temperature=temperature,
                    probs=probs
                ),
                reinterpreted_batch_ndims=3,
                name="labels"
            ),
            X_s_t=lambda mu_s_t, kappa, labels: tfd.Independent(
                tfd.Normal(
                    loc=tf.reduce_sum(
                        repeat_to_shape(
                            mu_s_t,
                            target_shape=(N,),
                            axis=-3
                        )
                        *
                        tf.expand_dims(
                            repeat_to_shape(
                                one_hot_straight_through(labels),
                                target_shape=(T,),
                                axis=-3
                            ),
                            axis=-1
                        ),
                        axis=-2
                    ),
                    scale=tf.reduce_sum(
                        repeat_to_shape(
                            kappa,
                            target_shape=(S, T, N,),
                            axis=-3
                        )
                        *
                        tf.expand_dims(
                            repeat_to_shape(
                                labels,
                                target_shape=(T,),
                                axis=-3
                            ),
                            axis=-1
                        ),
                        axis=-2
                    )
                ),
                reinterpreted_batch_ndims=5,
                name="X_s_t"
            )
        )
    )


full_hbm = get_hbm(
    S=S_full,
    T=T_full,
    N=N_full
)
reduced_hbm = get_hbm(
    S=S_reduced,
    T=T_reduced,
    N=N_reduced
)

# %%

plates_per_rv = {
    "mu_g": ['P'],
    "epsilon": ['P'],
    "mu_s": ['P', 'S'],
    "sigma": ['P'],
    "mu_s_t": ['P', 'S', 'T'],
    "probs": ['P', 'S', 'N'],
    "labels": ['P', 'S', 'N'],
    "kappa": ['P'],
    "X_s_t": ['P', 'S', 'T', 'N']
}

link_functions = {
    "mu_g": tfb.SoftClip(
        low=mu_g_low,
        high=mu_g_high
    ),
    "epsilon": tfb.Chain(
        [
            tfb.Exp(),
            tfb.SoftClip(
                low=epsilon_low,
                high=epsilon_high
            )
        ]
    ),
    "mu_s": tfb.Identity(),
    "sigma": tfb.Chain(
        [
            tfb.Exp(),
            tfb.SoftClip(
                low=sigma_low,
                high=sigma_high
            )
        ]
    ),
    "mu_s_t": tfb.Identity(),
    "kappa": tfb.Chain(
        [
            tfb.Exp(),
            tfb.SoftClip(
                low=kappa_low,
                high=kappa_high
            )
        ]
    ),
    "probs": tfb.SoftmaxCentered(),
    "labels": tfb.SoftmaxCentered(),
    "X_s_t": tfb.Identity()
}

pavi_kwargs = dict(
    full_hbm=full_hbm,
    reduced_hbm=reduced_hbm,
    plates_per_rv=plates_per_rv,
    link_functions=link_functions
)


# %% Dataset generation

try:
    dataset = pickle.load(
        open("../data/HCPL_dataset.p", "rb")
    )
    train_data = dataset["train"]
    val_data = dataset["val"]
except FileNotFoundError:
    seed = 1234
    tf.random.set_seed(seed)
    train_size, val_size = 20_000, 2000
    train_data, val_data = (
        full_hbm.sample(size)
        for size in [train_size, val_size]
    )
    dataset = {
        data_key: {
            key: value.numpy()
            for key, value in data.items()
        }
        for data_key, data in [
            ("train", train_data),
            ("val", val_data)
        ]
    }
    pickle.dump(
        dataset,
        open("../data/HCPL_dataset.p", "wb")
    )


# %% Ground HBM, used by CF

ground_hbm = tfd.JointDistributionNamed(
    model=dict(
        mu_g=tfd.Sample(
            tfd.Independent(
                tfd.Uniform(
                    low=mu_g_low * tf.ones((D,)),
                    high=mu_g_high * tf.ones((D,)),
                ),
                reinterpreted_batch_ndims=1
            ),
            sample_shape=(1, L,),
            name="mu_g"
        ),
        epsilon=tfd.TransformedDistribution(
            tfd.Sample(
                tfd.Uniform(
                    low=epsilon_low,
                    high=epsilon_high
                ),
                sample_shape=(1, L,)
            ),
            bijector=tfb.Exp(),
            name="epsilon"
        ),
        **{
            f"mu_{s}": lambda mu_g, epsilon: tfd.Independent(
                tfd.Normal(
                    loc=mu_g,
                    scale=tf.expand_dims(epsilon, axis=-1)
                ),
                reinterpreted_batch_ndims=3
            )
            for s in range(S_full)
        },
        sigma=tfd.TransformedDistribution(
            tfd.Sample(
                tfd.Uniform(
                    low=sigma_low,
                    high=sigma_high
                ),
                sample_shape=(1, L,)
            ),
            bijector=tfb.Exp(),
            name="sigma"
        ),
        **{
            f"mu_{s}_{t}": eval(
                f"""lambda mu_{s}, sigma: tfd.Independent(
                    tfd.Normal(
                        loc=mu_{s},
                        scale=tf.expand_dims(sigma, axis=-1)
                    ),
                    reinterpreted_batch_ndims=3
                )"""
            )
            for s in range(S_full)
            for t in range(T_full)
        },
        kappa=tfd.TransformedDistribution(
            tfd.Sample(
                tfd.Uniform(
                    low=kappa_low,
                    high=kappa_high
                ),
                sample_shape=(1,)
            ),
            bijector=tfb.Exp(),
            name="kappa"
        ),
        **{
            f"probs_{s}_{n}": tfd.Independent(
                tfd.Dirichlet(
                    concentration=tf.expand_dims(
                        concentration,
                        axis=0
                    )
                ),
                reinterpreted_batch_ndims=1
            )
            for s in range(S_full)
            for n in range(N_full)
        },
        **{
            f"labels_{s}_{n}": eval(
                f"""lambda probs_{s}_{n}: tfd.Independent(
                    tfd.RelaxedOneHotCategorical(
                        probs=probs_{s}_{n},
                        temperature={temperature}
                    ),
                    reinterpreted_batch_ndims=1
                )"""
            )
            for s in range(S_full)
            for n in range(N_full)
        },
        **{
            f"X_{s}_{t}_{n}": eval(
                f"""lambda mu_{s}_{t}, kappa, labels_{s}_{n}: tfd.Independent(
                    tfd.Normal(
                        loc=tf.reduce_sum(
                            tf.expand_dims(
                                one_hot_straight_through(labels_{s}_{n}),
                                axis=-1
                            )
                            * mu_{s}_{t},
                            axis=-2
                        ),
                        scale=tf.expand_dims(kappa, axis=-1)
                    ),
                    reinterpreted_batch_ndims=2
                )"""
            )
            for s in range(S_full)
            for t in range(T_full)
            for n in range(N_full)
        },
    )
)

cf_hbm_kwargs = dict(
    generative_hbm=ground_hbm,
    observed_rvs=[
        f"X_{s}_{t}_{n}"
        for s in range(S_full)
        for t in range(T_full)
        for n in range(N_full)
    ],
    link_functions={
        "mu_g": tfb.SoftClip(
            low=mu_g_low,
            high=mu_g_high
        ),
        "epsilon": tfb.Chain(
            [
                tfb.Exp(),
                tfb.SoftClip(
                    low=epsilon_low,
                    high=epsilon_high
                )
            ]
        ),
        **{
            f"mu_{s}": tfb.Identity()
            for s in range(S_full)
        },
        "sigma": tfb.Chain(
            [
                tfb.Exp(),
                tfb.SoftClip(
                    low=sigma_low,
                    high=sigma_high
                )
            ]
        ),
        **{
            f"mu_{s}_{t}": tfb.Identity()
            for s in range(S_full)
            for t in range(T_full)
        },
        "kappa": tfb.Chain(
            [
                tfb.Exp(),
                tfb.SoftClip(
                    low=kappa_low,
                    high=kappa_high
                )
            ]
        ),
        **{
            f"probs_{s}_{n}": tfb.SoftmaxCentered()
            for s in range(S_full)
            for n in range(N_full)
        },
        **{
            f"labels_{s}_{n}": tfb.SoftmaxCentered()
            for s in range(S_full)
            for n in range(N_full)
        },
        **{
            f"X_{s}_{t}_{n}": tfb.Identity()
            for s in range(S_full)
            for t in range(T_full)
            for n in range(N_full)
        }
    },
    observed_rv_reshapers={
        f"X_{s}_{t}_{n}": tfb.Reshape(
            event_shape_in=(D,),
            event_shape_out=(1, D)
        )
        for s in range(S_full)
        for t in range(T_full)
        for n in range(N_full)
    }
)


# %% Data reshaping


def stack_data(
    data: Dict[str, tf.Tensor]
) -> Dict[str, tf.Tensor]:
    output_data = {}
    for rv in [
        "mu_g",
        "epsilon",
        "sigma",
        "kappa"
    ]:
        output_data[rv] = data[rv]

    output_data["mu_s"] = tf.stack(
        [
            data[f"mu_{s}"]
            for s in range(S_full)
        ],
        axis=-3
    )
    output_data["mu_s_t"] = tf.stack(
        [
            tf.stack(
                [
                    data[f"mu_{s}_{t}"]
                    for t in range(T_full)
                ],
                axis=-3
            )
            for s in range(S_full)
        ],
        axis=-4
    )
    output_data["probs"] = tf.stack(
        [
            tf.stack(
                [
                    data[f"probs_{s}_{n}"]
                    for n in range(N_full)
                ],
                axis=-2
            )
            for s in range(S_full)
        ],
        axis=-3
    )
    output_data["labels"] = tf.stack(
        [
            tf.stack(
                [
                    data[f"labels_{s}_{n}"]
                    for n in range(N_full)
                ],
                axis=-2
            )
            for s in range(S_full)
        ],
        axis=-3
    )
    try:
        output_data["X_s_t"] = tf.stack(
            [
                tf.stack(
                    [
                        tf.stack(
                            [
                                data[f"X_{s}_{t}_{n}"]
                                for n in range(N_full)
                            ],
                            axis=-2
                        )
                        for t in range(T_full)
                    ],
                    axis=-3
                )
                for s in range(S_full)
            ],
            axis=-4
        )
    except KeyError:
        pass

    return output_data


def slice_data(
    data: Dict[str, tf.Tensor]
) -> Dict[str, tf.Tensor]:
    output_data = {}
    for rv in [
        "mu_g",
        "epsilon",
        "sigma",
        "kappa"
    ]:
        output_data[rv] = data[rv]

    for s in range(S_full):
        output_data[f"mu_{s}"] = data["mu_s"][..., s, :, :]
        for t in range(T_full):
            output_data[f"mu_{s}_{t}"] = data["mu_s_t"][..., s, t, :, :]
            for n in range(N_full):
                output_data[f"X_{s}_{t}_{n}"] = data["X_s_t"][..., s, t, n, :]
        for n in range(N_full):
            output_data[f"probs_{s}_{n}"] = data["probs"][..., s, n, :]
            output_data[f"labels_{s}_{n}"] = data["labels"][..., s, n, :]

    return output_data


# %% CF Data

cf_train_data = slice_data(train_data)
cf_val_data = slice_data(val_data)
