import numpy as np
import pandas as pd
from sklearn.decomposition import TruncatedSVD
import csv, os, string, pickle
import contrastive, word2vec, lda, bow_data, linear
import argparse


parser = argparse.ArgumentParser(description='Contrastive topic modeling')

## Data arguments
parser.add_argument('--data_path', type=str, default="data/", help="Data path.")

## Results folder
parser.add_argument('--results_folder', type=str, default="results/", help='Folder to save out results')
parser.add_argument('--temp_folder', type=str, default="temp/", help='Folder to hold temporary files')
parser.add_argument('--docs_folder', type=str, default=None, help='Folder of precomputed contrastive datasets')

## Experiment type argument
parser.add_argument('--exp_type', type=str, default='contrast', help='Type of experiment to run. Choose one of contrast, lda, bow, word2vec')

### Contrastive learning arguments
parser.add_argument('--lr', type=float, default=0.0001, help='Starting learning rate for contrastive learning.')
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate.')

parser.add_argument('--c_dim', type=int, default=300, help='Dimension of contrastive representation')
parser.add_argument('--nlandmarks', type=int, default=8000, help='Number of landmarks to use')
parser.add_argument('--h_dim', type=int, default=256, help='Dimension of contrastive representation')
parser.add_argument('--nepochs', type=int, default=600, help='Number of epochs to run for contrastive learning')
parser.add_argument('--embed_dim', type=int, default=300, help='Size of embedding layer')
parser.add_argument('--n_layers', type=int, default=3, help='Number of hidden layers')
parser.add_argument('--resample', type=int, default=3, help='Resampling rate for contrastive representation.')

parser.add_argument('--opt_type', type=str, default="rms", help='Optimization method for contrastive learning (rms, adam, sgd, adagrad)')
parser.add_argument('--prev_model_file', type=str, default=None, help='File name for previous model .')


### word2vec arguments
parser.add_argument('--e_dim', type=int, default=300, help='Dimension of word embeddings')
parser.add_argument('--negative', type=int, default=10, help='Number of negative samples')
parser.add_argument('--window', type=int, default=4, help='Size of window')
parser.add_argument('--iter', type=int, default=10, help='Number of iterations over unsupervised dataset')


### LDA arguments
parser.add_argument('--n_topics', type=int, default=300, help='Number of topics to use')
parser.add_argument('--passes', type=int, default=30, help='Number of passes over dataset')


### BOW arguments
parser.add_argument('--svd_dim', type=int, default=0, help='Number of dimensions to project BOW data')

## Logistic regression arguments
parser.add_argument('--nfracs', type=int, default=25, help='Number of intervals to consider.')
parser.add_argument('--nreps', type=int, default=10, help='Number of repetitions.')





args = parser.parse_args()

exp_type = args.exp_type

## Check to see datasets are correct.
have_train_test = all([os.path.isfile(os.path.join(args.data_path, f)) for f in ["train.csv", "test.csv"]])
assert have_train_test, "Need train.csv and test.csv files."
if(args.exp_type == "contrast"):
    bow_files = ["bow_test_counts", "bow_test_tokens", "bow_train_counts", "bow_train_tokens",
                        "bow_valid_counts", "bow_valid_tokens", "bow_unsup_counts", "bow_unsup_tokens"]
    have_all_files = all([os.path.isfile(os.path.join(args.data_path, f)) for f in bow_files])
    assert have_all_files, "Need to compute bag-of-word representations for contrast."
elif(args.exp_type == "lda" or args.exp_type == "word2vec"):
    have_unsupervised_csv = os.path.isfile(os.path.join(args.data_path, "unsupervised.csv"))
    assert have_unsupervised_csv, "Need unsupervised.csv for lda/word2vec."


if(args.exp_type == "contrast"):
    results_folder = os.path.join(args.results_folder, "contrast")
    if(not os.path.isdir(results_folder)):
        os.mkdir(results_folder)

    folder = os.path.join(results_folder, "_".join(["epoch", str(args.nepochs)]))
    train_file = os.path.join(folder, "train.npy")
    test_file = os.path.join(folder, "test.npy")
    already_computed = all([os.path.isfile(f) for f in [train_file, test_file]])
    if(not already_computed):
        _ = contrastive.contrastive_representation(data_path=args.data_path,
                                                results_folder=results_folder,
                                                c_dim=args.c_dim, 
                                                h_dim=args.h_dim,
                                                nepochs=args.nepochs,
                                                landmarks=args.nlandmarks,
                                                lr=args.lr, 
                                                embed_dim=args.embed_dim, 
                                                opt_type=args.opt_type,
                                                dropout_p=args.dropout,
                                                n_layers=args.n_layers,
                                                resample=args.resample,
                                                temp_model_folder=args.temp_folder, 
                                                prev_model_file=args.prev_model_file,
                                                presampled_docs_file=args.docs_folder)
    ## Retrieve representation.
    X_train = np.load(train_file)
    X_test = np.load(test_file)

