import hypothesis as h
import matplotlib.pyplot as plt
import torch

from hypothesis.plot import make_square
from hypothesis.stat import highest_density_level
from matplotlib.colors import LogNorm
from ratio_estimation import compute_log_posterior
from ratio_estimation import extent



@torch.no_grad()
def plot_contours(ax, r, observable, cls=[0.95], labels=[r"95\%"], resolution=250, color='C0'):
    if labels is None:
        labels = [None] * len(cls)
    observable = observable.squeeze()
    epsilon = 0.00001
    p1 = torch.linspace(extent[0], extent[1] - epsilon, resolution)  # Account for half-open interval of uniform prior
    p2 = torch.linspace(extent[2], extent[3] - epsilon, resolution)  # Account for half-open interval of uniform prior
    g1, g2 = torch.meshgrid(p1.view(-1), p2.view(-1))
    g1 = g1.cpu().numpy()
    g2 = g2.cpu().numpy()
    pdf = compute_log_posterior(r, observable, resolution=resolution).exp().numpy()
    fmt = {}
    for cl, label in zip(cls, labels):
        alpha = 1.0 - cl
        level = highest_density_level(pdf, alpha=alpha)
        if isinstance(color, np.ndarray):
            color = np.array(color).reshape(1, 4)
        c = ax.contour(g1, g2, pdf, [level], colors=color)
        if label is not None:
            fmt[c.levels[0]] = label
            ax.clabel(c, c.levels, inline=True, fontsize=20, fmt=fmt)
    h.plot.make_square(ax)
