import bayesflow as bf
import numpy as np
import tensorflow as tf
from bayesflow.amortizers import AmortizedPosterior
from scipy.stats import multivariate_normal

from src.self_consistency_real.amortizers import (
    AmortizedPosteriorSC,
)

from .generative_model import get_generative_model


def get_inference_network(dimension=10):
    network = bf.networks.InvertibleNetwork(
        num_params=dimension,
        num_coupling_layers=5,
        coupling_design="spline",
        coupling_settings={
            "dense_args": dict(units=128),
            "kernel_regularizer": tf.keras.regularizers.l2(1e-3),
        },
    )

    return network


def get_summary_network():
    network = bf.networks.SetTransformer(input_dim=10, summary_dim=30)

    return network


def configurator(forward_dict):
    input_dict = {}

    input_dict["parameters"] = forward_dict["prior_draws"].astype(np.float32)
    input_dict["summary_conditions"] = forward_dict["sim_data"].astype(np.float32)

    return input_dict


def configurator_no_summary(forward_dict):
    input_dict = {}

    input_dict["parameters"] = forward_dict["prior_draws"].astype(np.float32)
    input_dict["direct_conditions"] = forward_dict["sim_data"].astype(np.float32)

    return input_dict


def get_amortizer(
    with_summary_network=True,
    dimension=10,
    dataset_size=32,
    rng=np.random.default_rng(),
):
    model = get_generative_model(
        with_summary_network=with_summary_network, dimension=dimension, rng=rng
    )
    inference_net = get_inference_network(dimension=dimension)

    simulator = model.simulator.simulator
    prior = model.prior.prior

    if with_summary_network:
        real_data = multivariate_normal([3.0] * dimension).rvs(
            size=(dataset_size, 10), random_state=rng
        )
    else:
        real_data = multivariate_normal([3.0] * dimension).rvs(
            size=dataset_size, random_state=rng
        )

    if len(real_data.shape) == 1:
        real_data = np.expand_dims(real_data, axis=0)

    if with_summary_network:
        summary_net = get_summary_network()
        amortizer = AmortizedPosteriorSC(
            prior=prior,
            simulator=simulator,
            real_data=real_data,
            inference_net=inference_net,
            summary_net=summary_net,
        )
    else:
        amortizer = AmortizedPosteriorSC(
            prior=prior,
            simulator=simulator,
            real_data=real_data,
            inference_net=inference_net,
        )

    return amortizer


def get_trainer(
    with_summary_network=True,
    dimension=10,
    dataset_size=32,
    rng=np.random.default_rng(),
    **kwargs,
):
    generative_model = get_generative_model(
        with_summary_network=with_summary_network, dimension=dimension, rng=rng
    )
    amortizer = get_amortizer(
        with_summary_network=with_summary_network,
        dimension=dimension,
        dataset_size=dataset_size,
        rng=rng,
    )

    if with_summary_network:
        trainer = bf.trainers.Trainer(
            amortizer=amortizer,
            generative_model=generative_model,
            configurator=configurator,
            **kwargs,
        )
    else:
        trainer = bf.trainers.Trainer(
            amortizer=amortizer,
            generative_model=generative_model,
            configurator=configurator_no_summary,
            **kwargs,
        )

    return trainer


def get_trainer_no_sc(
    with_summary_network=True, dimension=10, rng=np.random.default_rng(), **kwargs
):
    generative_model = get_generative_model(
        with_summary_network=with_summary_network, dimension=dimension, rng=rng
    )
    inference_net = get_inference_network(dimension=dimension)

    if with_summary_network:
        summary_net = get_summary_network()
        amortizer = AmortizedPosterior(
            inference_net=inference_net, summary_net=summary_net
        )

        trainer = bf.trainers.Trainer(
            amortizer=amortizer,
            generative_model=generative_model,
            configurator=configurator,
            **kwargs,
        )
    else:
        amortizer = AmortizedPosterior(inference_net=inference_net)

        trainer = bf.trainers.Trainer(
            amortizer=amortizer,
            generative_model=generative_model,
            configurator=configurator_no_summary,
            **kwargs,
        )

    return trainer
