#!/usr/bin/env python3
#!/usr/bin/env python3
import argparse
import time

from torch import cuda
from torch.nn.init import xavier_uniform_
from data import Dataset
from utils import *
from tsne_torch import TorchTSNE as TSNE
import matplotlib
from matplotlib import pyplot as plt
from cycler import cycler
matplotlib.rcParams['axes.prop_cycle'] = cycler('color', [
    '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
    '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
    '#bcbd22', '#17becf'])

parser = argparse.ArgumentParser()
# Data path options
parser.add_argument('--val_file', default='data/ptb-val.pkl')
parser.add_argument('--train_file', default='data/ptb-train.pkl')
parser.add_argument('--save_path', default='latent-embed.pt', help='where to save the model')


def main(args):
    # np.random.seed(args.seed)
    # torch.manual_seed(args.seed)
    val_data = Dataset(args.val_file)
    train_data = Dataset(args.train_file)
    checkpoint = torch.load(args.save_path)
    model = checkpoint['model'].cuda()
    train_data.word2idx = checkpoint['word2idx']
    train_data.idx2word = checkpoint['idx2word']
    vocab_size = model.embed.embedding.weight.size(0) - 1

    with torch.no_grad():
        model.eval()
        z, _ = model.embed.encode(
                torch.arange(
                    vocab_size, dtype=torch.long,
                    device=model.embed.embedding.weight.device)[None,
                        :].repeat(200, 1))
        # print(z.size())
        z_std = z.std((0, 1))
        z_mean = z.mean((0, 1))

        z_norm = (z - z_mean) / z_std
        z_norm_cat_std = z_norm.std(0).mean(-1).mean(-1)
        print("Normalised category std:", z_norm_cat_std.item())
        """
        print("Generating t-SNE...")
        print("-" * 20)
        k = 25
        print("Top %d words:" % k)
        _, idxs = torch.topk(model.embed._x_prior[:-1], k=k)
        for i in idxs:
            print(train_data.idx2word[i.item()])
        print("-" * 20)
        """
        idxs = [train_data.word2idx[w]
                for w in [
                    'N', 'one', 'two', 'three',
                    'canada', 'singapore',
                    'walk', 'walked', 'fly', 'flew']]
        z = z[:, idxs]
        cat_samples, words, dim = z.size()
        z_flat = z.detach().flatten(0, 1)
        z_tsne = TSNE(n_components=2, perplexity=30., n_iter=1000, verbose=True).fit_transform(z_flat)
    z_tsne = torch.tensor(z_tsne)
    z_tsne = z_tsne.view(cat_samples, words, -1)
    plt.rc('font', size=15)
    plt.rc('axes', linewidth=2)
    plt.figure(figsize=(10, 10))
    for i in range(words):
        plt.scatter(z_tsne[:, i, 0],
                    z_tsne[:, i, 1],
                    label=train_data.idx2word[idxs[i]])
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.01), fancybox=True,
               ncol=5)
    # plt.axis('off')
    ax = plt.gca()
    ax.axes.xaxis.set_ticks([])
    ax.axes.yaxis.set_ticks([])
    plt.tight_layout()
    plt.savefig(args.save_path + '.png')


if __name__ == '__main__':
    args = parser.parse_args()
    main(args)

