import torch 
torch.manual_seed(0)

import numpy as np
import matplotlib.pyplot as plt

from algos_DR import *
from entropic_affinity import *

from matplotlib.colors import PowerNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import ticker as mticker

n = 1000
m = 10000

z1 = np.random.uniform(size=m)
p1 = z1/z1.sum()
z2 = np.random.uniform(size=m)
p2 = z2/z2.sum()

X1 = np.random.multinomial(1000, pvals=p1, size=int(n/2))
X1 = X1/X1.sum(-1, keepdims=True)
X2 = np.random.multinomial(1000, pvals=p2, size=int(n/4))
X2 = X2/(X2.sum(-1, keepdims=True))
X3 = np.random.multinomial(2000, pvals=p2, size=int(n/4))
X3 = X3/(X3.sum(-1, keepdims=True))

X = torch.from_numpy(np.concatenate((X1,X2,X3),0))
X = (X-X.mean())/X.std()
Y = [0]*int(n/2) + [1]*int(n/2)
Z = torch.normal(0, 1, size=(n, 2), dtype=torch.double)

C = torch.cdist(X,X)**2

perp=30
P_SNE = SNE_affinity(C, perp)
P_sym, _ = se_affinity_dual_ascent(C, perp=perp, tolog=True, tol=1e-5, lr=1e2)

#SNE
Z_SNE, log_SNE = affinity_coupling(P_SNE, Z.clone(), kernel='gaussian')

#Symmetric Entropic Affinities + Sinkhorn kernel in Z
Z_snekhorn, log_snekhorn = affinity_coupling(P_sym, Z.clone(), kernel='sinkhorn')


params = {'text.usetex': True, 
          'text.latex.preamble': [r'\usepackage{cmbright}', r'\usepackage{amsmath}']}
plt.rcParams.update(params)
vmin = 1e-6
vmax = 1e-3
plt.rc('font', family='Times New Roman')
fs = 20
imshow_kwargs = {'cmap':'Blues', 'norm':PowerNorm(1.5, vmin=vmin, vmax=vmax)}

fig, axs = plt.subplots(2, 2, figsize=(7,7), constrained_layout=True) #gridspec_kw = {'height_ratios':[3,1]})

im0 = axs[0,0].imshow(P_SNE, aspect="auto", **imshow_kwargs)
axs[0,0].set_title(r'$\overline{\mathbf{P}^{\mathrm{e}}}$', fontsize=fs)
axs[0,0].set_xticks([])
axs[0,0].set_yticks([])

im1 = axs[0,1].imshow(P_sym, aspect="auto", **imshow_kwargs)
axs[0,1].set_title(r'$\mathbf{P}^{\mathrm{se}}$', fontsize=fs)
axs[0,1].set_xticks([])
axs[0,1].set_yticks([])
divider = make_axes_locatable(axs[0,1])
cax = divider.append_axes("right", size="3%", pad=0.03)
cb = fig.colorbar(im1, cax=cax, orientation='vertical')
ticks_loc = cb.ax.get_yticks().tolist()
cb.ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
cb.ax.set_yticklabels([f'{i:.0e}' for i in cb.get_ticks()])
cb.ax.tick_params(labelsize=fs-3)

axs[1,0].scatter(Z_SNE[np.array(Y) == 0,0], Z_SNE[np.array(Y) == 0,1], alpha=0.7, c='blue', edgecolor='k', label='$\mathbf{p}_1$')
axs[1,0].scatter(Z_SNE[np.array(Y) == 1,0], Z_SNE[np.array(Y) == 1,1], alpha=0.7, c='red', edgecolor='k', label='$\mathbf{p}_2$')
axs[1,0].set_title('Symmetric-SNE', fontsize=fs)
axs[1,1].scatter(Z_snekhorn[np.array(Y) == 0,0], Z_snekhorn[np.array(Y) == 0,1], alpha=0.7, c='blue', edgecolor='k', label='$\mathbf{p}_1$')
axs[1,1].scatter(Z_snekhorn[np.array(Y) == 1,0], Z_snekhorn[np.array(Y) == 1,1], alpha=0.7, c='red', edgecolor='k', label='$\mathbf{p}_2$')
axs[1,1].set_title('SNEkhorn', fontsize=fs)
axs[1,1].legend(loc = 'lower right', fontsize=fs-3)

plt.savefig('heteroscedastic_noise.pdf', bbox_inches='tight') 
