import bayesflow as bf
import numpy as np
import tensorflow as tf
from bayesflow.amortizers import AmortizedPosterior
from scipy.stats import multivariate_normal

from src.self_consistency_real.amortizers import (
    AmortizedPosteriorSC,
)

from .generative_model import get_generative_model


def get_inference_network(dimension=10):
    network = bf.networks.InvertibleNetwork(
        num_params=dimension,
        num_coupling_layers=5,
        coupling_design="spline",
        coupling_settings={
            "dense_args": dict(units=128),
            "kernel_regularizer": tf.keras.regularizers.l2(1e-3),
        },
    )

    return network


def configurator(forward_dict):
    input_dict = {}

    input_dict["parameters"] = forward_dict["prior_draws"].astype(np.float32)
    input_dict["direct_conditions"] = forward_dict["sim_data"].astype(np.float32)

    return input_dict


def get_amortizer(
    dimension=10,
    dataset_size=32,
    dataset_mean=3.0,
    rng=np.random.default_rng(),
):
    model = get_generative_model(dimension=dimension, rng=rng)
    inference_net = get_inference_network(dimension=dimension)

    simulator = model.simulator.simulator
    prior = model.prior.prior

    real_data = multivariate_normal([dataset_mean] * dimension).rvs(
        size=dataset_size, random_state=rng
    )
    if len(real_data.shape) == 1:
        real_data = np.expand_dims(real_data, axis=0)

    amortizer = AmortizedPosteriorSC(
        prior=prior,
        simulator=simulator,
        real_data=real_data,
        inference_net=inference_net,
    )

    return amortizer


def get_trainer(
    dimension=10,
    dataset_size=32,
    dataset_mean=3.0,
    rng=np.random.default_rng(),
    **kwargs,
):
    generative_model = get_generative_model(dimension=dimension, rng=rng)
    amortizer = get_amortizer(
        dimension=dimension,
        dataset_size=dataset_size,
        dataset_mean=dataset_mean,
        rng=rng,
    )

    trainer = bf.trainers.Trainer(
        amortizer=amortizer,
        generative_model=generative_model,
        configurator=configurator,
        **kwargs,
    )

    return trainer


def get_trainer_no_sc(dimension=10, rng=np.random.default_rng(), **kwargs):
    generative_model = get_generative_model(dimension=dimension, rng=rng)
    inference_net = get_inference_network(dimension=dimension)

    amortizer = AmortizedPosterior(inference_net=inference_net)

    trainer = bf.trainers.Trainer(
        amortizer=amortizer,
        generative_model=generative_model,
        configurator=configurator,
        **kwargs,
    )

    return trainer
