from __future__ import annotations

import numpy as np
import tensorflow_probability as tfp
from bayesflow.simulation import GenerativeModel


def get_generative_model(
    prior: PriorWithLogProb | None = None,
    simulator: SimulatorWithLogProb | None = None,
    dimension: int = 10,
    rng=np.random.default_rng(),
):
    if prior is None:
        prior = PriorWithLogProb(dimension=dimension, rng=rng)
    if simulator is None:
        simulator = SimulatorWithLogProb(dimension=dimension, rng=rng)

    model = GenerativeModel(
        prior=prior,
        simulator=simulator,
        prior_is_batched=False,
        simulator_is_batched=False,
    )

    return model


class SimulatorWithLogProb:
    def __init__(self, scale=1, dimension=10, rng=np.random.default_rng()):
        self.scale = scale
        self.dimension = dimension
        self.rng = rng

    def __call__(self, theta):
        return self.rng.normal(loc=theta, scale=self.scale)

    def log_prob(self, theta, x):
        dist = tfp.distributions.Independent(
            tfp.distributions.Normal(loc=theta, scale=self.scale),
            reinterpreted_batch_ndims=1,
        )

        return dist.log_prob(x)


class PriorWithLogProb:
    def __init__(self, loc=0, scale=1, dimension=10, rng=np.random.default_rng()):
        self.loc = loc
        self.scale = scale
        self.dimension = dimension
        self.rng = rng

    def __call__(self):
        return self.rng.normal(loc=self.loc, scale=self.scale, size=(self.dimension,))

    def log_prob(self, theta):
        dist = tfp.distributions.Independent(
            tfp.distributions.Normal(
                loc=[self.loc] * self.dimension, scale=[self.scale] * self.dimension
            ),
            reinterpreted_batch_ndims=1,
        )

        return dist.log_prob(theta)
