"""Code builds on https://github.com/lollcat/fab-jax"""
from typing import Optional, Tuple
import itertools

import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
import chex
import pandas as pd


def plot_contours_2D(log_prob_func,
                     ax: Optional[plt.Axes] = None,
                     bound: float = 3,
                     levels: int = 20,
                     clip_threshold = -1000,
                     dim0 = 0,
                     dim1 = 1):
    """Plot the contours of a 2D log prob function."""
    if ax is None:
        fig, ax = plt.subplots(1)
    n_points = 100
    x_points_dim1 = np.linspace(-bound, bound, n_points)
    x_points_dim2 = np.linspace(-bound, bound, n_points)
    x_points = np.array(list(itertools.product(x_points_dim1, x_points_dim2)))
    log_probs = log_prob_func(x_points)
    log_probs = jnp.clip(log_probs, a_min=clip_threshold, a_max=None)
    x1 = x_points[:, 0].reshape(n_points, n_points)
    x2 = x_points[:, 1].reshape(n_points, n_points)
    z = log_probs.reshape(n_points, n_points)
    ax.contour(x1, x2, z, levels=levels)
    # ax.contourf(x1, x2, np.exp(z), levels = 20, cmap = 'viridis')



def plot_marginal_pair(samples: chex.Array,
                  ax: Optional[plt.Axes] = None,
                  marginal_dims: Tuple[int, int] = (0, 1),
                  bounds: Tuple[float, float] = (-5, 5),
                  alpha: float = 0.5,
                  temperatures = None,
                  fig = None):
    """Plot samples from marginal of distribution for a given pair of dimensions."""
    if not ax:
        fig, ax = plt.subplots(1)
    samples = jnp.clip(samples, bounds[0], bounds[1])

    if temperatures is None:
        ax.plot(samples[:, marginal_dims[0]], samples[:, marginal_dims[1]], "o", alpha=alpha)
    else:
        scatter = ax.scatter(samples[:, marginal_dims[0]], samples[:, marginal_dims[1]], c=temperatures, cmap='viridis', alpha=alpha)
        fig.colorbar(scatter, ax=ax)
def plot_history(history):
    """Agnostic history plotter for quickly plotting a dictionary of logging info."""
    figure, axs = plt.subplots(len(history), 1, figsize=(7, 3 * len(history.keys())))
    if len(history.keys()) == 1:
        axs = [axs]  # make iterable
    elif len(history.keys()) == 0:
        return
    for i, key in enumerate(history):
        data = pd.Series(history[key])
        data.replace([np.inf, -np.inf], np.nan, inplace=True)
        if sum(data.isna()) > 0:
            data = data.dropna()
            print(f"NaN encountered in {key} history")
        axs[i].plot(data)
        axs[i].set_title(key)
    plt.tight_layout()



# code from setup_training
