import logging
from pathlib import Path

from bayesflow.amortizers import AmortizedLikelihood, AmortizedPosteriorLikelihood
import numpy as np
from cmdstanpy import CmdStanModel
from scipy.stats import multivariate_normal
from bayesflow.trainers import Trainer
from bayesflow.computational_utilities import maximum_mean_discrepancy

from src.self_consistency_real.amortizers.amortized_posterior_likelihood_sc import (
    AmortizedPosteriorLikelihoodSC,
)


def posterior_bias(trainer: Trainer, observed_data: list | np.ndarray):
    observed_data = np.array(observed_data)
    D = observed_data.shape[-1]

    forward_dict = {
        "prior_draws": np.array([[0.0] * D], dtype=np.float32),
        "sim_data": np.array([observed_data], dtype=np.float32),
    }
    input_dict = trainer.configurator(forward_dict)

    if (
        type(trainer.amortizer) is AmortizedPosteriorLikelihoodSC
        or type(trainer.amortizer) is AmortizedPosteriorLikelihood
    ):
        posterior_draws = trainer.amortizer.sample(
            input_dict, n_post_samples=1000, n_lik_samples=1000
        )["posterior_samples"]
    else:
        posterior_draws = trainer.amortizer.sample(input_dict, n_samples=1000)

    posterior_mean = np.mean(posterior_draws, axis=0)
    posterior_std = np.std(posterior_draws, axis=0)

    accurate_posterior = analytic_posterior(
        observed_data, prior_mean=[0.0] * D, prior_std=[1.0] * D
    )

    mean_bias = posterior_mean - accurate_posterior[0]
    std_bias = posterior_std - accurate_posterior[1]

    return mean_bias, std_bias


def posterior_mmd(trainer: Trainer, observed_data: list | np.ndarray):
    observed_data = np.array(observed_data)
    D = observed_data.shape[-1]

    forward_dict = {
        "prior_draws": np.array([[0.0] * D], dtype=np.float32),
        "sim_data": np.array([observed_data], dtype=np.float32),
    }
    input_dict = trainer.configurator(forward_dict)

    if (
        type(trainer.amortizer) is AmortizedPosteriorLikelihoodSC
        or type(trainer.amortizer) is AmortizedPosteriorLikelihood
    ):
        posterior_draws = trainer.amortizer.sample(
            input_dict, n_post_samples=1000, n_lik_samples=1000
        )["posterior_samples"]
    else:
        posterior_draws = trainer.amortizer.sample(input_dict, n_samples=1000)

    accurate_draws = analytic_posterior_draws(
        observed_data, prior_mean=[0.0] * D, prior_std=[1.0] * D, n_samples=1000
    ).astype(np.float32)

    mmd = maximum_mean_discrepancy(posterior_draws, accurate_draws)

    return mmd


def analytic_posterior(
    y: list | np.ndarray,
    prior_mean: list | np.ndarray,
    prior_std: list | np.ndarray,
    y_std=None,
):
    if y_std is None:
        y_std = np.full(len(y), 1.0)

    y = np.array(y)

    prior_mean = np.array(prior_mean)
    prior_std = np.array(prior_std)
    y_std = np.array(y_std)

    sample_mean = np.mean(y, axis=0)

    N = 1 if y.ndim == 1 else y.shape[0]

    post_mean = (
        prior_std**2 / ((y_std**2 / N) + prior_std**2) * sample_mean
        + y_std**2 / ((y_std**2) / N + prior_std**2) * prior_mean
    )
    post_var = 1 / (1 / prior_std**2 + N / y_std**2)

    return post_mean, np.sqrt(post_var)


def analytic_posterior_draws(
    y: list | np.ndarray,
    prior_mean: list | np.ndarray,
    prior_std: list | np.ndarray,
    y_std=None,
    n_samples: int = 200,
):
    if y_std is None:
        y_std = np.full(len(y), 1.0)

    post_mean, post_std = analytic_posterior(y, prior_mean, prior_std, y_std)
    post_var = post_std**2
    draws = multivariate_normal.rvs(post_mean, np.diag(post_var), size=n_samples)

    return draws


def stan_posterior_draws(
    y: list | np.ndarray,
    prior_mean: list | np.ndarray,
    prior_std: list | np.ndarray,
    y_std=None,
    n_samples: int = 200,
):
    if y_std is None:
        y_std = np.full(len(y), 1.0)

    path = str(Path(__file__).parents[0] / "normal_dim_2.stan")

    data = {
        "N": 1,
        "D": len(y),
        "y": np.array(y),
        "prior_mean": np.array(prior_mean),
        "prior_std": np.array(prior_std),
        "y_std": np.array(y_std),
    }

    logger = logging.getLogger("cmdstanpy")
    level = logger.level

    try:
        logger.setLevel(logging.CRITICAL)
        model = CmdStanModel(stan_file=path)
        fit = model.sample(data, show_progress=False, iter_sampling=n_samples // 4)
    finally:
        logger.setLevel(level)

    theta = fit.draws_pd("theta")

    return np.array(theta)
