import json

import numpy as np
from scipy.stats import entropy
from tqdm import tqdm


def simulate_divergences_with_percentiles(true_probs, n_trials, n_simulations):
    true_probs_p = np.asarray(true_probs, dtype=float)  # specify dtype=float
    if not np.isclose(np.sum(true_probs_p), 1.0):
        print(
            f"Warning: Sum of probabilities deviates slightly from 1: {np.sum(true_probs_p)}"
        )
        if not np.isclose(np.sum(true_probs_p), 1.0, atol=1e-5):  # slightly looser tolerance
            raise ValueError(
                f"The sum of the true probability distribution is not 1: {np.sum(true_probs_p)}"
            )

    if np.any(true_probs_p < 0):
        raise ValueError("The true probability distribution contains negative values.")
    if np.any(true_probs_p > 1):
        raise ValueError("The true probability distribution contains values greater than 1.")

    kl_pq_all_values = []
    kl_qp_all_values = []
    js_pq_all_values = []

    for _ in range(n_simulations):
        observed_counts = np.random.multinomial(n_trials, true_probs_p, size=1)[0]
        empirical_probs_q = observed_counts / n_trials

        kl_pq = entropy(pk=true_probs_p, qk=empirical_probs_q)
        kl_pq_all_values.append(kl_pq)

        kl_qp = entropy(pk=empirical_probs_q, qk=true_probs_p)
        kl_qp_all_values.append(kl_qp)

        m_probs = 0.5 * (true_probs_p + empirical_probs_q)
        js_part1 = entropy(pk=true_probs_p, qk=m_probs)
        js_part2 = entropy(pk=empirical_probs_q, qk=m_probs)
        js_pq = 0.5 * (js_part1 + js_part2)
        js_pq_all_values.append(js_pq)

    percentiles_to_compute = [10, 20, 30, 40, 50, 60, 70, 80, 90, 95]
    results = {}

    # KL(P||Q)
    finite_kl_pq_values = [x for x in kl_pq_all_values if np.isfinite(x)]
    if finite_kl_pq_values:
        kl_pq_perc = np.percentile(finite_kl_pq_values, percentiles_to_compute)
        kl_pq_std = np.std(finite_kl_pq_values)
        kl_pq_mean = np.mean(finite_kl_pq_values)
    else:
        kl_pq_perc = [np.nan] * len(percentiles_to_compute)
        kl_pq_std = np.nan
        kl_pq_mean = np.nan
    results["kl_pq"] = {
        "values": kl_pq_all_values,  # keep all values (including inf)
        "percentiles": dict(zip(percentiles_to_compute, kl_pq_perc)),
        "std": kl_pq_std,
        "mean": kl_pq_mean,
    }
    if len(finite_kl_pq_values) < len(kl_pq_all_values):
        num_inf = len(kl_pq_all_values) - len(finite_kl_pq_values)
        results["kl_pq"]["info"] = (
            f"{num_inf}/{n_simulations} occurrences of inf encountered and excluded from percentile calculations."
        )

    # KL(Q||P)
    finite_kl_qp_values = [
        x for x in kl_qp_all_values if np.isfinite(x)
    ]  # inf should not normally occur
    if finite_kl_qp_values:
        kl_qp_perc = np.percentile(finite_kl_qp_values, percentiles_to_compute)
        kl_qp_std = np.std(finite_kl_qp_values)
        kl_qp_mean = np.mean(finite_kl_qp_values)
    else:
        kl_qp_perc = [np.nan] * len(percentiles_to_compute)
        kl_qp_std = np.nan
        kl_qp_mean = np.nan
    results["kl_qp"] = {
        "values": kl_qp_all_values,
        "percentiles": dict(zip(percentiles_to_compute, kl_qp_perc)),
        "std": kl_qp_std,
        "mean": kl_qp_mean,
    }
    if len(finite_kl_qp_values) < len(kl_qp_all_values):
        num_inf = len(kl_qp_all_values) - len(finite_kl_qp_values)
        results["kl_qp"]["info"] = (
            f"{num_inf}/{n_simulations} occurrences of inf encountered and excluded from percentile calculations."
        )

    # JS(P||Q)
    finite_js_pq_values = [
        x for x in js_pq_all_values if np.isfinite(x)
    ]  # inf should not normally occur
    if finite_js_pq_values:
        js_pq_perc = np.percentile(finite_js_pq_values, percentiles_to_compute)
        js_pq_std = np.std(finite_js_pq_values)
        js_pq_mean = np.mean(finite_js_pq_values)
    else:
        js_pq_perc = [np.nan] * len(percentiles_to_compute)
        js_pq_std = np.nan
        js_pq_mean = np.nan
    results["js_pq"] = {
        "values": js_pq_all_values,
        "percentiles": dict(zip(percentiles_to_compute, js_pq_perc)),
        "std": js_pq_std,
        "mean": js_pq_mean,
    }
    if len(finite_js_pq_values) < len(js_pq_all_values):
        num_inf = len(js_pq_all_values) - len(finite_js_pq_values)
        results["js_pq"]["info"] = (
            f"{num_inf}/{n_simulations} occurrences of inf encountered and excluded from percentile calculations."
        )

    return results


