import pickle, csv, os
import word2vec, bow_data
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import argparse

parser = argparse.ArgumentParser(description='Make plots')

parser.add_argument('--plot_folder', type=str, default="results/", help="Folder to put plots in.")

## Experiment files
parser.add_argument('--lda_df', type=str, default="results/lda.pkl", help="LDA dataframe file.")
parser.add_argument('--bow_df', type=str, default="results/bow.pkl", help="BOW dataframe file.")
parser.add_argument('--bow_svd_df', type=str, default="results/bow_svd.pkl", help="BOW-SVD dataframe file.")
parser.add_argument('--word2vec_df', type=str, default="results/word2vec.pkl", help="BOW dataframe file.")
parser.add_argument('--contrast_df', type=str, default="results/contrast.pkl", help="Contrastive learning dataframe file.")
parser.add_argument('--landmarks_df', type=str, default="results/landmarks.pkl", help="Landmarks dataframe file.")

## Embeddings files
parser.add_argument('--data_path', type=str, default="data/", help="Data path.")
parser.add_argument('--word2vec_embeddings', type=bool, default=True, help="Do we compute word2vec embeddings.")
parser.add_argument('--contrast_embedding_file', type=str, default="results/contrast/epoch_600/test.npy", help="Contrastive learning embeddings file.")

args = parser.parse_args()
df = pd.DataFrame(columns = ['Training examples', 'Representation', 'Test Accuracy', 'Train Accuracy'])
for fname in [args.lda_df, args.bow_df, args.bow_svd_df, args.word2vec_df, args.contrast_df, args.landmarks_df]:
    if(fname is not None):
        df = pd.concat([df, pd.read_pickle(fname)])

## Make comparison plots
plt.clf()
sns.lineplot(x="Training examples", y="Test Accuracy", hue="Representation", lw=2, ci=95, data=df)
plt.savefig(os.path.join(args.plot_folder, 'comparison.pdf'), bbox_inches='tight')

###### TSNE plot stuff #######
def load_labels(filename):
    Y = []
    with open(filename) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        next(csv_reader, None) ## skip the header
        for row in csv_reader:
            Y += [int(row[1])]
    return(np.array(Y))

def build_df(X, Y):
    labels = np.repeat("Business  ", len(Y))
    for k, word in [(1,"World"), (2,"Sports"), (3,"Business"), (4,"Sci/Tech")]:
        label_inds = np.where(Y==k)[0]
        labels[label_inds] = word
        
    
    data = np.hstack((X, labels[:, np.newaxis]))
    df = pd.DataFrame(data, columns = ['X1', 'X2', 'Category']) 
    df["X1"] = pd.to_numeric(df["X1"])
    df["X2"] = pd.to_numeric(df["X2"])
    return(df)

## Make t-SNE plots (Plot the test examples)
if(args.word2vec_embeddings and (args.data_path is not None)):
    embedding_file = os.path.join(args.data_path, "skipEmbeddings.npy")
    word2ind_file = os.path.join(args.data_path, "word2ind.pkl")
    embedding = np.load(embedding_file)
    with open(word2ind_file, 'rb') as f:
        word2ind = pickle.load(f)
    
    data_file = os.path.join(args.data_path, 'test.csv')
    X_word2vec = word2vec.word2vec_document_embedding(data_file, embedding, word2ind)
    Y = load_labels(data_file)

    ## Compute word2vec tsne embedding
    X_word2vec_embedded = TSNE(n_components=2).fit_transform(X_word2vec)
    df = build_df(X_word2vec_embedded, Y)

    ## Make plot
    plt.clf()
    ax = sns.scatterplot(x="X1", y="X2", hue="Category", linewidth=0, palette="muted", s=3, data=df, legend='full')
    plt.setp(ax.get_legend().get_texts(), fontsize='7.5')

    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[1:], labels=labels[1:])
    ax.set_ylabel('')    
    ax.set_xlabel('')
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    plt.savefig(os.path.join(args.plot_folder, 'word2vec_tsne.pdf'), bbox_inches='tight')



if((args.contrast_embedding_file is not None) and (args.data_path is not None)):
    X_contrast = np.load(args.contrast_embedding_file)
    Y = load_labels(os.path.join(os.path.join(args.data_path, 'test.csv')))

    ## Compute word2vec tsne embedding
    X_contrast_embedded = TSNE(n_components=2).fit_transform(X_contrast)
    df = build_df(X_contrast_embedded, Y)

    ## Make plot
    plt.clf()
    ax = sns.scatterplot(x="X1", y="X2", hue="Category", linewidth=0, palette="muted", s=3, data=df, legend='full')
    plt.setp(ax.get_legend().get_texts(), fontsize='7.5')

    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[1:], labels=labels[1:])
    ax.set_ylabel('')    
    ax.set_xlabel('')
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    plt.savefig(os.path.join(args.plot_folder, 'contrast_tsne.pdf'), bbox_inches='tight')




