import bayesflow as bf
import tensorflow as tf
import pickle
from pathlib import Path
from .generative_model import get_generative_model
from bayesflow.amortizers import AmortizedPosterior
from .amortized_posterior_sc import AmortizedPosteriorSC


def get_real_data():
    file_path = Path(__file__).parents[0] / "data" / "real_data.pkl"

    if not file_path.exists():
        model = get_generative_model()
        prior = tf.random.normal((1024, 7), 0, 2)
        data = model.simulator(prior)["sim_data"]
        real_data = data  # + tf.random.uniform(data.shape, minval=-2.0, maxval=2.0)

        with open(file_path, "wb") as file:
            pickle.dump(real_data, file)

    with open(file_path, "rb") as file:
        real_data = pickle.load(file)

    return real_data


def get_real_data_subset(n=32):
    x = get_real_data()
    indices = tf.random.shuffle(tf.range(tf.shape(x)[0]))[:n]

    subset = tf.gather(x, indices)

    return subset


def get_training_data():
    file_path = Path(__file__).parents[0] / "data" / "training_data.pkl"

    if not file_path.exists():
        model = get_generative_model()
        forward_dict = model(2**15)

        with open(file_path, "wb") as file:
            pickle.dump(forward_dict, file)

    with open(file_path, "rb") as file:
        forward_dict = pickle.load(file)

    return forward_dict


def get_summary_network():
    return tf.keras.Sequential(
        [
            tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, -1)),
            tf.keras.layers.LSTM(100),
            tf.keras.layers.Dense(400, activation="relu"),
            tf.keras.layers.Dense(200, activation="relu"),
            tf.keras.layers.Dense(100, activation="relu"),
            tf.keras.layers.Dense(50, activation="relu"),
        ]
    )


def get_inference_network():
    return bf.networks.InvertibleNetwork(
        num_params=7,
        num_coupling_layers=10,
        coupling_design="spline",
        coupling_settings={
            "dense_args": {"units": 256},
            "kernel_regularizer": tf.keras.regularizers.l2(1e-3),
        },
    )


def configurator(forward_dict):
    input_dict = {}

    # expand dims so summary network works on 4-dimensional inputs
    input_dict["parameters"] = forward_dict["prior_draws"]
    input_dict["summary_conditions"] = forward_dict["sim_data"]

    return input_dict


def get_amortizer():
    model = get_generative_model()
    summary_net = get_summary_network()
    inference_net = get_inference_network()

    simulator = model.simulator.simulator
    prior = model.prior.prior

    amortizer = AmortizedPosteriorSC(
        prior=prior,
        simulator=simulator,
        real_data=get_real_data_subset,
        inference_net=inference_net,
        summary_net=summary_net,
        n_consistency_samples=8,
    )

    return amortizer


def get_trainer(**kwargs):
    generative_model = get_generative_model()
    amortizer = get_amortizer()

    trainer = bf.trainers.Trainer(
        amortizer=amortizer,
        generative_model=generative_model,
        configurator=configurator,
        **kwargs,
    )

    return trainer


def get_trainer_no_sc(**kwargs):
    generative_model = get_generative_model()
    amortizer = AmortizedPosterior(
        inference_net=get_inference_network(), summary_net=get_summary_network()
    )

    trainer = bf.trainers.Trainer(
        amortizer=amortizer,
        generative_model=generative_model,
        configurator=configurator,
        **kwargs,
    )

    return trainer
