# -*- coding: utf-8 -*-

# %%

import sys
from pathlib import Path

# project root = parent of "scripts"
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import numpy as np
from scipy.stats import norm

# %%


def A(m, s):
    return 2 * s * norm.pdf(m / s) + m * (2 * norm.cdf(m / s) - 1)


# %%


def get_mu_sigma_ens(mus, sigmas):
    mu_star = mus.mean()
    sigma_star = np.sqrt(((sigmas**2) + (mus**2)).mean() - mu_star**2)

    return mu_star, sigma_star


# %%

from scipy.stats import norm

def get_N_ij(mu_i, sigma_i, mu_j, sigma_j):
    sigma_ij = np.sqrt(sigma_i ** 2 + sigma_j ** 2)
    return norm.pdf(mu_i, loc=mu_j, scale=sigma_ij)


#%%


def get_quadratic_score(mu_i, sigma_i, mu_j, sigma_j):
    return -2 * get_N_ij(mu_i, sigma_i, mu_j, sigma_j) + 1 / (2 * np.sqrt(np.pi) * sigma_i) 


def get_entropy(mu, sigma):
    return -1 / (2 * np.sqrt(np.pi) * sigma)


# %%


def mean_Nij(mus, sigmas):
    M = len(mus)
    Ns = np.zeros((M, M))

    for i in range(M):
        for j in range(M):
            # As[i, j] = A(mus[i] - mus[j], np.sqrt(sigmas[i] ** 2 + sigmas[j] ** 2))
            Ns[i, j] = get_N_ij(mus[i], sigmas[i], mus[j], sigmas[j])
    return Ns.mean()


# %%


def r_excess_3a_2(mus, sigmas):
    mu_star, sigma_star = get_mu_sigma_ens(mus, sigmas)

    M = len(mus)
    Ns = np.zeros(M)

    for j in range(M):
        Ns[j] = get_N_ij(mus[j], sigmas[j], mu_star, sigma_star)

    return - 2 * Ns.mean() - r_bayes_2(mus, sigmas) - r_bayes_3a(mus, sigmas)


def r_excess_3b_2(mus, sigmas):
    mu_b = mus.mean()
    sigma_b = np.sqrt((sigmas ** 2).mean())

    M = len(mus)
    Ns = np.zeros(M)

    for j in range(M):
        Ns[j] = get_N_ij(mus[j], sigmas[j], mu_b, sigma_b)

    return - 2 * Ns.mean() - r_bayes_2(mus, sigmas) - r_bayes_3b(mus, sigmas)


# %%


def r_bayes_1(mus, sigmas):

    return - 1 / (2 * np.sqrt(np.pi)) * (1 / sigmas).mean()


def r_bayes_2(mus, sigmas):

    return -mean_Nij(mus, sigmas)


def r_bayes_2_alt(mus, sigmas):

    s1 = r_bayes_1(mus, sigmas)
    s2 = 0
    M = len(mus)
    for i in range(M):
        for j in range(M):
            if i == j:
                continue
            s2 += get_quadratic_score(mus[i], sigmas[i], mus[j], sigmas[j])
    #for i in range(M):
    #    for j in range(M):
    #        s2 += get_quadratic_score(mus[i], sigmas[i], mus[j], sigmas[j]) + get_entropy(mus[i], sigmas[i])

    return s1 * (1 / (2 * M) + 0.5) + s2 / (2 * M * M)
    #return s2 / (2 * M * M)


def r_bayes_3a(mus, sigmas):

    mu_star, sigma_star = get_mu_sigma_ens(mus, sigmas)

    h_pstar = -1 / (2 * np.sqrt(np.pi) * sigma_star)

    return h_pstar


def r_bayes_3b(mus, sigmas):

    h_b = -1 / (2 * np.sqrt(np.pi * (sigmas ** 2).mean()))

    return h_b


#%%


def r_excess_1_1(mus, sigmas):
    return - 2 * mean_Nij(mus, sigmas) - 2 * r_bayes_1(mus, sigmas)


