import torch
import numpy as np
import matplotlib.pyplot as plt
from metabeta.models.approximators import Approximator
# -----------------------------------------------------------------------------
# compare posterior intervals with mcmc
def plotIntervals(ax,
                  quantiles1: torch.Tensor,
                  quantiles2: torch.Tensor,
                  target: torch.Tensor,
                  name: str, n: int = 12, show_y: bool=False):
    # calculate overlap
    # width1 = quantiles1[:, 1] - quantiles1[:, 0]
    # width2 = quantiles2[:, 1] - quantiles2[:, 0]
    # d_50 = (width1 - width2).mean()
    # width1 = quantiles1[:, 3] - quantiles1[:, 2]
    # width2 = quantiles2[:, 3] - quantiles2[:, 2]
    # d_95 = (width1 - width2).mean()
    
    # sort targets
    target, idx = torch.sort(target)
    
    # get evenly spaced subset of targets
    u = torch.linspace(0.05, 0.95, n)
    idx_ = torch.round(u * (len(target) - 1)).long()
    
    # subset the posterior quantiles
    q1 = quantiles1[idx][idx_]
    q2 = quantiles2[idx][idx_]
    
    # prepare axes
    x = np.arange(n)
    bar_gap = 0.05
    x1 = x - (0.20 + bar_gap/2)
    x2 = x + (0.20 + bar_gap/2)
    
    # plot 
    ax.bar(x1, bottom=q1[:, 2], height=q1[:, 3]-q1[:, 2],
           width=0.35, color='darkgreen', alpha=0.3, label='MB (95%)')
    ax.bar(x1, bottom=q1[:, 0], height=q1[:, 1]-q1[:, 0],
           width=0.40, color='darkgreen', alpha=0.8, label='MB (50%)')
    ax.bar(x2, bottom=q2[:, 2], height=q2[:, 3]-q2[:, 2],
           width=0.35, color='darkgoldenrod', alpha=0.3, label='HMC (95%)')
    ax.bar(x2, bottom=q2[:, 0], height=q2[:, 1]-q2[:, 0],
           width=0.40, color='darkgoldenrod', alpha=0.8, label='HMC (50%)')
    
    # # add medians
    # for i in range(n):
    #     plt.hlines(y=q1[i, 4], xmin=x1[i]-0.2, xmax=x1[i]+0.2,
    #                color='white', linewidth=1.5)
    #     plt.hlines(y=q2[i, 4], xmin=x2[i]-0.2, xmax=x2[i]+0.2,
    #                color='white', linewidth=1.5)
    
    # ax.text(
    #     0.75, 0.1,
    #     fr'$d_{{50}} = {d_50.item():.3f}$' + '\n' + fr'$d_{{95}} = {d_95.item():.3f}$',
    #     transform=ax.transAxes,
    #     ha='center', va='bottom',
    #     fontsize=16,
    #     bbox=dict(boxstyle='round',
    #               facecolor=(1, 1, 1, 0.7),
    #               edgecolor=(0, 0, 0, 0.2),
    #               ),
    # )
    
    ax.set_title(name, pad=10, size=24)
    ax.set_ylim(min(q1.min(), q2.min()) - 1)
    if show_y:
        ax.set_ylabel('credible intervals', size=24, labelpad=10)
        ax.legend()
    else:
        ax.set_ylabel('')
    ax.set_xticks([])
    
    
    
    
def plotAllIntervals(model: Approximator,
                     proposed: torch.Tensor,
                     mcmc: torch.Tensor,
                     targets: torch.Tensor,
                     names: list):
    # MB
    q50 = model.quantiles(proposed, [.25, .75], calibrate=True) 
    q95 = model.quantiles(proposed, [.025, .975], calibrate=True)
    qmb = torch.cat([q50, q95], dim=-1)
    
    # HMC
    q50 = model.quantiles(mcmc, [.25, .75], calibrate=False) 
    q95 = model.quantiles(mcmc, [.025, .975], calibrate=False)
    qmc = torch.cat([q50, q95], dim=-1)
    
    # figure
    n = len(names)
    fig, axs = plt.subplots(figsize=(6 * n, 5), ncols=n, dpi=400)
    axs = axs.flatten()
    mask = (targets != 0.)
    for i in range(n):
        if mask[:, i].sum() == 0:
            axs[i].set_visible(False)
            continue
        plotIntervals(axs[i], qmb[:, i], qmc[:, i], targets[:, i], names[i], show_y=(i==0))
    fig.tight_layout()
    

