"""Implementation of the misspecified MA(1) model."""

import jax.numpy as jnp
import numpyro.distributions as dist  # type: ignore
from jax import random
from jax._src.prng import PRNGKeyArray  # for typing


def assumed_dgp(rng_key: PRNGKeyArray,
                t1: float,
                n_obs: int = 100
                ) -> jnp.ndarray:
    """Sample a moving average (MA) model of order 1.

    Args:
        rng_key (PRNGKeyArray): a PRNGKeyArray for reproducibility
        t1 (float): The coefficient for the first lag in the MA(1) model.
        n_obs (int): The number of observations to generate. Defaults to 100.

    Returns:
        jnp.ndarray: An array representing a time series generated by
                     the MA(1) model.

    """
    w = dist.Normal(0, 1).sample(key=rng_key, sample_shape=(1, n_obs + 2))
    x = w[:, 2:] + t1 * w[:, 1:-1]
    return x


def autocov(x: jnp.ndarray,
            lag: int = 1
            ) -> jnp.ndarray:
    """Compute the autocovariance of a given sequence at a specified lag.

    Args:
        x (jnp.ndarray): The sequence for which the autocovariance is to
                         be computed.
        lag (int, optional): The lag at which to compute the autocovariance.
                             Defaults to 1.


    Returns:
        jnp.ndarray: A JAX array containing the autocovariance at the specified lag.
    """
    x = jnp.atleast_2d(x)
    if lag == 0:
        C = jnp.mean(x[:, :] ** 2, axis=1)
    else:
        C = jnp.mean(x[:, lag:] * x[:, :-lag], axis=1)

    return C


def calculate_summary_statistics(x):
    """Calculate summary statistics for misspec MA(1) example."""
    s0 = autocov(x, lag=0)
    s1 = autocov(x, lag=1)
    return jnp.squeeze(jnp.array([s0, s1]))


def get_prior():
    """Return prior for inference on misspec MA(1)."""
    return dist.Uniform(low=jnp.array([-1.0]),
                        high=jnp.array([1.0]))


# poorly implemented ... but only run once
def true_dgp(key: PRNGKeyArray,
             w: float = -0.736,
             rho: float = 0.9,
             sigma_v: float = 0.36,
             batch_size: int = 1,
             n_obs: int = 100
             ) -> jnp.ndarray:
    """Sample from a stochastic volatility model using a normally distributed shock term.

    Args:
        key (PRNGKeyArray): a PRNGKeyArray for reproducibility.
        w (float, optional): Defaults to -0.736.
        rho (float, optional): Defaults to 0.9.
        sigma_v (float, optional): Defaults to 0.36.
        batch_size (int, optional): Defaults to 1.
        n_obs (int, optional): Defaults to 100.

    Returns:
        jnp.ndarray: samples generated by the model.
    """
    h_mat = jnp.zeros((batch_size, n_obs))
    y_mat = jnp.zeros((batch_size, n_obs))

    w_vec = jnp.repeat(w, batch_size)
    rho_vec = jnp.repeat(rho, batch_size)
    sigma_v_vec = jnp.repeat(sigma_v, batch_size)

    key, subkey = random.split(key)
    h_mat = h_mat.at[:, 0].set(w_vec + dist.Normal(0, 1).sample(key=subkey,
                                                                sample_shape=(batch_size,)) * sigma_v_vec)
    key, subkey = random.split(key)
    y_mat = y_mat.at[:, 0].set(jnp.exp(h_mat[:, 0]/2) * dist.Normal(0, 1).sample(key=subkey,
                                                                                 sample_shape=(batch_size,)))

    for i in range(1, n_obs):
        key, subkey = random.split(key)
        h_mat = h_mat.at[:, i].set(w_vec + rho_vec * h_mat[:, i-1] + dist.Normal(0, 1).sample(key=subkey,
                                                                                              sample_shape=(batch_size,)) * sigma_v_vec)
        key, subkey = random.split(key)
        y_mat = y_mat.at[:, i].set(jnp.exp(h_mat[:, i]/2)*dist.Normal(0, 1).sample(key=subkey,
                                                                                   sample_shape=(batch_size,)))

    return y_mat


# def true_posterior(x_obs: jnp.ndarray,
#                    prior: dist.Distribution) -> jnp.ndarray:
#     """Return true posterior for misspecified MA(1) example."""
#     pass
