import numpy as np
from scipy.stats import norm
import tensorflow_probability as tfp
import tensorflow as tf
from functools import partial
import bayesflow as bf

tfd = tfp.distributions

def sample_predictors(rng, num_obs):
    correlation = np.random.uniform(-1, 1)
    corr = np.array([[1, correlation], [correlation, 1]])
    X = rng.multivariate_normal([0, 0], corr, num_obs + 1)
    X_std = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
    debt = X_std[:, 0]
    gdp = X_std[:, 1]

    return debt, gdp

class PriorWithLogProb:
    def __init__(self, rng):
        self.rng = rng

    def __call__(self):
        b0 = self.rng.normal(loc=0, scale=0.5)
        bt = self.rng.normal(loc=0, scale=0.2)
        b1 = self.rng.normal(loc=0, scale=0.5)
        b2 = self.rng.normal(loc=0, scale=0.5)
        logsigma = self.rng.normal(loc=-1, scale=0.5)

        return np.c_[b0, bt, b1, b2, logsigma]

    def log_prob(self, theta):
        
        b0      = theta[..., 0]
        bt      = theta[..., 1]
        b1      = theta[..., 2]
        b2      = theta[..., 3]
        logsigma= theta[..., 4]

        lp_b0       = tfd.Normal(loc=0,   scale=0.5).log_prob(b0)
        lp_bt       = tfd.Normal(loc=0,   scale=0.2).log_prob(bt)
        lp_b1       = tfd.Normal(loc=0,   scale=0.5).log_prob(b1)
        lp_b2       = tfd.Normal(loc=0,   scale=0.5).log_prob(b2)
        lp_logsigma = tfd.Normal(loc=-1,  scale=0.5).log_prob(logsigma)

        logp = lp_b0 + lp_bt + lp_b1 + lp_b2 + lp_logsigma

        return logp

class LikelihoodSampler:
    def __init__(self, rng=None):
        # Set the random number generator
        if rng is None:
            self.rng = np.random.default_rng()
        else:
            self.rng = rng

    def sample_likelihood(self, params, predictors, max_num_obs):

        debt, gdp = predictors
        y = np.empty(shape=(max_num_obs, 4))

        b0        = params[0][0]
        bt        = params[0][1]
        b1        = params[0][2]
        b2        = params[0][3]
        logsigma  = params[0][4]
        sigma     = np.exp(logsigma)

        # First two "pre-lag" draws (not stored in y):
        y0 = self.rng.normal(loc=0.5, scale=1.0)  # 2003-2004
        y1 = self.rng.normal(
            loc=b0 + bt*y0 + b1 * debt[0] + b2 * gdp[0],
            scale=sigma
        )  # 2004-2005

        # Fill first row in y (this is 2005-2006)
        y[0, 0] = self.rng.normal(
            loc=b0 + bt * y1 + b1 * debt[1] + b2 * gdp[1],
            scale=sigma
        )
        y[0, 1] = debt[1]
        y[0, 2] = gdp[1]
        y[0, 3] = y1   # lagged y for time step 0

        for i in range(1, max_num_obs):
            y[i, 0] = self.rng.normal(
                loc=b0 + bt * y[i - 1, 0] + b1 * debt[i + 1] + b2 * gdp[i + 1],
                scale=sigma
            )
            y[i, 1] = debt[i + 1]
            y[i, 2] = gdp[i + 1]
            y[i, 3] = y[i - 1, 0]  # lagged y

        return y

    def log_prob(self, theta, x):

        b0       = theta[..., 0]
        bt       = theta[..., 1]
        b1       = theta[..., 2]
        b2       = theta[..., 3]
        logsigma = theta[..., 4]
        sigma    = tf.exp(logsigma)

        y_val = x[..., 0]
        debt  = x[..., 1]
        gdp   = x[..., 2]
        y_lag = x[..., 3]

        b0    = b0   [..., tf.newaxis]
        bt    = bt   [..., tf.newaxis]
        b1    = b1   [..., tf.newaxis]
        b2    = b2   [..., tf.newaxis]
        sigma = sigma[..., tf.newaxis]

        mean = b0 + bt * y_lag + b1 * debt + b2 * gdp  # shape [n_samples, n_datasets, 14]
        dist = tfd.Normal(loc=mean, scale=sigma)
        log_pdf = dist.log_prob(y_val)
        log_lik = tf.reduce_sum(log_pdf, axis=-1)

        return log_lik
    
def build_generative_model(rng, num_obs, simulation_budget=1024):
    prior = PriorWithLogProb(rng=rng)
    sample_likelihood = LikelihoodSampler(rng=rng)
    sample_predictors_with_defaults = partial(sample_predictors, rng=rng, num_obs=num_obs)
    
    simulator = bf.simulation.Simulator(
        simulator_fun=partial(sample_likelihood.sample_likelihood, max_num_obs=num_obs),
        context_generator=bf.simulation.ContextGenerator(
            batchable_context_fun=sample_predictors_with_defaults,
        ),
    )

    bf.simulation.Simulator.log_prob = sample_likelihood.log_prob
    
    generative_model = bf.simulation.GenerativeModel(
        prior=prior,
        simulator=simulator,
        name="air_traffic",
    )
    
    train_data = generative_model(simulation_budget)
    
    return generative_model, train_data, prior, simulator