if __name__ == "__main__":
    # --- Common simulation parameters ---
    # num_trials_per_experiment = 1000
    num_trials_per_experiment = 100
    num_simulations = 100000

    # --- List of specified probability distributions ---
    distributions_to_test = [
        {
            "name": f"Distribution {i} ({2**i} categories, uniform)",
            "probabilities": [
                1.0 / (2**i) for _ in range(2**i)
            ],  # ensure the sum is exactly 1
            "num_moves": 2**i,
            "num_trials_per_experiment": num_trials_per_experiment,
            "num_simulations": num_simulations,
        }
        for i in range(1, 14)
    ]

    # --- Run simulations for each distribution ---
    results_ls = dict()
    for dist_info in tqdm(distributions_to_test):
        dist_name = dist_info["name"]
        true_distribution = dist_info["probabilities"]

        print(f"===== Simulation start: {dist_name} =====")
        print(f"True probability distribution P: {np.round(true_distribution, 4)}")  # rounded for display
        print(f"Number of trials (n_trials): {num_trials_per_experiment}")
        print(f"Number of simulations (n_simulations): {num_simulations}\n")

        results = simulate_divergences_with_percentiles(
            true_distribution, num_trials_per_experiment, num_simulations
        )
        print(f"--- Simulation results ({dist_name}) ---\n")

        for div_name, data in results.items():
            print(f"  --- {div_name} ---")
            if "info" in data:
                print(f"    Info: {data['info']}")
            print(f"    Statistics:")
            if np.isnan(data["mean"]):
                print(f"      Mean: NaN")
            else:
                print(f"      Mean: {data['mean']:.8f}")
            if np.isnan(data["std"]):
                print(f"      Standard deviation: NaN")
            else:
                print(f"      Standard deviation: {data['std']:.8f}")
            print(f"    Percentiles:")
            for p, val in data["percentiles"].items():
                if np.isnan(val):
                    print(f"      {p}%: NaN")
                else:
                    print(f"      {p}%: {val:.8f}")  # display with higher precision
                data["percentiles"][p] = float(val)
            # Convert numpy types to Python types for JSON serialization
            data["mean"] = float(data["mean"]) if not np.isnan(data["mean"]) else None
            data["std"] = float(data["std"]) if not np.isnan(data["std"]) else None
            if "values" in data:
                del data["values"]
            print("-" * 30)
        print(f"===== Simulation end: {dist_name} =====\n\n")
        results_ls[dist_info["num_moves"]] = {
            "dist_info": dist_info,
            "results": results,
        }

    print(results_ls)
    with open("prng_n_100.json", "w") as f:
        f.write(json.dumps(results_ls, indent=2))
