# Simulation-based calibration 
# - SBC histogram [✓] 
# - SBC shade 
# - ECDF difference plot [✓] 

import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.stats import wasserstein_distance, gaussian_kde, binom


def getRanks(targets: torch.Tensor, proposed: dict, absolute=False, mask_0=False) -> torch.Tensor:
    # targets (b, d), samples (b, s, d)
    mask = (targets != 0)
    if mask_0: # ignore d=1 models
        mask[:, 0] = mask[:, 1:-1].any(1)
    samples = proposed['samples']
    weights = proposed.get('weights', None)
    targets_ = targets.unsqueeze(-1)
    if absolute:
        samples = samples.abs()
        targets_ = targets_.abs()        
    smaller = (samples <= targets_).float()
    if weights is not None:
        ranks = (smaller * weights).sum(-1) / samples.shape[-1]
    else:
        ranks = smaller.mean(-1)
    ranks[~mask] = -1
    return ranks



def plotSBC(ranks: torch.Tensor,
            mask: torch.Tensor,
            names: list,
            color: str = 'darkgreen') -> None:
    eps = 0.02
    n = len(names)
    w = int(torch.tensor(n).sqrt().ceil())
    _, axs = plt.subplots(figsize=(8 * w, 6 * w), ncols=w, nrows=w)
    axs = axs.flatten()
    endpoints = binom.interval(0.95, n, 1 / 20)
    mean = n / 20
    
    for i, name in enumerate(names):
        ax = axs[i]
        ax.set_axisbelow(True)
        ax.grid(True)
        mask_0 = (ranks[:, i] >= 0)
        if mask is not None:
            mask_i = mask[:, i] * mask_0
        else:
            mask_i = mask_0
        xx = ranks[mask_i, i]
        if mask_i.sum() == 0:
            axs[i].set_visible(False)
            continue
        ax.axhspan(endpoints[0], endpoints[1], facecolor="gray", alpha=0.1)
        ax.axhline(mean, color="gray", zorder=0, alpha=0.9, linestyle='--')
        sns.histplot(xx, kde=True, ax=ax, binwidth=0.05, binrange=(0,1),
                     color=color, alpha=0.5, stat="density", label=names[i])
        ax.set_xlim(0-eps,1+eps)
        ax.set_xlabel('U', fontsize=20)
        ax.set_ylabel('')
        ax.tick_params(axis='y', labelcolor='w')
        ax.legend()
    for i in range(n, w*w):
        axs[i].set_visible(False)
        
        
def plotSBCsingle(
        ax,
        ranks: torch.Tensor,
        n: int, 
        upper: bool = True,
        color: str = 'darkgreen') -> None:
    eps = 0.02
    endpoints = binom.interval(0.95, n, 1 / 20)
    mean = n / 20
    ax.axhspan(endpoints[0], endpoints[1], facecolor="gray", alpha=0.1)
    ax.axhline(mean, color="gray", zorder=0, alpha=0.9,
               label='theoretical', lw=2, linestyle='--')
    ax.set_axisbelow(True)
    ax.grid(True)
    mask = (ranks >= 0)
    xx = ranks[mask]
    sns.histplot(xx, kde=True, ax=ax, binwidth=0.05, binrange=(0,1), lw=2,
                 color=color, alpha=0.5, stat="density", label='estimated')
    ax.set_xlim(0-eps,1+eps)
    ax.set_ylabel('')
    
    ax.tick_params(axis='y', labelcolor='w')
    if upper:
        ax.set_title('Calibration', fontsize=30, pad=15)
        ax.legend(fontsize=16, loc='upper right')
        ax.set_xlabel('')
    else:
        ax.set_xlabel('U', labelpad=10, size=26)
    
    
    
def getWasserstein(ranks: torch.Tensor, mask: torch.Tensor, n_points=1000):
    b, d = ranks.shape
    
    # support
    x = np.linspace(0, 1, n_points)
    dx = x[1] - x[0]
    
    # get KDE approximation of ranks distribution
    # get probabilities of x under this distribution
    q = np.ones_like(x)
    p = []
    for i in range(d):
        mask_0 = (ranks[:, i] >= 0)
        mask_i = mask[:, i] * mask_0
        if mask_i.sum() == 0: continue
        ranks_i = ranks[mask_i, i]
        p += [gaussian_kde(ranks_i, bw_method='scott')(x)]
    
    # cdfs
    ecdf = [np.cumsum(p_i) * dx for p_i in p] 
    ucdf = np.cumsum(q) * dx
    
    # get wasserstein distances
    wds = [wasserstein_distance(ecdf_i, ucdf) for ecdf_i in ecdf]
    wd = float(np.sum(wds)/len(wds))
    return wd


def boundECDF(n_ranks: int, n_sim: int = 1000, alpha: float = 0.01):
    p = np.linspace(0, 1, n_ranks)
    lower = binom.ppf(alpha/2, n_sim, p)/n_sim - p
    upper = binom.ppf(1 - alpha/2, n_sim, p)/n_sim - p
    return p, lower, upper


def plotECDF(ranks: torch.Tensor, mask: torch.Tensor, names: list, s: int, color='darkgreen') -> None:
    eps = 0.02
    xx_theo, lower, upper = boundECDF(n_ranks=len(ranks), n_sim=s)

    n = len(names)
    w = int(torch.tensor(n).sqrt().ceil())
    fig, axs = plt.subplots(figsize=(8 * w, 6 * w), ncols=w, nrows=w)
    axs = axs.flatten()
    
    for i, name in enumerate(names):
        ax = axs[i]
        ax.set_axisbelow(True)
        ax.grid(True)
        mask_0 = (ranks[:, i] >= 0)
        mask_i = mask[:, i] * mask_0
        if mask_i.sum() == 0:
            axs[i].set_visible(False)
            continue
        xx = ranks[mask_i, i].sort()[0].numpy()
        xx = np.pad(xx, (1, 1), constant_values=(0, 1))
        yy = np.linspace(0, 1, num=xx.shape[-1]) - xx
        
        ax.plot(xx, yy, color=color, label='sample')
        ax.fill_between(xx_theo, lower, upper, color=color, alpha=0.1, label='theoretical')
        ax.set_xlim(0-eps,1+eps)
        ax.set_xlabel('U', fontsize=20)
        ax.set_ylabel(r'$\Delta$ ECDF')
        # ax.tick_params(axis='y', labelcolor='w')
        ax.legend()
    
    for i in range(n, w*w):
        axs[i].set_visible(False)
    fig.tight_layout()
    
    
def plotECDFsingle(ax, ranks: torch.Tensor, mask: torch.Tensor, names: list, s: int, color='darkgreen') -> None:
    eps = 0.02
    xx_theo, lower, upper = boundECDF(n_ranks=len(ranks), n_sim=s)
    ax.fill_between(xx_theo, lower, upper, color=color, alpha=0.1, label='theoretical')
    ax.set_axisbelow(True)
    ax.grid(True)
    ax.set_xlim(0-eps,1+eps)
    ax.set_xlabel('U', fontsize=20)
    ax.set_ylabel(r'$\Delta$ ECDF')
    for i, name in enumerate(names):
        mask_0 = (ranks[:, i] >= 0)
        mask_i = mask[:, i] * mask_0
        xx = ranks[mask_i, i].sort()[0].numpy()
        xx = np.pad(xx, (1, 1), constant_values=(0, 1))
        yy = np.linspace(0, 1, num=xx.shape[-1]) - xx
        ax.plot(xx, yy, color=color, label='sample' if name == names[-1] else None)
    ax.legend()
        
        