import numpy as np
import torch
import math
from colour import Color
from PIL import Image
import io
import spaces
import matplotlib.pyplot as plt
from matplotlib import cm


def fig_to_img():
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    img = Image.open(buf)
    return img


def save_gif(img_list, file):
    img_list[0].save(
        file,
        format='GIF',
        append_images=img_list[1:],
        save_all=True,
        duration=100,
        loop=0
    )

    
def save_png(img, file):
    plt.imshow(img)
    plt.savefig(file)
    plt.close


def get_color(clr_scale):
    color_res = 100
    color_range = list(Color('blue').range_to(Color('red'), color_res))
    color_range = [c.rgb for c in color_range]
    min_clr_scale = min(clr_scale)
    max_clr_scale = max(clr_scale)
    clr_scale_rng = max_clr_scale - min_clr_scale
    if clr_scale_rng > 0.01:
        clr_scale = (clr_scale - min_clr_scale) / clr_scale_rng
        color_idx = (color_res - 1) * clr_scale
    else:
        color_idx = np.zeros_like(clr_scale)
    clr = np.array([color_range[int(c)] for c in color_idx])
    return clr


def get_color2(clr_scale):
    color_res = 100
    color_range = list(Color('blue').range_to(Color('red'), color_res))
    color_range = [c.rgb for c in color_range]
    color_idx = (color_res - 1) * np.clip(clr_scale * 0.5 + 0.5, 0, 1)
    clr = np.array([color_range[int(c)] for c in color_idx])
    return clr


def set_limit(ax, limit):
    ax.set_xlim(limit)
    ax.set_ylim(limit)
    ax.set_zlim(limit)


def set_labels(ax, dims):
    ax.set_xlabel('$z_' + str(dims[0] + 1) + '$')
    ax.set_ylabel('$z_' + str(dims[1] + 1) + '$')
    ax.set_zlabel('$z_' + str(dims[2] + 1) + '$')


def down_proj(features, dims):
    proj_feat = []
    for d in dims:
        if d < features.shape[1]:
            x = features[:, d]
        else:
            x = np.zeros([features.shape[0]])
        proj_feat.append(x)
    return proj_feat


def fibonacci_sphere(samples=1000):
    points = []
    phi = math.pi * (3. - math.sqrt(5.))  # golden angle in radians

    for i in range(samples):
        y = 1 - (i / float(samples - 1)) * 2  # y goes from 1 to -1
        radius = math.sqrt(1 - y * y)  # radius at y

        theta = phi * i  # golden angle increment

        x = math.cos(theta) * radius
        z = math.sin(theta) * radius

        points.append((x, y, z))

    return torch.tensor(points, dtype=torch.float)


def generate_grid(args, num_steps=8):
    if args.space_type == "box":
        rng = torch.linspace(args.box_min, args.box_max, steps=num_steps)
        grid = torch.cat([m.contiguous().view(-1, 1) for m in torch.meshgrid(*([rng]*args.n))], dim=-1)
    elif args.space_type == "sphere":
        if args.m_p == 0:
            # grid = spaces.NSphereSpace(args.n).uniform(512)
            grid = fibonacci_sphere(512)
        else:
            # eta = torch.zeros(args.n)
            # eta[0] = 1.0
            # grid = spaces.NSphereSpace(args.n).von_mises_fisher(eta, 1.0 / args.m_param, 512)
            grid = spaces.NSphereSpace(args.n).non_uniform(1.0 / args.m_param, 512)
            # grid = fibonacci_sphere(512)
    elif args.space_type == "hollow_ball":
        if args.n ==2:
            r_min = 0.5
            r_max = 2.0
            radius = torch.linspace(r_min, r_max, steps=num_steps)
            theta = torch.linspace(0, 2 * torch.pi, steps=num_steps)
            R, P = torch.meshgrid(radius, theta)
            X = R * torch.cos(P)
            Y = R * torch.sin(P)
            grid = torch.cat([X.contiguous().view(-1, 1), Y.contiguous().view(-1, 1)], dim=-1)
        else:
            grid = spaces.NHollowBallSpace(args.n).uniform(512)
    elif args.space_type == "cube_grid":
        gap_size = 0.5
        rng = torch.linspace(-1, 1, steps=num_steps)
        grid = torch.cat([m.contiguous().view(-1, 1) for m in torch.meshgrid(*([rng]*args.n))], dim=-1)
        grid[grid >= 0] += 0.5 * gap_size
        grid[grid < 0] -= 0.5 * gap_size
    else:
        rng = torch.linspace(-1, 1, steps=num_steps)
        grid = torch.cat([m.contiguous().view(-1, 1) for m in torch.meshgrid(*([rng]*args.n))], dim=-1)
    return grid


def plot_latents(sources, features, iter=0):
    clr_scales = sources

    dims = [[0, 1, 2]]

    num_rows = len(dims)
    num_cols = clr_scales.shape[1]

    fig = plt.figure(figsize=(16, 8), dpi=100)
    for i in range(num_rows):
        for j in range(num_cols):
            ax = fig.add_subplot(num_rows, num_cols, i * num_cols + j + 1, projection='3d')

            clr_scale = clr_scales[:, j]
            clr = get_color(clr_scale)

            x, y, z = down_proj(features, dims[i])

            ax.scatter(x, y, z, c=clr, marker='o', s=14)
            set_labels(ax, dims[i])
            ax.grid(False)
            if features.shape[1] < 3:
                ax.view_init(elev=90, azim=90)
            else:
                # ax.view_init(elev=45, azim=45)
                # ax.view_init(elev=90, azim=90)
                ax.view_init(elev=30, azim=45 + 3*iter)

    plt.tight_layout(h_pad=-3, w_pad=-3, rect=[0, 0, 1, 1])
    img = fig_to_img()
    plt.close()
    return img


