from __future__ import annotations

import numpy as np
import tensorflow_probability as tfp
import tensorflow as tf
from bayesflow.simulation import GenerativeModel


def get_generative_model(
    prior: PriorWithLogProb | None = None,
    simulator: SimulatorWithLogProb | None = None,
    dimension: int = 10,
    with_summary_network=True,
    rng=np.random.default_rng(),
):
    if prior is None:
        prior = PriorWithLogProb(dimension=dimension, rng=rng)
    if simulator is None:
        simulator = SimulatorWithLogProb(
            with_summary_network=with_summary_network, dimension=dimension, rng=rng
        )

    model = GenerativeModel(
        prior=prior,
        simulator=simulator,
        prior_is_batched=False,
        simulator_is_batched=False,
    )

    return model


class SimulatorWithLogProb:
    def __init__(
        self,
        with_summary_network=True,
        scale=1,
        dimension=10,
        rng=np.random.default_rng(),
    ):
        self.with_summary_network = with_summary_network
        self.scale = scale
        self.dimension = dimension
        self.rng = rng

    def __call__(self, theta):
        if self.with_summary_network:
            # np.sqrt(10) to account for 10 observations vs. 1
            return self.rng.normal(
                size=(10, len(theta)), loc=theta, scale=self.scale * tf.sqrt(10.0)
            )
        else:
            return self.rng.normal(loc=theta, scale=self.scale)

    def log_prob(self, theta, x):
        if self.with_summary_network:
            dist = tfp.distributions.Independent(
                tfp.distributions.Normal(loc=theta, scale=self.scale * tf.sqrt(10.0)),
                reinterpreted_batch_ndims=1,
            )
        else:
            dist = tfp.distributions.Independent(
                tfp.distributions.Normal(loc=theta, scale=self.scale),
                reinterpreted_batch_ndims=1,
            )

        return dist.log_prob(x)


class PriorWithLogProb:
    def __init__(self, loc=0, scale=1, dimension=10, rng=np.random.default_rng()):
        self.loc = loc
        self.scale = scale
        self.dimension = dimension
        self.rng = rng

    def __call__(self):
        return self.rng.normal(loc=self.loc, scale=self.scale, size=(self.dimension,))

    def log_prob(self, theta):
        dist = tfp.distributions.Independent(
            tfp.distributions.Normal(
                loc=[self.loc] * self.dimension, scale=[self.scale] * self.dimension
            ),
            reinterpreted_batch_ndims=1,
        )

        return dist.log_prob(theta)
