import os
import torch
import wandb

from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap


def imshow_density(
    log_prob,
    bins: int,
    scale: float,
    ax=None,
    device: torch.device = torch.device('cpu'),
    **kwargs
):
    if ax is None:
        ax = plt.gca()
    x = torch.linspace(-scale, scale, bins).to(device)
    y = torch.linspace(-scale, scale, bins).to(device)
    X, Y = torch.meshgrid(x, y)
    xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)
    density = log_prob(xy).reshape(bins, bins).T
    # print("Density:", density.min().item(), density.max().item())
    im = ax.imshow(
        density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)

def contour_density(
    log_prob,
    bins: int,
    scale: float,
    ax=None,
    device: torch.device = torch.device('cpu'),
    **kwargs
):
    if ax is None:
        ax = plt.gca()
    x = torch.linspace(-scale, scale, bins).to(device)
    y = torch.linspace(-scale, scale, bins).to(device)
    X, Y = torch.meshgrid(x, y)
    xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)
    density = log_prob(xy).reshape(bins, bins).T
    im = ax.contour(
        density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)


def scatter_samples(
    x,
    scale: float, 
    ax=None, 
    **kwargs
):
    x = torch.clamp(x, -scale*1.5, scale*1.5)
    im = ax.scatter(x[:, 0].cpu(), x[:, 1].cpu(), **kwargs)

def contourf_density(
    log_prob,
    bins: int,
    scale: float,
    ax=None,
    device: torch.device = torch.device('cpu'),
    **kwargs
):
    if ax is None:
        ax = plt.gca()
    x = torch.linspace(-scale, scale, bins).to(device)
    y = torch.linspace(-scale, scale, bins).to(device)
    X, Y = torch.meshgrid(x, y)
    xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)
    density = log_prob(xy).reshape(bins, bins).T
    im = ax.contourf(
        density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)

def plot_density_and_samples(
    target_log_prob, 
    samples: torch.Tensor,
    bins: int = 200,
    scale: float = 10.0,
    contour_levels: int = 20,
    ax = None,
    device: torch.device = torch.device('cpu')
):
    if ax is None:
        ax = plt.gca()
    
    scale_plot = scale * 1.2

    im = imshow_density(
        target_log_prob, 
        bins, 
        scale_plot, 
        ax=ax, 
        vmin=-scale,
        cmap="Blues",
        device=device,
        zorder=1,
        alpha=0.5
    )

    im = contour_density(
        target_log_prob, 
        bins, 
        scale_plot, 
        ax=ax, 
        colors='grey',
        linestyles='solid',
        levels=contour_levels,
        device=device,
        zorder=1,
        alpha=0.25
    )

    im = scatter_samples(samples, scale, ax=ax, color='black', alpha=0.5)

    ax.set_aspect('equal', adjustable='box')
    ax.set_xlim(-scale_plot, scale_plot)
    ax.set_ylim(-scale_plot, scale_plot)
    ax.set_xticks([])
    ax.set_yticks([])

def plot_contour_and_samples(
    target_log_prob, 
    samples: torch.Tensor,
    bins: int = 100,
    scale: float = 10.0,
    contour_levels: int = 10,
    ax = None,
    device: torch.device = torch.device('cpu')
):
    if ax is None:
        ax = plt.gca()

    im = contour_density(
        target_log_prob, 
        bins, 
        scale, 
        ax=ax, 
        cmap=LinearSegmentedColormap.from_list("", ["navy", "aquamarine"]),
        levels=contour_levels,
        device=device,
        zorder=1
    )

    im = scatter_samples(samples, scale, ax=ax, color='black', alpha=0.5)

    ax.set_aspect('equal', adjustable='box')
    ax.set_xlim(-scale, scale)
    ax.set_ylim(-scale, scale)
    ax.set_xticks([])
    ax.set_yticks([])