def ln_norm_const_box_lp1(mu, b):
    ln_nc = np.sum(np.log(2*b - b*(np.exp(-mu/b) + np.exp((mu - 1)/b))), axis=-1)
    return ln_nc


def uniform_laplace_marginal(x, b=0.05):
    assert b == 0.05 or b == 0.1  # generate fp externally for other b values

    xp = np.linspace(0, 1, 1001)

    if b == 0.05:
        fp = np.array([0.69314719, 0.70675231, 0.71985557, 0.73247687, 0.74463511, 0.75634825, 0.76763338, 0.77850677, 0.78898393, 0.79907969, 0.80880817, 0.81818289, 0.82721676, 0.83592216, 0.84431093, 0.85239440, 0.86018346, 0.86768854, 0.87491966, 0.88188643, 0.88859808, 0.89506350, 0.90129121, 0.90728942, 0.91306604, 0.91862866, 0.92398460, 0.92914092, 0.93410440, 0.93888158, 0.94347879, 0.94790209, 0.95215737, 0.95625027, 0.96018626, 0.96397062, 0.96760843, 0.97110460, 0.97446389, 0.97769086, 0.98078996, 0.98376545, 0.98662148, 0.98936204, 0.99199098, 0.99451205, 0.99692885, 0.99924488, 1.0014635, 1.0035880, 1.0056215, 1.0075671, 1.0094277, 1.0112063, 1.0129055, 1.0145280, 1.0160765, 1.0175534, 1.0189612, 1.0203023, 1.0215788, 1.0227930, 1.0239471, 1.0250431, 1.0260830, 1.0270687, 1.0280022, 1.0288851, 1.0297194, 1.0305067, 1.0312486, 1.0319467, 1.0326027, 1.0332179, 1.0337939, 1.0343319, 1.0348335, 1.0352999, 1.0357324, 1.0361323, 1.0365006, 1.0368387, 1.0371475, 1.0374282, 1.0376819, 1.0379096, 1.0381122, 1.0382907, 1.0384460, 1.0385791, 1.0386908, 1.0387819, 1.0388532, 1.0389056, 1.0389398, 1.0389565, 1.0389565, 1.0389404, 1.0389089, 1.0388626, 1.0388023, 1.0387284, 1.0386417, 1.0385425, 1.0384316, 1.0383093, 1.0381763, 1.0380330, 1.0378800, 1.0377176, 1.0375463, 1.0373666, 1.0371789, 1.0369836, 1.0367810, 1.0365716, 1.0363557, 1.0361336, 1.0359058, 1.0356725, 1.0354341, 1.0351908, 1.0349430, 1.0346910, 1.0344350, 1.0341753, 1.0339121, 1.0336457, 1.0333764, 1.0331043, 1.0328297, 1.0325529, 1.0322739, 1.0319930, 1.0317105, 1.0314264, 1.0311410, 1.0308544, 1.0305668, 1.0302784, 1.0299893, 1.0296996, 1.0294095, 1.0291192, 1.0288287, 1.0285381, 1.0282477, 1.0279575, 1.0276676, 1.0273782, 1.0270892, 1.0268009, 1.0265133, 1.0262266, 1.0259407, 1.0256557, 1.0253718, 1.0250891, 1.0248075, 1.0245272, 1.0242482, 1.0239705, 1.0236943, 1.0234196, 1.0231464, 1.0228749, 1.0226049, 1.0223366, 1.0220701, 1.0218053, 1.0215423, 1.0212811, 1.0210218, 1.0207644, 1.0205089, 1.0202553, 1.0200038, 1.0197542, 1.0195066, 1.0192610, 1.0190176, 1.0187761, 1.0185368, 1.0182995, 1.0180644, 1.0178313, 1.0176004, 1.0173717, 1.0171450, 1.0169206, 1.0166982, 1.0164781, 1.0162600, 1.0160442, 1.0158305, 1.0156189, 1.0154095, 1.0152023, 1.0149972, 1.0147943, 1.0145935, 1.0143949, 1.0141984, 1.0140040, 1.0138117, 1.0136216, 1.0134335, 1.0132476, 1.0130637, 1.0128820, 1.0127023, 1.0125246, 1.0123490, 1.0121755, 1.0120039, 1.0118344, 1.0116669, 1.0115014, 1.0113378, 1.0111762, 1.0110166, 1.0108589, 1.0107031, 1.0105492, 1.0103973, 1.0102472, 1.0100989, 1.0099525, 1.0098080, 1.0096653, 1.0095243, 1.0093852, 1.0092478, 1.0091122, 1.0089784, 1.0088462, 1.0087158, 1.0085871, 1.0084600, 1.0083346, 1.0082109, 1.0080888, 1.0079683, 1.0078494, 1.0077321, 1.0076164, 1.0075022, 1.0073895, 1.0072784, 1.0071687, 1.0070606, 1.0069539, 1.0068487, 1.0067449, 1.0066426, 1.0065416, 1.0064421, 1.0063439, 1.0062471, 1.0061517, 1.0060575, 1.0059647, 1.0058732, 1.0057830, 1.0056940, 1.0056063, 1.0055199, 1.0054347, 1.0053507, 1.0052678, 1.0051862, 1.0051057, 1.0050264, 1.0049483, 1.0048712, 1.0047953, 1.0047205, 1.0046467, 1.0045740, 1.0045024, 1.0044319, 1.0043623, 1.0042938, 1.0042263, 1.0041598, 1.0040942, 1.0040296, 1.0039660, 1.0039033, 1.0038416, 1.0037808, 1.0037208, 1.0036618, 1.0036037, 1.0035464, 1.0034900, 1.0034344, 1.0033797, 1.0033258, 1.0032727, 1.0032204, 1.0031689, 1.0031181, 1.0030682, 1.0030190, 1.0029705, 1.0029228, 1.0028759, 1.0028296, 1.0027841, 1.0027392, 1.0026951, 1.0026516, 1.0026088, 1.0025666, 1.0025251, 1.0024843, 1.0024440, 1.0024044, 1.0023654, 1.0023271, 1.0022893, 1.0022521, 1.0022155, 1.0021795, 1.0021440, 1.0021091, 1.0020747, 1.0020409, 1.0020076, 1.0019748, 1.0019426, 1.0019108, 1.0018796, 1.0018488, 1.0018186, 1.0017888, 1.0017595, 1.0017307, 1.0017023, 1.0016744, 1.0016469, 1.0016199, 1.0015933, 1.0015671, 1.0015413, 1.0015160, 1.0014911, 1.0014665, 1.0014424, 1.0014187, 1.0013953, 1.0013723, 1.0013497, 1.0013275, 1.0013056, 1.0012841, 1.0012629, 1.0012421, 1.0012216, 1.0012015, 1.0011817, 1.0011622, 1.0011430, 1.0011242, 1.0011056, 1.0010874, 1.0010695, 1.0010518, 1.0010345, 1.0010174, 1.0010007, 1.0009842, 1.0009680, 1.0009520, 1.0009363, 1.0009209, 1.0009058, 1.0008909, 1.0008762, 1.0008618, 1.0008477, 1.0008338, 1.0008201, 1.0008066, 1.0007934, 1.0007804, 1.0007677, 1.0007551, 1.0007428, 1.0007306, 1.0007187, 1.0007070, 1.0006955, 1.0006842, 1.0006731, 1.0006622, 1.0006515, 1.0006410, 1.0006306, 1.0006204, 1.0006105, 1.0006007, 1.0005910, 1.0005816, 1.0005723, 1.0005631, 1.0005542, 1.0005454, 1.0005367, 1.0005283, 1.0005199, 1.0005118, 1.0005037, 1.0004959, 1.0004881, 1.0004805, 1.0004731, 1.0004658, 1.0004586, 1.0004516, 1.0004447, 1.0004379, 1.0004313, 1.0004248, 1.0004184, 1.0004121, 1.0004060, 1.0004000, 1.0003941, 1.0003883, 1.0003827, 1.0003771, 1.0003717, 1.0003664, 1.0003612, 1.0003561, 1.0003511, 1.0003462, 1.0003414, 1.0003367, 1.0003322, 1.0003277, 1.0003233, 1.0003190, 1.0003148, 1.0003108, 1.0003068, 1.0003029, 1.0002991, 1.0002953, 1.0002917, 1.0002882, 1.0002847, 1.0002814, 1.0002781, 1.0002749, 1.0002718, 1.0002688, 1.0002658, 1.0002630, 1.0002602, 1.0002575, 1.0002549, 1.0002523, 1.0002499, 1.0002475, 1.0002452, 1.0002430, 1.0002408, 1.0002387, 1.0002367, 1.0002348, 1.0002329, 1.0002311, 1.0002294, 1.0002278, 1.0002262, 1.0002247, 1.0002232, 1.0002219, 1.0002206, 1.0002193, 1.0002182, 1.0002171, 1.0002161, 1.0002151, 1.0002142, 1.0002134, 1.0002126, 1.0002120, 1.0002113, 1.0002108, 1.0002103, 1.0002099, 1.0002095, 1.0002092, 1.0002090, 1.0002088, 1.0002087, 1.0002087, 1.0002087, 1.0002088, 1.0002090, 1.0002092, 1.0002095, 1.0002099, 1.0002103, 1.0002108, 1.0002113, 1.0002120, 1.0002126, 1.0002134, 1.0002142, 1.0002151, 1.0002161, 1.0002171, 1.0002182, 1.0002193, 1.0002206, 1.0002219, 1.0002232, 1.0002247, 1.0002262, 1.0002278, 1.0002294, 1.0002311, 1.0002329, 1.0002348, 1.0002367, 1.0002387, 1.0002408, 1.0002430, 1.0002452, 1.0002475, 1.0002499, 1.0002523, 1.0002549, 1.0002575, 1.0002602, 1.0002630, 1.0002658, 1.0002688, 1.0002718, 1.0002749, 1.0002781, 1.0002814, 1.0002847, 1.0002882, 1.0002917, 1.0002953, 1.0002991, 1.0003029, 1.0003068, 1.0003108, 1.0003148, 1.0003190, 1.0003233, 1.0003277, 1.0003322, 1.0003367, 1.0003414, 1.0003462, 1.0003511, 1.0003561, 1.0003612, 1.0003664, 1.0003717, 1.0003771, 1.0003827, 1.0003883, 1.0003941, 1.0004000, 1.0004060, 1.0004121, 1.0004184, 1.0004248, 1.0004313, 1.0004379, 1.0004447, 1.0004516, 1.0004586, 1.0004658, 1.0004731, 1.0004805, 1.0004881, 1.0004959, 1.0005037, 1.0005118, 1.0005199, 1.0005283, 1.0005367, 1.0005454, 1.0005542, 1.0005631, 1.0005723, 1.0005816, 1.0005910, 1.0006007, 1.0006105, 1.0006204, 1.0006306, 1.0006410, 1.0006515, 1.0006622, 1.0006731, 1.0006842, 1.0006955, 1.0007070, 1.0007187, 1.0007306, 1.0007428, 1.0007551, 1.0007677, 1.0007804, 1.0007934, 1.0008066, 1.0008201, 1.0008338, 1.0008477, 1.0008618, 1.0008762, 1.0008909, 1.0009058, 1.0009209, 1.0009363, 1.0009520, 1.0009680, 1.0009842, 1.0010007, 1.0010174, 1.0010345, 1.0010518, 1.0010695, 1.0010874, 1.0011056, 1.0011242, 1.0011430, 1.0011622, 1.0011817, 1.0012015, 1.0012216, 1.0012421, 1.0012629, 1.0012841, 1.0013056, 1.0013275, 1.0013497, 1.0013723, 1.0013953, 1.0014187, 1.0014424, 1.0014665, 1.0014911, 1.0015160, 1.0015413, 1.0015671, 1.0015933, 1.0016199, 1.0016469, 1.0016744, 1.0017023, 1.0017307, 1.0017595, 1.0017888, 1.0018186, 1.0018488, 1.0018796, 1.0019108, 1.0019426, 1.0019748, 1.0020076, 1.0020409, 1.0020747, 1.0021091, 1.0021440, 1.0021795, 1.0022155, 1.0022521, 1.0022893, 1.0023271, 1.0023654, 1.0024044, 1.0024440, 1.0024843, 1.0025251, 1.0025666, 1.0026088, 1.0026516, 1.0026951, 1.0027392, 1.0027841, 1.0028296, 1.0028759, 1.0029228, 1.0029705, 1.0030190, 1.0030682, 1.0031181, 1.0031689, 1.0032204, 1.0032727, 1.0033258, 1.0033797, 1.0034344, 1.0034900, 1.0035464, 1.0036037, 1.0036618, 1.0037208, 1.0037808, 1.0038416, 1.0039033, 1.0039660, 1.0040296, 1.0040942, 1.0041598, 1.0042263, 1.0042938, 1.0043623, 1.0044319, 1.0045024, 1.0045740, 1.0046467, 1.0047205, 1.0047953, 1.0048712, 1.0049483, 1.0050264, 1.0051057, 1.0051862, 1.0052678, 1.0053507, 1.0054347, 1.0055199, 1.0056063, 1.0056940, 1.0057830, 1.0058732, 1.0059647, 1.0060575, 1.0061517, 1.0062471, 1.0063439, 1.0064421, 1.0065416, 1.0066426, 1.0067449, 1.0068487, 1.0069539, 1.0070606, 1.0071687, 1.0072784, 1.0073895, 1.0075022, 1.0076164, 1.0077321, 1.0078494, 1.0079683, 1.0080888, 1.0082109, 1.0083346, 1.0084600, 1.0085871, 1.0087158, 1.0088462, 1.0089784, 1.0091122, 1.0092478, 1.0093852, 1.0095243, 1.0096653, 1.0098080, 1.0099525, 1.0100989, 1.0102472, 1.0103973, 1.0105492, 1.0107031, 1.0108589, 1.0110166, 1.0111762, 1.0113378, 1.0115014, 1.0116669, 1.0118344, 1.0120039, 1.0121755, 1.0123490, 1.0125246, 1.0127023, 1.0128820, 1.0130637, 1.0132476, 1.0134335, 1.0136216, 1.0138117, 1.0140040, 1.0141984, 1.0143949, 1.0145935, 1.0147943, 1.0149972, 1.0152023, 1.0154095, 1.0156189, 1.0158305, 1.0160442, 1.0162600, 1.0164781, 1.0166982, 1.0169206, 1.0171450, 1.0173717, 1.0176004, 1.0178313, 1.0180644, 1.0182995, 1.0185368, 1.0187761, 1.0190176, 1.0192610, 1.0195066, 1.0197542, 1.0200038, 1.0202553, 1.0205089, 1.0207644, 1.0210218, 1.0212811, 1.0215423, 1.0218053, 1.0220701, 1.0223366, 1.0226049, 1.0228749, 1.0231464, 1.0234196, 1.0236943, 1.0239705, 1.0242482, 1.0245272, 1.0248075, 1.0250891, 1.0253718, 1.0256557, 1.0259407, 1.0262266, 1.0265133, 1.0268009, 1.0270892, 1.0273782, 1.0276676, 1.0279575, 1.0282477, 1.0285381, 1.0288287, 1.0291192, 1.0294095, 1.0296996, 1.0299893, 1.0302784, 1.0305668, 1.0308544, 1.0311410, 1.0314264, 1.0317105, 1.0319930, 1.0322739, 1.0325529, 1.0328297, 1.0331043, 1.0333764, 1.0336457, 1.0339121, 1.0341753, 1.0344350, 1.0346910, 1.0349430, 1.0351908, 1.0354341, 1.0356725, 1.0359058, 1.0361336, 1.0363557, 1.0365716, 1.0367810, 1.0369836, 1.0371789, 1.0373666, 1.0375463, 1.0377176, 1.0378800, 1.0380330, 1.0381763, 1.0383093, 1.0384316, 1.0385425, 1.0386417, 1.0387284, 1.0388023, 1.0388626, 1.0389089, 1.0389404, 1.0389565, 1.0389565, 1.0389398, 1.0389056, 1.0388532, 1.0387819, 1.0386908, 1.0385791, 1.0384460, 1.0382907, 1.0381122, 1.0379096, 1.0376819, 1.0374282, 1.0371475, 1.0368387, 1.0365006, 1.0361323, 1.0357324, 1.0352999, 1.0348335, 1.0343319, 1.0337939, 1.0332179, 1.0326027, 1.0319467, 1.0312486, 1.0305067, 1.0297194, 1.0288851, 1.0280022, 1.0270687, 1.0260830, 1.0250431, 1.0239471, 1.0227930, 1.0215788, 1.0203023, 1.0189612, 1.0175534, 1.0160765, 1.0145280, 1.0129055, 1.0112063, 1.0094277, 1.0075671, 1.0056215, 1.0035880, 1.0014635, 0.99924488, 0.99692885, 0.99451205, 0.99199098, 0.98936204, 0.98662148, 0.98376545, 0.98078996, 0.97769086, 0.97446389, 0.97110460, 0.96760843, 0.96397062, 0.96018626, 0.95625027, 0.95215737, 0.94790209, 0.94347879, 0.93888158, 0.93410440, 0.92914092, 0.92398460, 0.91862866, 0.91306604, 0.90728942, 0.90129121, 0.89506350, 0.88859808, 0.88188643, 0.87491966, 0.86768854, 0.86018346, 0.85239440, 0.84431093, 0.83592216, 0.82721676, 0.81818289, 0.80880817, 0.79907969, 0.78898393, 0.77850677, 0.76763338, 0.75634825, 0.74463511, 0.73247687, 0.71985557, 0.70675231, 0.69314719])
    elif b == 0.1:
        fp = np.array([0.69326507, 0.70013282, 0.70687255, 0.71348683, 0.71997818, 0.72634907, 0.73260188, 0.73873894, 0.74476252, 0.75067485, 0.75647809, 0.76217433, 0.76776565, 0.77325405, 0.77864150, 0.78392991, 0.78912115, 0.79421706, 0.79921941, 0.80412996, 0.80895042, 0.81368245, 0.81832769, 0.82288774, 0.82736415, 0.83175845, 0.83607215, 0.84030671, 0.84446355, 0.84854409, 0.85254969, 0.85648171, 0.86034145, 0.86413021, 0.86784926, 0.87149983, 0.87508314, 0.87860038, 0.88205271, 0.88544127, 0.88876719, 0.89203157, 0.89523548, 0.89837998, 0.90146610, 0.90449486, 0.90746727, 0.91038428, 0.91324687, 0.91605598, 0.91881252, 0.92151741, 0.92417154, 0.92677577, 0.92933097, 0.93183798, 0.93429761, 0.93671069, 0.93907800, 0.94140034, 0.94367846, 0.94591313, 0.94810507, 0.95025502, 0.95236370, 0.95443180, 0.95646001, 0.95844900, 0.96039946, 0.96231201, 0.96418732, 0.96602601, 0.96782869, 0.96959599, 0.97132849, 0.97302678, 0.97469145, 0.97632305, 0.97792216, 0.97948931, 0.98102504, 0.98252990, 0.98400439, 0.98544904, 0.98686434, 0.98825078, 0.98960887, 0.99093907, 0.99224186, 0.99351769, 0.99476703, 0.99599032, 0.99718801, 0.99836052, 0.99950828, 1.0006317, 1.0017312, 1.0028072, 1.0038601, 1.0048902, 1.0058981, 1.0068839, 1.0078482, 1.0087912, 1.0097134, 1.0106151, 1.0114966, 1.0123584, 1.0132006, 1.0140237, 1.0148280, 1.0156138, 1.0163814, 1.0171311, 1.0178633, 1.0185782, 1.0192762, 1.0199575, 1.0206224, 1.0212712, 1.0219042, 1.0225216, 1.0231238, 1.0237109, 1.0242833, 1.0248411, 1.0253848, 1.0259144, 1.0264303, 1.0269326, 1.0274217, 1.0278978, 1.0283610, 1.0288116, 1.0292499, 1.0296760, 1.0300902, 1.0304926, 1.0308835, 1.0312632, 1.0316317, 1.0319893, 1.0323361, 1.0326725, 1.0329985, 1.0333143, 1.0336202, 1.0339163, 1.0342028, 1.0344799, 1.0347477, 1.0350064, 1.0352562, 1.0354972, 1.0357296, 1.0359535, 1.0361692, 1.0363767, 1.0365762, 1.0367679, 1.0369519, 1.0371284, 1.0372974, 1.0374592, 1.0376139, 1.0377616, 1.0379024, 1.0380364, 1.0381639, 1.0382849, 1.0383996, 1.0385080, 1.0386103, 1.0387066, 1.0387971, 1.0388818, 1.0389608, 1.0390343, 1.0391024, 1.0391652, 1.0392228, 1.0392753, 1.0393228, 1.0393653, 1.0394031, 1.0394362, 1.0394646, 1.0394886, 1.0395081, 1.0395233, 1.0395343, 1.0395411, 1.0395439, 1.0395427, 1.0395376, 1.0395287, 1.0395161, 1.0394998, 1.0394799, 1.0394566, 1.0394299, 1.0393998, 1.0393665, 1.0393300, 1.0392904, 1.0392477, 1.0392020, 1.0391535, 1.0391021, 1.0390479, 1.0389910, 1.0389315, 1.0388694, 1.0388048, 1.0387377, 1.0386682, 1.0385964, 1.0385223, 1.0384460, 1.0383675, 1.0382869, 1.0382042, 1.0381195, 1.0380329, 1.0379444, 1.0378540, 1.0377618, 1.0376679, 1.0375723, 1.0374750, 1.0373760, 1.0372756, 1.0371736, 1.0370701, 1.0369652, 1.0368589, 1.0367513, 1.0366423, 1.0365321, 1.0364207, 1.0363081, 1.0361943, 1.0360794, 1.0359635, 1.0358465, 1.0357285, 1.0356095, 1.0354896, 1.0353689, 1.0352472, 1.0351247, 1.0350015, 1.0348774, 1.0347526, 1.0346272, 1.0345010, 1.0343742, 1.0342468, 1.0341188, 1.0339903, 1.0338612, 1.0337316, 1.0336015, 1.0334710, 1.0333400, 1.0332087, 1.0330770, 1.0329449, 1.0328125, 1.0326798, 1.0325468, 1.0324136, 1.0322801, 1.0321464, 1.0320125, 1.0318785, 1.0317442, 1.0316099, 1.0314754, 1.0313409, 1.0312063, 1.0310716, 1.0309369, 1.0308021, 1.0306674, 1.0305327, 1.0303980, 1.0302634, 1.0301288, 1.0299943, 1.0298599, 1.0297257, 1.0295915, 1.0294575, 1.0293237, 1.0291900, 1.0290565, 1.0289232, 1.0287901, 1.0286573, 1.0285247, 1.0283923, 1.0282602, 1.0281284, 1.0279969, 1.0278656, 1.0277347, 1.0276041, 1.0274739, 1.0273439, 1.0272144, 1.0270852, 1.0269563, 1.0268279, 1.0266998, 1.0265722, 1.0264450, 1.0263181, 1.0261918, 1.0260658, 1.0259403, 1.0258153, 1.0256907, 1.0255666, 1.0254429, 1.0253198, 1.0251972, 1.0250750, 1.0249534, 1.0248322, 1.0247116, 1.0245916, 1.0244720, 1.0243530, 1.0242346, 1.0241167, 1.0239993, 1.0238826, 1.0237664, 1.0236507, 1.0235357, 1.0234212, 1.0233074, 1.0231941, 1.0230814, 1.0229694, 1.0228579, 1.0227471, 1.0226368, 1.0225272, 1.0224183, 1.0223099, 1.0222022, 1.0220952, 1.0219887, 1.0218829, 1.0217778, 1.0216733, 1.0215695, 1.0214664, 1.0213639, 1.0212620, 1.0211609, 1.0210604, 1.0209606, 1.0208614, 1.0207630, 1.0206652, 1.0205681, 1.0204717, 1.0203760, 1.0202810, 1.0201866, 1.0200930, 1.0200001, 1.0199078, 1.0198163, 1.0197255, 1.0196353, 1.0195459, 1.0194572, 1.0193692, 1.0192819, 1.0191954, 1.0191095, 1.0190244, 1.0189399, 1.0188562, 1.0187733, 1.0186910, 1.0186095, 1.0185286, 1.0184486, 1.0183692, 1.0182906, 1.0182127, 1.0181355, 1.0180590, 1.0179833, 1.0179083, 1.0178341, 1.0177605, 1.0176878, 1.0176157, 1.0175444, 1.0174738, 1.0174039, 1.0173348, 1.0172665, 1.0171988, 1.0171319, 1.0170658, 1.0170003, 1.0169356, 1.0168717, 1.0168085, 1.0167460, 1.0166843, 1.0166233, 1.0165631, 1.0165036, 1.0164448, 1.0163868, 1.0163295, 1.0162730, 1.0162172, 1.0161621, 1.0161078, 1.0160542, 1.0160014, 1.0159493, 1.0158979, 1.0158473, 1.0157975, 1.0157484, 1.0157000, 1.0156523, 1.0156054, 1.0155593, 1.0155139, 1.0154692, 1.0154253, 1.0153821, 1.0153397, 1.0152980, 1.0152570, 1.0152168, 1.0151773, 1.0151386, 1.0151006, 1.0150633, 1.0150268, 1.0149910, 1.0149560, 1.0149217, 1.0148881, 1.0148553, 1.0148232, 1.0147919, 1.0147613, 1.0147315, 1.0147023, 1.0146740, 1.0146463, 1.0146194, 1.0145933, 1.0145679, 1.0145432, 1.0145192, 1.0144960, 1.0144736, 1.0144518, 1.0144308, 1.0144106, 1.0143911, 1.0143723, 1.0143542, 1.0143369, 1.0143204, 1.0143045, 1.0142895, 1.0142751, 1.0142615, 1.0142486, 1.0142365, 1.0142250, 1.0142144, 1.0142044, 1.0141952, 1.0141868, 1.0141790, 1.0141720, 1.0141658, 1.0141603, 1.0141555, 1.0141514, 1.0141481, 1.0141456, 1.0141437, 1.0141426, 1.0141422, 1.0141426, 1.0141437, 1.0141456, 1.0141481, 1.0141514, 1.0141555, 1.0141603, 1.0141658, 1.0141720, 1.0141790, 1.0141868, 1.0141952, 1.0142044, 1.0142144, 1.0142250, 1.0142365, 1.0142486, 1.0142615, 1.0142751, 1.0142895, 1.0143045, 1.0143204, 1.0143369, 1.0143542, 1.0143723, 1.0143911, 1.0144106, 1.0144308, 1.0144518, 1.0144736, 1.0144960, 1.0145192, 1.0145432, 1.0145679, 1.0145933, 1.0146194, 1.0146463, 1.0146740, 1.0147023, 1.0147315, 1.0147613, 1.0147919, 1.0148232, 1.0148553, 1.0148881, 1.0149217, 1.0149560, 1.0149910, 1.0150268, 1.0150633, 1.0151006, 1.0151386, 1.0151773, 1.0152168, 1.0152570, 1.0152980, 1.0153397, 1.0153821, 1.0154253, 1.0154692, 1.0155139, 1.0155593, 1.0156054, 1.0156523, 1.0157000, 1.0157484, 1.0157975, 1.0158473, 1.0158979, 1.0159493, 1.0160014, 1.0160542, 1.0161078, 1.0161621, 1.0162172, 1.0162730, 1.0163295, 1.0163868, 1.0164448, 1.0165036, 1.0165631, 1.0166233, 1.0166843, 1.0167460, 1.0168085, 1.0168717, 1.0169356, 1.0170003, 1.0170658, 1.0171319, 1.0171988, 1.0172665, 1.0173348, 1.0174039, 1.0174738, 1.0175444, 1.0176157, 1.0176878, 1.0177605, 1.0178341, 1.0179083, 1.0179833, 1.0180590, 1.0181355, 1.0182127, 1.0182906, 1.0183692, 1.0184486, 1.0185286, 1.0186095, 1.0186910, 1.0187733, 1.0188562, 1.0189399, 1.0190244, 1.0191095, 1.0191954, 1.0192819, 1.0193692, 1.0194572, 1.0195459, 1.0196353, 1.0197255, 1.0198163, 1.0199078, 1.0200001, 1.0200930, 1.0201866, 1.0202810, 1.0203760, 1.0204717, 1.0205681, 1.0206652, 1.0207630, 1.0208614, 1.0209606, 1.0210604, 1.0211609, 1.0212620, 1.0213639, 1.0214664, 1.0215695, 1.0216733, 1.0217778, 1.0218829, 1.0219887, 1.0220952, 1.0222022, 1.0223099, 1.0224183, 1.0225272, 1.0226368, 1.0227471, 1.0228579, 1.0229694, 1.0230814, 1.0231941, 1.0233074, 1.0234212, 1.0235357, 1.0236507, 1.0237664, 1.0238826, 1.0239993, 1.0241167, 1.0242346, 1.0243530, 1.0244720, 1.0245916, 1.0247116, 1.0248322, 1.0249534, 1.0250750, 1.0251972, 1.0253198, 1.0254429, 1.0255666, 1.0256907, 1.0258153, 1.0259403, 1.0260658, 1.0261918, 1.0263181, 1.0264450, 1.0265722, 1.0266998, 1.0268279, 1.0269563, 1.0270852, 1.0272144, 1.0273439, 1.0274739, 1.0276041, 1.0277347, 1.0278656, 1.0279969, 1.0281284, 1.0282602, 1.0283923, 1.0285247, 1.0286573, 1.0287901, 1.0289232, 1.0290565, 1.0291900, 1.0293237, 1.0294575, 1.0295915, 1.0297257, 1.0298599, 1.0299943, 1.0301288, 1.0302634, 1.0303980, 1.0305327, 1.0306674, 1.0308021, 1.0309369, 1.0310716, 1.0312063, 1.0313409, 1.0314754, 1.0316099, 1.0317442, 1.0318785, 1.0320125, 1.0321464, 1.0322801, 1.0324136, 1.0325468, 1.0326798, 1.0328125, 1.0329449, 1.0330770, 1.0332087, 1.0333400, 1.0334710, 1.0336015, 1.0337316, 1.0338612, 1.0339903, 1.0341188, 1.0342468, 1.0343742, 1.0345010, 1.0346272, 1.0347526, 1.0348774, 1.0350015, 1.0351247, 1.0352472, 1.0353689, 1.0354896, 1.0356095, 1.0357285, 1.0358465, 1.0359635, 1.0360794, 1.0361943, 1.0363081, 1.0364207, 1.0365321, 1.0366423, 1.0367513, 1.0368589, 1.0369652, 1.0370701, 1.0371736, 1.0372756, 1.0373760, 1.0374750, 1.0375723, 1.0376679, 1.0377618, 1.0378540, 1.0379444, 1.0380329, 1.0381195, 1.0382042, 1.0382869, 1.0383675, 1.0384460, 1.0385223, 1.0385964, 1.0386682, 1.0387377, 1.0388048, 1.0388694, 1.0389315, 1.0389910, 1.0390479, 1.0391021, 1.0391535, 1.0392020, 1.0392477, 1.0392904, 1.0393300, 1.0393665, 1.0393998, 1.0394299, 1.0394566, 1.0394799, 1.0394998, 1.0395161, 1.0395287, 1.0395376, 1.0395427, 1.0395439, 1.0395411, 1.0395343, 1.0395233, 1.0395081, 1.0394886, 1.0394646, 1.0394362, 1.0394031, 1.0393653, 1.0393228, 1.0392753, 1.0392228, 1.0391652, 1.0391024, 1.0390343, 1.0389608, 1.0388818, 1.0387971, 1.0387066, 1.0386103, 1.0385080, 1.0383996, 1.0382849, 1.0381639, 1.0380364, 1.0379024, 1.0377616, 1.0376139, 1.0374592, 1.0372974, 1.0371284, 1.0369519, 1.0367679, 1.0365762, 1.0363767, 1.0361692, 1.0359535, 1.0357296, 1.0354972, 1.0352562, 1.0350064, 1.0347477, 1.0344799, 1.0342028, 1.0339163, 1.0336202, 1.0333143, 1.0329985, 1.0326725, 1.0323361, 1.0319893, 1.0316317, 1.0312632, 1.0308835, 1.0304926, 1.0300902, 1.0296760, 1.0292499, 1.0288116, 1.0283610, 1.0278978, 1.0274217, 1.0269326, 1.0264303, 1.0259144, 1.0253848, 1.0248411, 1.0242833, 1.0237109, 1.0231238, 1.0225216, 1.0219042, 1.0212712, 1.0206224, 1.0199575, 1.0192762, 1.0185782, 1.0178633, 1.0171311, 1.0163814, 1.0156138, 1.0148280, 1.0140237, 1.0132006, 1.0123584, 1.0114966, 1.0106151, 1.0097134, 1.0087912, 1.0078482, 1.0068839, 1.0058981, 1.0048902, 1.0038601, 1.0028072, 1.0017312, 1.0006317, 0.99950828, 0.99836052, 0.99718801, 0.99599032, 0.99476703, 0.99351769, 0.99224186, 0.99093907, 0.98960887, 0.98825078, 0.98686434, 0.98544904, 0.98400439, 0.98252990, 0.98102504, 0.97948931, 0.97792216, 0.97632305, 0.97469145, 0.97302678, 0.97132849, 0.96959599, 0.96782869, 0.96602601, 0.96418732, 0.96231201, 0.96039946, 0.95844900, 0.95646001, 0.95443180, 0.95236370, 0.95025502, 0.94810507, 0.94591313, 0.94367846, 0.94140034, 0.93907800, 0.93671069, 0.93429761, 0.93183798, 0.92933097, 0.92677577, 0.92417154, 0.92151741, 0.91881252, 0.91605598, 0.91324687, 0.91038428, 0.90746727, 0.90449486, 0.90146610, 0.89837998, 0.89523548, 0.89203157, 0.88876719, 0.88544127, 0.88205271, 0.87860038, 0.87508314, 0.87149983, 0.86784926, 0.86413021, 0.86034145, 0.85648171, 0.85254969, 0.84854409, 0.84446355, 0.84030671, 0.83607215, 0.83175845, 0.82736415, 0.82288774, 0.81832769, 0.81368245, 0.80895042, 0.80412996, 0.79921941, 0.79421706, 0.78912115, 0.78392991, 0.77864150, 0.77325405, 0.76776565, 0.76217433, 0.75647809, 0.75067485, 0.74476252, 0.73873894, 0.73260188, 0.72634907, 0.71997818, 0.71348683, 0.70687255, 0.70013282, 0.69326507])

    return np.stack([np.interp(x[:, i], xp, fp) for i in range(x.shape[-1])], axis=-1)


