from __future__ import print_function
import time

import numpy as np
import pandas as pd

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import matplotlib.pyplot as plt

import seaborn as sns


"""
Utility file for visualizing embeddings
"""

class Heatmap:
    def __init__(self, data, text=None):
        if type(data) == pd.DataFrame:
            self.data = data
        else:
            # annotate data according to text
            if text is not None:
                self.data = pd.DataFrame(data, index=text, columns=text)
            else:
                self.data = pd.DataFrame(data)

    def plot(self, fig_size=(16, 10)):
        plt.figure(figsize=fig_size)
        sns.heatmap(self.data.corr(), annot=True, cmap='coolwarm', fmt=".2f")
        plt.show()


class PCA_TSNE:
    def __init__(self, n_components, perplexity=40, n_iter=300):
        self.n_components = n_components
        self.pca = PCA(n_components=n_components)
        self.tsne = TSNE(n_components=n_components, verbose=1, perplexity=perplexity, n_iter=n_iter)

    def fit_transform(self, X: pd.DataFrame):
        fit = self.pca.fit_transform(X)
        fit = pd.DataFrame(fit, columns=[f"pca_{i}" for i in range(self.n_components)])
        return fit

    def fit_transform_tsne(self, X: pd.DataFrame, font_size=8):
        fit = self.tsne.fit_transform(X)
        return pd.DataFrame(fit, columns=[f"tsne_{i}" for i in range(self.n_components)])

    def plot(self, data: pd.DataFrame,  axis_1, axis_2, fig_size=(40, 40), labels=None, tags=None, font_size=8):
        plt.rcParams.update({'font.size': 32})
        plt.figure(figsize=fig_size)
        sns.scatterplot(
            x=axis_1, y=axis_2,
            palette=sns.color_palette(),
            data=data,
            legend="full",
            alpha=0.6,
            hue=tags
        )
        for i in range(data.shape[0]):
            plt.text(data[axis_1][i], data[axis_2][i], labels[i], fontsize=font_size, horizontalalignment='center')
        plt.show()
        plt.savefig(f"tsne_{axis_1}_{axis_2}.png")
