import tensorflow as tf
import tensorflow_probability as tfp
from bayesflow.simulation import GenerativeModel
from .ode import HodgkinHuxleyODE


def get_generative_model():
    prior = PriorWithLogProb()
    simulator = SimulatorWithLogProb()

    model = GenerativeModel(
        prior=prior,
        simulator=simulator,
        prior_is_batched=False,
        simulator_is_batched=True,
    )

    return model


class SimulatorWithLogProb:
    def __init__(self):
        self.ode = HodgkinHuxleyODE()

    def __call__(self, z):
        theta = z_to_theta(z)
        result = self.ode.solve_ode(theta)
        dist = tfp.distributions.StudentT(loc=result, scale=0.1, df=10)

        y = dist.sample()

        return y

    def log_prob(self, z, x):
        theta = z_to_theta(z)
        y = self.ode.solve_ode(theta)
        dist = tfp.distributions.StudentT(loc=y, scale=0.1, df=10)

        x_flat = tf.reshape(x, (-1, tf.shape(x)[-1]))
        pointwise_log_prob = dist.log_prob(x_flat)
        log_prob = tf.reduce_mean(pointwise_log_prob, axis=-1)
        log_prob = tf.reshape(log_prob, (tf.shape(x)[:-1]))

        return log_prob


class PriorWithLogProb:
    def __call__(self):
        z = tfp.distributions.Normal(loc=0, scale=1).sample(7)

        return z

    def log_prob(self, z):
        return tf.reduce_sum(
            tfp.distributions.Normal(loc=0, scale=1).log_prob(z), axis=-1
        )


def theta_to_z(theta):
    z_1 = (tf.math.log(theta[..., 0]) - tf.math.log(110.0)) / 0.1
    z_2 = (tf.math.log(theta[..., 1]) - tf.math.log(36.0)) / 0.1
    z_3 = (tf.math.log(theta[..., 2]) - tf.math.log(0.2)) / 0.5
    z_4 = (theta[..., 3] - 1.0) / 0.05
    z_5 = (theta[..., 4] + 55.0) / 5.0
    z_6 = (theta[..., 5] - 50.0) / 5.0
    z_7 = (theta[..., 6] + 77.0) / 5.0

    return tf.stack([z_1, z_2, z_3, z_4, z_5, z_6, z_7], axis=-1)


def z_to_theta(z):
    theta_1 = tf.exp(tf.math.log(110.0) + 0.1 * z[..., 0])
    theta_2 = tf.exp(tf.math.log(36.0) + 0.1 * z[..., 1])
    theta_3 = tf.exp(tf.math.log(0.2) + 0.5 * z[..., 2])
    theta_4 = z[..., 3] * 0.05 + 1.0
    theta_5 = z[..., 4] * 5.0 - 55.0
    theta_6 = z[..., 5] * 5.0 + 50.0
    theta_7 = z[..., 6] * 5.0 - 77.0

    return tf.stack(
        [theta_1, theta_2, theta_3, theta_4, theta_5, theta_6, theta_7], axis=-1
    )