def change_axis(ax):
    xyticks = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
    ticklabels = ['0', '0.2', '0.4', '0.6', '0.8', '1']

    # make the panes transparent
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    
    # make the grid lines transparent
    # ax.xaxis._axinfo["grid"]['color'] = (1,1,0,0)
    # ax.yaxis._axinfo["grid"]['color'] = (1,1,1,0)
    # ax.zaxis._axinfo["grid"]['color'] = (1,1,1,0)
    
    # set limits and ticks
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    ax.set_xticks(xyticks, ticklabels)
    ax.set_yticks(xyticks, ticklabels)

    ax.view_init(elev=25, azim=-125)


def plot_probs(sources, f1, f2, c, args, iter):
    cmap = cm.jet
    n = int(np.sqrt(len(sources)))

    # ln_nc = ln_norm_const_box_lp1(sources, b=args.c_param)
    # ln_marginal = np.sum(np.log(uniform_laplace_marginal(sources, b=args.c_param)), axis=-1)
    # gt = ln_nc + ln_marginal
    # if args.loss == 'nce':
    #     c -= np.log(args.batch_size - 1)
    
    # ln_nc = ln_nc.reshape(n, n)
    # ln_marginal = ln_marginal.reshape(n, n)
    # gt = gt.reshape(n, n)

    sources = sources.reshape(n, n, 2)
    f1 = f1.reshape(n, n)
    f2 = f2.reshape(n, n)

    fig = plt.figure(figsize=(16, 8), dpi=100)
    ax1 = fig.add_subplot(2, 2, 1, projection='3d')
    ax1.plot_surface(sources[:, :, 0], sources[:, :, 1], f1, cmap=cmap, linewidth=0, antialiased=False)
    ax2 = fig.add_subplot(2, 2, 2, projection='3d')
    ax2.plot_surface(sources[:, :, 0], sources[:, :, 1], f2, cmap=cmap, linewidth=0, antialiased=False)    
    # ax3 = fig.add_subplot(2, 2, 3, projection='3d')
    # ax3.plot_surface(sources[:, :, 0], sources[:, :, 1], ln_nc, cmap=cmap, linewidth=0, antialiased=False)
    # ax4 = fig.add_subplot(2, 2, 4, projection='3d')
    # ax4.plot_surface(sources[:, :, 0], sources[:, :, 1], ln_marginal, cmap=cmap, linewidth=0, antialiased=False)    

    ax1.view_init(elev=30, azim=45 + 3*iter)
    ax2.view_init(elev=30, azim=45 + 3*iter)
    # ax3.view_init(elev=30, azim=45 + 3*iter)
    # ax4.view_init(elev=30, azim=45 + 3*iter)

    plt.tight_layout(h_pad=-3, w_pad=-3, rect=[0, 0, 1, 1])
    img = fig_to_img()
    plt.close()
    return img


