import torch

from ..ng_utils import compute_infraction, compute_infraction_differentiable
from .vis import visualize_hist, visualize_scatter


def infraction(samples, max_batch_size=8196, device="cuda"):
    res = {}
    if not isinstance(samples, torch.Tensor):
        samples = torch.as_tensor(samples).to(device)
    infraction = torch.cat(
        [
            compute_infraction(batch)
            for batch in torch.split(samples, max_batch_size)
        ]
    ).cpu().numpy()
    res["infraction"] = infraction.astype("float").mean()
    infraction_dist = torch.cat(
        [
            compute_infraction_differentiable(batch, norm_p=1)
            for batch in torch.split(samples, max_batch_size, dim=0)
        ]
    ).cpu().numpy()
    if res["infraction"] == 0:
        res["infraction_dist_max"] = 0
        res["infraction_dist_mean"] = 0
    else:
        infraction_dist = infraction_dist[infraction] # Filter out the non-infracting ones
        res["infraction_dist_max"] = infraction_dist.max()
        res["infraction_dist_mean"] = infraction_dist.mean()
    return res


def accepted_cross_entropy(samples, max_batch_size=8196, device="cuda"):
    """Compute the cross entropy of the accepted samples with respect to true data generating distribution
    i.e., - E_{x~\bar{q}} \log p(x) where p is the true data distribution and \bar{q} is
    the distribution of accepted samples."""

    if not isinstance(samples, torch.Tensor):
        samples = torch.as_tensor(samples).to(device)
    cross_entropy = torch.cat(
        [
            torch.distributions.Normal(0, 1).log_prob(batch).sum(dim=-1)
            for batch in torch.split(samples, max_batch_size)
        ]
    )
    return {"cross_entropy": cross_entropy.mean().item()}