import bayesflow as bf
import numpy as np
import tensorflow as tf
from bayesflow.amortizers import (
    AmortizedPosterior,
    AmortizedLikelihood,
    AmortizedPosteriorLikelihood,
)
from scipy.stats import multivariate_normal

from src.self_consistency_real.amortizers import (
    AmortizedPosteriorLikelihoodSC,
)

from experiments.multivariate_normal.generative_model import get_generative_model


def get_inference_network(dimension=10):
    network = bf.networks.InvertibleNetwork(
        num_params=dimension,
        num_coupling_layers=5,
        coupling_design="spline",
        coupling_settings={
            "dense_args": dict(units=128),
            "kernel_regularizer": tf.keras.regularizers.l2(1e-3),
        },
    )

    return network


def get_likelihood_network(dimension=10):
    network = bf.networks.InvertibleNetwork(
        num_params=dimension,
        num_coupling_layers=5,
        coupling_design="spline",
        coupling_settings={
            "dense_args": dict(units=128),
            "kernel_regularizer": tf.keras.regularizers.l2(1e-3),
        },
    )

    return network


def configurator(forward_dict):
    input_dict = {}
    posterior_inputs = {}
    likelihood_inputs = {}

    posterior_inputs["parameters"] = forward_dict["prior_draws"].astype(np.float32)
    posterior_inputs["direct_conditions"] = forward_dict["sim_data"].astype(np.float32)

    likelihood_inputs["conditions"] = forward_dict["prior_draws"].astype(np.float32)
    likelihood_inputs["observables"] = forward_dict["sim_data"].astype(np.float32)

    input_dict["posterior_inputs"] = posterior_inputs
    input_dict["likelihood_inputs"] = likelihood_inputs

    return input_dict


def get_amortized_likelihood(dimension=10):
    network = get_likelihood_network(dimension=dimension)
    amortizer = AmortizedLikelihood(network)

    return amortizer


def get_amortized_posterior(dimension=10):
    network = get_inference_network(dimension=dimension)
    amortizer = AmortizedPosterior(network)

    return amortizer


def get_amortizer(
    dimension=10,
    rng=np.random.default_rng(),
):
    model = get_generative_model(dimension=dimension, rng=rng)
    prior = model.prior.prior

    real_data = multivariate_normal([3.0] * dimension).rvs(size=32, random_state=rng)
    if len(real_data.shape) == 1:
        real_data = np.expand_dims(real_data, axis=0)

    amortizer = AmortizedPosteriorLikelihoodSC(
        prior=prior,
        real_data=real_data,
        amortized_posterior=get_amortized_posterior(dimension=dimension),
        amortized_likelihood=get_amortized_likelihood(dimension=dimension),
    )

    return amortizer


def get_trainer(dimension=10, rng=np.random.default_rng(), **kwargs):
    generative_model = get_generative_model(dimension=dimension, rng=rng)
    amortizer = get_amortizer(dimension=dimension, rng=rng)

    trainer = bf.trainers.Trainer(
        amortizer=amortizer,
        generative_model=generative_model,
        configurator=configurator,
        **kwargs,
    )

    return trainer


def get_trainer_no_sc(dimension=10, rng=np.random.default_rng(), **kwargs):
    amortizer = AmortizedPosteriorLikelihood(
        amortized_posterior=get_amortized_posterior(dimension=dimension),
        amortized_likelihood=get_amortized_likelihood(dimension=dimension),
    )

    trainer = bf.trainers.Trainer(
        amortizer=amortizer,
        generative_model=get_generative_model(rng=rng),
        configurator=configurator,
        **kwargs,
    )

    return trainer