def plot_critic_sphere(sources, features, f1, f2, c, args, iter):
    fig = plt.figure(figsize=(16, 8), dpi=100)
    ax1 = fig.add_subplot(2, 2, 1, projection='3d')
    ax1.scatter(sources[:, 0], sources[:, 1], sources[:, 2], c=get_color(f1), marker='o', s=6)
    ax2 = fig.add_subplot(2, 2, 2, projection='3d')
    ax2.scatter(sources[:, 0], sources[:, 1], sources[:, 2], c=get_color(f2), marker='o', s=6)
    ax3 = fig.add_subplot(2, 2, 3, projection='3d')
    ax3.scatter(features[:, 0], features[:, 1], features[:, 2], c=get_color(f1), marker='o', s=6)
    ax4 = fig.add_subplot(2, 2, 4, projection='3d')
    ax4.scatter(features[:, 0], features[:, 1], features[:, 2], c=get_color(f2), marker='o', s=6)

    ax1.view_init(elev=30, azim=45 + 3*iter)
    ax2.view_init(elev=30, azim=45 + 3*iter)
    ax3.view_init(elev=30, azim=45 + 3*iter)
    ax4.view_init(elev=30, azim=45 + 3*iter)

    plt.tight_layout(h_pad=-3, w_pad=-3, rect=[0, 0, 1, 1])
    img = fig_to_img()
    plt.close()
    return img
