import json

import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from torch import nn
from torch.nn import functional as F
import random


def convert_dict_to_json_types(d):
    d_copy = {}
    for k, v in d.items():
        if (
            isinstance(v, np.float32)
            or isinstance(v, float)
            or isinstance(v, np.float64)
        ):
            d_copy.update({k: float(v)})
        elif isinstance(v, torch.Tensor):
            d_copy.update({k: v.item()})
    return d_copy


def dump_json(file: dict, path: str) -> None:
    with open(path, "w") as f:
        json.dump(file, f, indent=4)


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# This function takes as an input the images to reconstruct
# and the name of the model with which the reconstructions
# are performed
def to_img(x):
    x = x.clamp(0, 1)
    return x


def show_image(img):
    img = to_img(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


def visualise_output(o_images, model, device):
    max_v = 1.0
    with torch.no_grad():
        o_images = o_images.to(device)
        images = model.reconstruction(o_images)
        mse = torch.mean(
            (o_images.flatten(start_dim=1) - images.flatten(start_dim=1)) ** 2, dim=-1
        )
        psnr = 20 * torch.log10(max_v / torch.sqrt(mse))
        psnr = psnr.detach().cpu().numpy()
        images = images.cpu()
        images = to_img(images)
        np_imagegrid = torchvision.utils.make_grid(images[1:50], 10, 5).numpy()
        plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
        plt.show()


def sample_from_discretized_mix_logistic(l, grayscale=False):
    """
    Code taken from pytorch adaptation of original PixelCNN++ tf implementation
    https://github.com/pclucas14/pixel-cnn-pp
    """

    def to_one_hot(tensor, n):
        one_hot = torch.zeros(tensor.size() + (n,))
        one_hot = one_hot.to(tensor.device)
        one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), 1.0)
        return one_hot

    # Pytorch ordering
    l = l.permute(0, 2, 3, 1)
    ls = [int(y) for y in l.size()]
    if grayscale:
        xs = ls[:-1] + [1]
    else:
        xs = ls[:-1] + [3]

    # here and below: unpacking the params of the mixture of logistics
    if grayscale:
        nr_mix = int(ls[-1] / 3)
    else:
        nr_mix = int(ls[-1] / 10)

    # unpack parameters
    logit_probs = l[:, :, :, :nr_mix]
    if grayscale:
        l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2])
    else:
        l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3])
    # sample mixture indicator from softmax
    temp = torch.FloatTensor(logit_probs.size())
    if l.is_cuda:
        temp = temp.cuda()
    temp.uniform_(1e-5, 1.0 - 1e-5)
    temp = logit_probs.data - torch.log(-torch.log(temp))
    _, argmax = temp.max(dim=3)

    one_hot = to_one_hot(argmax, nr_mix)
    sel = one_hot.view(xs[:-1] + [1, nr_mix])
    # select logistic parameters
    means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4)
    log_scales = torch.clamp(
        torch.sum(l[:, :, :, :, nr_mix : 2 * nr_mix] * sel, dim=4), min=-7.0
    )
    if not grayscale:
        coeffs = torch.sum(
            torch.tanh(l[:, :, :, :, 2 * nr_mix : 3 * nr_mix]) * sel, dim=4
        )
    # sample from logistic & clip to interval
    # we don't actually round to the nearest 8bit value when sampling
    u = torch.FloatTensor(means.size())
    if l.is_cuda:
        u = u.cuda()
    u.uniform_(1e-5, 1.0 - 1e-5)
    u = nn.Parameter(u)
    x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
    x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.0), max=1.0)
    if grayscale:
        out = x0.view(xs[:-1] + [1])
    else:
        x1 = torch.clamp(
            torch.clamp(x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, min=-1.0), max=1.0
        )
        x2 = torch.clamp(
            torch.clamp(
                x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1,
                min=-1.0,
            ),
            max=1.0,
        )

        out = torch.cat(
            [x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])],
            dim=3,
        )
    # put back in Pytorch ordering
    out = out.permute(0, 3, 1, 2)
    return out


