# -*- coding: utf-8 -*-

# %%

import sys
from pathlib import Path

# project root = parent of "scripts"
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import numpy as np
from scipy.stats import norm

# %%


def A(m, s):
    return 2 * s * norm.pdf(m / s) + m * (2 * norm.cdf(m / s) - 1)


# %%


def get_mu_sigma_ens(mus, sigmas):
    mu_star = mus.mean()
    sigma_star = np.sqrt(((sigmas**2) + (mus**2)).mean() - mu_star**2)

    return mu_star, sigma_star


# %%


def get_A_ij(mu_i, sigma_i, mu_j, sigma_j):
    mu_ij = mu_i - mu_j
    sigma_ij = np.sqrt(sigma_i**2 + sigma_j**2)
    return A(mu_ij, sigma_ij)


# %%


def mean_Aij(mus, sigmas):
    M = len(mus)
    As = np.zeros((M, M))

    for i in range(M):
        for j in range(M):
            # As[i, j] = A(mus[i] - mus[j], np.sqrt(sigmas[i] ** 2 + sigmas[j] ** 2))
            As[i, j] = get_A_ij(mus[i], sigmas[i], mus[j], sigmas[j])
    return As.mean()


# %%


def r_excess_3a_2(mus, sigmas):
    mu_star, sigma_star = get_mu_sigma_ens(mus, sigmas)

    h_pstar = sigma_star / np.sqrt(np.pi)
    h_ens = 0.5 * mean_Aij(mus, sigmas)

    M = len(mus)
    As = np.zeros(M)

    for j in range(M):
        # As[j] = A(mu_star - mus[j], np.sqrt(sigma_star ** 2 + sigmas[j] ** 2))
        As[j] = get_A_ij(mus[j], sigmas[j], mu_star, sigma_star)

    return As.mean() - h_pstar - h_ens


def r_excess_3b_2(mus, sigmas):
    mu_b = mus.mean()
    sigma_b = np.sqrt((sigmas ** 2).mean())

    h_pstar = sigma_b / np.sqrt(np.pi)
    h_ens = 0.5 * mean_Aij(mus, sigmas)

    M = len(mus)
    As = np.zeros(M)

    for j in range(M):
        # As[j] = A(mu_star - mus[j], np.sqrt(sigma_star ** 2 + sigmas[j] ** 2))
        As[j] = get_A_ij(mus[j], sigmas[j], mu_b, sigma_b)

    return As.mean() - h_pstar - h_ens


# %%


def r_bayes_1(mus, sigmas):

    return sigmas.mean() / np.sqrt(np.pi)


def r_bayes_2(mus, sigmas):

    return 0.5 * mean_Aij(mus, sigmas)


def r_bayes_3a(mus, sigmas):

    mu_star, sigma_star = get_mu_sigma_ens(mus, sigmas)

    h_pstar = sigma_star / np.sqrt(np.pi)

    return h_pstar


def r_bayes_3b(mus, sigmas):
    h_b = np.sqrt((sigmas ** 2).mean()) / np.sqrt(np.pi)

    return h_b


#%%


def r_excess_1_1(mus, sigmas):
    return mean_Aij(mus, sigmas) - 2 / np.sqrt(np.pi) * sigmas.mean()


def r_excess_2_1(mus, sigmas):
    return 0.5 * mean_Aij(mus, sigmas) - 1 / np.sqrt(np.pi) * sigmas.mean()


def r_excess_3a_1(mus, sigmas):
    mu_star, sigma_star = get_mu_sigma_ens(mus, sigmas)

    M = len(mus)
    As = np.zeros(M)

    for j in range(M):
        As[j] = get_A_ij(mu_star, sigma_star, mus[j], sigmas[j])

    return As.mean() - (sigma_star + sigmas.mean()) / np.sqrt(np.pi)


def r_excess_3b_1(mus, sigmas):
    mu_b = mus.mean()
    sigma_b = np.sqrt((sigmas ** 2).mean())

    M = len(mus)
    As = np.zeros(M)

    for j in range(M):
        As[j] = get_A_ij(mu_b, sigma_b, mus[j], sigmas[j])

    return As.mean() - (sigma_b + sigmas.mean()) / np.sqrt(np.pi)


# %%

M = 5
mus = np.random.randn(M) * 2 + 1
sigmas = np.abs(np.random.randn(M)) * 0.5 + 0.1

mu_star, sigma_star = get_mu_sigma_ens(mus, sigmas)

print(f"{mu_star=}, {sigma_star=}")

reimp = {}

reimp["bayes_1"] = r_bayes_1(mus, sigmas)
reimp["bayes_2"] = r_bayes_2(mus, sigmas)
reimp["bayes_3a"] = r_bayes_3a(mus, sigmas)
reimp["bayes_3b"] = r_bayes_3b(mus, sigmas)

reimp["excess_1_1"] = r_excess_1_1(mus, sigmas)
reimp["excess_2_1"] = r_excess_2_1(mus, sigmas)
reimp["excess_3a_1"] = r_excess_3a_1(mus, sigmas)
reimp["excess_3b_1"] = r_excess_3b_1(mus, sigmas)

reimp["excess_3a_2"] = r_excess_3a_2(mus, sigmas)
reimp["excess_3b_2"] = r_excess_3b_2(mus, sigmas)

reimp["total_1_1"] = r_excess_1_1(mus, sigmas) + r_bayes_1(mus, sigmas)
reimp["total_2_1"] = r_excess_2_1(mus, sigmas) + r_bayes_2(mus, sigmas)
reimp["total_3a_1"] = r_excess_3a_1(mus, sigmas) + r_bayes_3a(mus, sigmas)
reimp["total_3b_1"] = r_excess_3b_1(mus, sigmas) + r_bayes_3b(mus, sigmas)
reimp["total_3a_2"] = r_excess_3a_2(mus, sigmas) + r_bayes_3a(mus, sigmas)
reimp["total_3b_2"] = r_excess_3b_2(mus, sigmas) + r_bayes_3b(mus, sigmas)

# %%

import torch
from source.utils.uncertainty_measures import calculate_uncertainties_crps, _calc_surrogate_variance

# %%

print("Current code:                Re-implemented:")

means = torch.tensor(mus).reshape(1, -1)
variances = torch.tensor(sigmas ** 2).reshape(1, -1)

var_star = _calc_surrogate_variance(means, variances)
print(f"mu_star = {means.mean(dim=-1)}, sigma_star = {torch.sqrt(var_star)}")

unc = calculate_uncertainties_crps(means, variances)

for k, v in unc.items():
    print(f"R_{k:12} = {v[0].numpy():.4f}  vs  {reimp.get(k, 0.0):.4f}")

#print(unc)

# %%