elif(args.exp_type == "landmarks"):
    results_folder = os.path.join(args.results_folder, "landmarks")
    if(not os.path.isdir(results_folder)):
        os.mkdir(results_folder)

    folder = os.path.join(results_folder, "_".join(["epoch", str(args.nepochs)]))
    train_file = os.path.join(folder, "train.npy")
    test_file = os.path.join(folder, "test.npy")
    already_computed = all([os.path.isfile(f) for f in [train_file, test_file]])
    if(not already_computed):
        _ = contrastive.contrastive_representation(data_path=args.data_path,
                                                results_folder=results_folder,
                                                c_dim=args.c_dim, 
                                                h_dim=args.h_dim,
                                                nepochs=args.nepochs, 
                                                lr=args.lr, 
                                                embed_dim=args.embed_dim, 
                                                opt_type=args.opt_type,
                                                dropout_p=args.dropout,
                                                n_layers=args.n_layers,
                                                resample=args.resample,
                                                temp_model_folder=args.temp_folder, 
                                                prev_model_file=args.prev_model_file,
                                                presampled_docs_file=args.docs_folder)
    ## Retrieve representation.
    X_train = np.load(train_file)
    X_test = np.load(test_file)

elif(args.exp_type == "word2vec"):
    embedding_file = os.path.join(args.data_path, "skipEmbeddings.npy")
    word2ind_file = os.path.join(args.data_path, "word2ind.pkl")

    already_computed = all([os.path.isfile(f) for f in [embedding_file, word2ind_file]])
    if(not already_computed):
        word2vec.build_word2vec(data_path=args.data_path, 
                                dim=args.e_dim, 
                                negative=args.negative, 
                                window=args.window, 
                                iter=args.iter)

    ## Load up files
    embedding = np.load(embedding_file)
    with open(word2ind_file, 'rb') as f:
        word2ind = pickle.load(f)

    X_train = word2vec.word2vec_document_embedding(filename= os.path.join(args.data_path, "train.csv"), 
                                                    embedding=embedding, 
                                                    word2ind=word2ind)

    X_test = word2vec.word2vec_document_embedding(filename= os.path.join(args.data_path, "test.csv"), 
                                                    embedding=embedding, 
                                                    word2ind=word2ind)

elif(args.exp_type == "lda"):
    results_folder = os.path.join(args.results_folder, "lda")
    if(not os.path.isdir(results_folder)):
        os.mkdir(results_folder)

    train_file = os.path.join(results_folder, "train.npy")
    test_file = os.path.join(results_folder, "test.npy")
    already_computed = all([os.path.isfile(f) for f in [train_file, test_file]])
    if(not already_computed):
        lda.build_lda_representation(data_path=args.data_path, 
                                        results_folder=results_folder, 
                                        n_topics=args.n_topics, 
                                        passes=args.passes, 
                                        tok=str.split)
    ## Retrieve representation.
    X_train = np.load(train_file)
    X_test = np.load(test_file)

else: ## bag-of-word representation
    vocab, train, test, unsup, _ = bow_data.get_data(args.data_path)

    X_train = bow_data.sparse_matrix_format(tokens=train['tokens'], 
                                            counts=train['counts'], 
                                            vocab_size=len(vocab))
    X_test = bow_data.sparse_matrix_format(tokens=test['tokens'], 
                                            counts=test['counts'], 
                                            vocab_size=len(vocab))

    if(args.svd_dim > 0 and args.svd_dim < len(vocab)):
        X_unsup = bow_data.sparse_matrix_format(tokens=unsup['tokens'], 
                                                counts=unsup['counts'], 
                                                vocab_size=len(vocab))

        svd = TruncatedSVD(n_components=args.svd_dim)
        svd.fit(X_unsup)
        X_train = svd.transform(X_train)
        X_test = svd.transform(X_test)

        exp_type = "bow_svd"
    

def load_labels(filename):
    Y = []
    with open(filename) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        next(csv_reader, None) ## skip the header
        for row in csv_reader:
            Y += [int(row[1])]
    return(np.array(Y))

Y_train = load_labels(os.path.join(args.data_path, 'train.csv'))
Y_test = load_labels(os.path.join(args.data_path, 'test.csv'))

## Do we use scaler with mean?
with_mean = (args.exp_type != 'bow' )


data = np.empty((0,4))
num_examples = np.array([int(frac*len(Y_train)) for frac in np.linspace(0.025, 1, args.nfracs)])
for _ in range(args.nreps):
    test_accuracies, train_accuracies = linear.compute_accuracies(X_train, Y_train, X_test, Y_test, nfracs=args.nfracs, with_mean=with_mean, CV_=True)
    curr_data = np.hstack([num_examples[:, np.newaxis], 
                            np.repeat(exp_type, len(num_examples))[:, np.newaxis], 
                            test_accuracies[:, np.newaxis],
                            train_accuracies[:, np.newaxis]])
    data = np.vstack((data, curr_data))

df = pd.DataFrame(data, columns = ['Training examples', 'Representation', 'Test Accuracy', 'Train Accuracy'])
df["Training examples"] = pd.to_numeric(df["Training examples"])
df["Test Accuracy"] = pd.to_numeric(df["Test Accuracy"])
df["Train Accuracy"] = pd.to_numeric(df["Train Accuracy"])

## Save out results
filename = os.path.join(args.results_folder, ".".join([exp_type, "pkl"]))
df.to_pickle(filename)