def r_excess_2_1(mus, sigmas):
    #return 0.5 * r_excess_1_1(mus, sigmas)
    return - 2 * mean_Nij(mus, sigmas) - r_bayes_1(mus, sigmas) - r_bayes_2_alt(mus, sigmas)


def r_excess_3a_1(mus, sigmas):
    mu_star, sigma_star = get_mu_sigma_ens(mus, sigmas)

    M = len(mus)
    Ns = np.zeros(M)

    for j in range(M):
        Ns[j] = get_N_ij(mu_star, sigma_star, mus[j], sigmas[j])

    return - 2 * Ns.mean() - r_bayes_1(mus, sigmas) - r_bayes_3a(mus, sigmas)


def r_excess_3b_1(mus, sigmas):
    mu_b, sigma_b = mus.mean(), np.sqrt((sigmas ** 2).mean())

    M = len(mus)
    Ns = np.zeros(M)

    for j in range(M):
        Ns[j] = get_N_ij(mu_b, sigma_b, mus[j], sigmas[j])

    return - 2 * Ns.mean() - r_bayes_1(mus, sigmas) - r_bayes_3b(mus, sigmas)


# %%

M = 5
mus = np.random.randn(M) * 2 + 1
sigmas = np.abs(np.random.randn(M)) * 0.5 + 0.1

mu_star, sigma_star = get_mu_sigma_ens(mus, sigmas)

print(f"{mu_star=}, {sigma_star=}")

reimp = {}

reimp["bayes_1"] = r_bayes_1(mus, sigmas)
reimp["bayes_2"] = r_bayes_2(mus, sigmas)
reimp["bayes_3a"] = r_bayes_3a(mus, sigmas)
reimp["bayes_3b"] = r_bayes_3b(mus, sigmas)

reimp["excess_1_1"] = r_excess_1_1(mus, sigmas)
reimp["excess_2_1"] = r_excess_2_1(mus, sigmas)
reimp["excess_3a_1"] = r_excess_3a_1(mus, sigmas)
reimp["excess_3b_1"] = r_excess_3b_1(mus, sigmas)

reimp["excess_3a_2"] = r_excess_3a_2(mus, sigmas)
reimp["excess_3b_2"] = r_excess_3b_2(mus, sigmas)

reimp["total_1_1"] = r_excess_1_1(mus, sigmas) + r_bayes_1(mus, sigmas)
reimp["total_2_1"] = r_excess_2_1(mus, sigmas) + r_bayes_2(mus, sigmas)
reimp["total_3a_1"] = r_excess_3a_1(mus, sigmas) + r_bayes_3a(mus, sigmas)
reimp["total_3b_1"] = r_excess_3b_1(mus, sigmas) + r_bayes_3b(mus, sigmas)
reimp["total_3a_2"] = r_excess_3a_2(mus, sigmas) + r_bayes_3a(mus, sigmas)
reimp["total_3b_2"] = r_excess_3b_2(mus, sigmas) + r_bayes_3b(mus, sigmas)

# %%

import torch
from source.utils.uncertainty_measures import calculate_uncertainties_quadratic, _calc_surrogate_variance

# %%

print("Current code:                Re-implemented:")

means = torch.tensor(mus).reshape(1, -1)
variances = torch.tensor(sigmas ** 2).reshape(1, -1)

var_star = _calc_surrogate_variance(means, variances)
print(f"mu_star = {means.mean(dim=-1)}, sigma_star = {torch.sqrt(var_star)}")

unc = calculate_uncertainties_quadratic(means, variances)

for k, v in unc.items():
    s = f"R_{k:12} = {v[0].numpy():+.4f}  vs  {reimp.get(k, 0.0):+.4f}"
    if k == "bayes_2":
        s += f" , alt: {r_bayes_2_alt(mus, sigmas):.4f}"
    print(s)

#print(unc)
#print(f"{r_bayes_2_alt(mus, sigmas):.4f}")

# %%