def discretized_mix_logistic_loss(x, l, grayscale=False):
    """
    log-likelihood for mixture of discretized logistics, assumes the data
    has been rescaled to [-1,1] interval
    Code taken from pytorch adaptation of original PixelCNN++ tf implementation
    https://github.com/pclucas14/pixel-cnn-pp
    """

    # channels last
    x = x.permute(0, 2, 3, 1)
    l = l.permute(0, 2, 3, 1)

    # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
    xs = [int(y) for y in x.size()]
    # predicted distribution, e.g. (B,32,32,100)
    ls = [int(y) for y in l.size()]

    # here and below: unpacking the params of the mixture of logistics
    if grayscale:
        nr_mix = int(ls[-1] / 3)
    else:
        nr_mix = int(ls[-1] / 10)
    logit_probs = l[:, :, :, :nr_mix]
    if grayscale:
        l = (
            l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2])
        )  # 2 for mean, scale
    else:
        l = (
            l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3])
        )  # 3 for mean, scale, coef
    means = l[:, :, :, :, :nr_mix]
    # log_scales = torch.max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.)
    log_scales = torch.clamp(l[:, :, :, :, nr_mix : 2 * nr_mix], min=-7.0)

    if not grayscale:
        coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix : 3 * nr_mix])
        # here and below: getting the means and adjusting them based on preceding
        # sub-pixels
        x = x.contiguous()
        x = x.unsqueeze(-1) + nn.Parameter(
            torch.zeros(xs + [nr_mix]).to(x.device), requires_grad=False
        )
        m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :]).view(
            xs[0], xs[1], xs[2], 1, nr_mix
        )

        m3 = (
            means[:, :, :, 2, :]
            + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :]
            + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]
        ).view(xs[0], xs[1], xs[2], 1, nr_mix)

        means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3)
    else:
        x = x.contiguous()
        x = x.unsqueeze(-1) + nn.Parameter(
            torch.zeros(xs + [nr_mix]).to(x.device), requires_grad=False
        )
    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
    cdf_plus = torch.sigmoid(plus_in)
    min_in = inv_stdv * (centered_x - 1.0 / 255.0)
    cdf_min = torch.sigmoid(min_in)
    # log probability for edge case of 0 (before scaling)
    log_cdf_plus = plus_in - F.softplus(plus_in)
    # log probability for edge case of 255 (before scaling)
    log_one_minus_cdf_min = -F.softplus(min_in)
    cdf_delta = cdf_plus - cdf_min  # probability for all other cases
    mid_in = inv_stdv * centered_x
    # log probability in the center of the bin, to be used in extreme cases
    # (not actually used in our code)
    log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)

    # now select the right output: left edge case, right edge case, normal
    # case, extremely low prob case (doesn't actually happen for us)

    # this is what we are really doing, but using the robust version below
    # for extreme cases in other applications and to avoid NaN issue with tf.select()
    # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999,
    # log_one_minus_cdf_min, tf.log(cdf_delta)))

    # robust version, that still works if probabilities are below 1e-5 (which
    # never happens in our code)
    # tensorflow backpropagates through tf.select() by multiplying with zero
    # instead of selecting: this requires use to use some ugly tricks to avoid
    # potential NaNs
    # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as
    # output, it's purely there to get around the tf.select() gradient issue
    # if the probability on a sub-pixel is below 1e-5, we use an approximation
    # based on the assumption that the log-density is constant in the bin of
    # the observed sub-pixel value

    inner_inner_cond = (cdf_delta > 1e-5).float()
    inner_inner_out = inner_inner_cond * torch.log(
        torch.clamp(cdf_delta, min=1e-12)
    ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log(127.5))
    inner_cond = (x > 0.999).float()
    inner_out = (
        inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
    )
    cond = (x < -0.999).float()
    log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
    log_probs = torch.sum(log_probs, dim=3) + torch.log_softmax(logit_probs, dim=-1)
    log_probs = torch.logsumexp(log_probs, dim=-1)

    # return -torch.sum(log_probs)
    loss_sep = -log_probs.sum((1, 2))  # keep batch dimension
    return loss_sep


def gaus_kl(q_mu, q_logsigmasq, p_mu, p_logsigmasq, dim=1):
    """
    Compute KL-divergence KL(q || p) between n pairs of Gaussians
    with diagonal covariational matrices.
    Do not divide KL-divergence by the dimensionality of the latent space.

    Input: q_mu, p_mu, Tensor of shape n x d - mean vectors for n Gaussians.
    Input: q_sigma, p_sigma, Tensor of shape n x d - standard deviation
           vectors for n Gaussians.
    Return: Tensor of shape n - each component is KL-divergence between
            a corresponding pair of Gaussians.
    """
    res = p_logsigmasq - q_logsigmasq - 1 + torch.exp(q_logsigmasq - p_logsigmasq)
    res = res + (q_mu - p_mu).pow(2) * torch.exp(-p_logsigmasq)
    if dim is not None:
        return 0.5 * res.sum(dim=dim)
    else:
        return 0.5 * res


def gaus_skl(q_mu, q_logsigmasq, p_mu, p_logsigmasq, dim=1):
    """
    Compute symmetric KL-divergence 0.5*KL(q || p) + 0.5*KL(p || q) between n pairs of Gaussians
    with diagonal covariational matrices.
    """
    logsigma_dif = p_logsigmasq - q_logsigmasq
    mu_dif = (q_mu - p_mu).pow(2)
    res = torch.exp(logsigma_dif) + torch.exp(-logsigma_dif)

    res = (
        1 / 4 * (res + mu_dif * (torch.exp(-p_logsigmasq) + torch.exp(-q_logsigmasq)))
        - 1 / 2
    )
    if dim is not None:
        return res.sum(dim=dim)
    else:
        return res
