import os

import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd
from sklearn import metrics

import torch
torch.manual_seed(0)

from algos_DR import *
from entropic_affinity import *
from self_sinkhorn import *

from PIL import Image
from torchvision import transforms

def COIL_dataset(dir=None):
    if dir is None:
        dir = '/coil-20-proc'
    n = 1440
    p = 16384
    X = torch.empty((n,p), dtype=torch.double)
    Y = torch.empty(n)
    imgs = []
    for i,filename in enumerate(os.listdir(dir)):
        img = Image.open(os.path.join(dir, filename))
        imgs.append(img)
        convert_tensor = transforms.ToTensor() 
        X[i] = convert_tensor(img)[0].view(-1).double()
        if filename[4]=='_':
            Y[i] = int(filename[3])
        else:
            Y[i] = int(filename[3:5])
    scatter_kwargs = {'s': 5, 'alpha': 0.8, 'c': Y, 'cmap': plt.get_cmap('tab20')}
    return X, Y, scatter_kwargs


X_coil, Y_coil, scatter_kwargs_coil = COIL_dataset('coil-20-proc')
X_coil = PCA(X_coil, q=50) #reduce dimensionality
n = X_coil.shape[0]
C_coil = torch.cdist(X_coil,X_coil,2)**2

optim_params = {'max_iter':5000, 'tol':1e-4, 'lr':1e0}

perp=30
P_sne = SNE_affinity(C_coil, perp=perp)
P_sym = se_affinity_dual_ascent2(C_coil, perp=perp, lr=1e1)

H_sne = entropy(P_sne, log=False, ax=-1)
Perp_sne = torch.exp(H_sne - 1)

H_sym = entropy(P_sym, log=False, ax=-1)
Perp_sym = torch.exp(H_sym - 1)

Z0 = torch.normal(0, 1, size=(n, 2), dtype=torch.double)

embed_sne, _ = affinity_coupling(P_sne, Z0.clone(), kernel='gaussian', **optim_params)
embed_sym, _ = affinity_coupling(P_sym, Z0.clone(), kernel='sinkhorn', **optim_params)

fig, axs = plt.subplots(2, 2, figsize=(10,6), gridspec_kw = {'height_ratios':[3,1]})

params = {'text.usetex': True, 
          'text.latex.preamble': [r'\usepackage{cmbright}', r'\usepackage{amsmath}']}
plt.rcParams.update(params)

plt.rc('font', family='Times New Roman')

axs[0,0].scatter(embed_sne[:,0], embed_sne[:,1], **scatter_kwargs_coil)
axs[0,0].set_title(f'Symmetric-SNE (score:{float(metrics.silhouette_score(embed_sne, Y_coil)) : .2f})', font='Times New Roman', fontsize=25)
axs[0,0].set_xticks([-5,5])

axs[0,1].scatter(embed_sym[:,0], embed_sym[:,1], **scatter_kwargs_coil)
axs[0,1].set_title(f'SNEkhorn (score:{float(metrics.silhouette_score(embed_sym, Y_coil)) : .2f})', font='Times New Roman', fontsize=25)
axs[0,1].set_xticks([-10,10])


scatter_kwargs_coil_perm = scatter_kwargs_coil.copy()
scatter_kwargs_coil_perm['c'] = [i//(n//20)+1 for i in range(n)]
perm = torch.argsort(Y_coil)
axs[1,0].set_ylabel(r'$e^{\operatorname{H}(\mathbf{P}_i)-1}$', fontsize=20)
axs[1,0].set_xlabel('Sample '+r'$i$', font='Times New Roman', fontsize=20)

axs[1,0].scatter(torch.arange(n), Perp_sne[perm], **scatter_kwargs_coil_perm)
axs[1,0].set_title(r'$\overline{\mathbf{P}^{\mathrm{e}}}$', fontsize=20)
axs[1,0].set_yticks([30,120,260])
axs[1,1].scatter(torch.arange(n), Perp_sym[perm], **scatter_kwargs_coil_perm)
axs[1,1].set_title(r'$\mathbf{P}^{\mathrm{se}}$', fontsize=20)
axs[1,1].set_ylim(29,31)

axs[1,1].set_xlabel('Sample '+r'$i$', font='Times New Roman', fontsize=20)
plt.savefig('fig_coil.pdf', bbox_inches='tight')
plt.show()