import numpy as np
import csv, os, string, pickle
import train, posteriors, sim_data, train
import torch
from train import SynthArgs
import argparse


class Arguments():
    def __init__(self):
        self.data_path = None
        self.results_folder = None
        self.temp_folder = None
        self.docs_folder = None
        self.pretrained_vectors = False
        self.prev_model_file = None

        ## Learning parameters
        self.opt_type = "rms"
        self.t_extra_words = 1
        self.lr = 0.0001
        self.nepochs = 50
        self.dropout_p = 0.5
        self.nfolds = 1
        self.resample = 0

        ## Network structure params
        self.n_layers = 3
        self.embed_dim = 300
        self.c_dim = 100
        self.h_dim = 256
        self.optimize_embeddings = True
        self.embedding_type = None

        ## Synthetic experiment arguments
        self.synth_args = None


def generate_embeddings(args, window, pretrain_docs):
    model = train.get_predictive_model(data_path=args.data_path,
                                                   results_folder=args.results_folder,
                                                   c_dim=args.c_dim,
                                                   h_dim=args.h_dim,
                                                   nepochs=args.nepochs,
                                                   lr=args.lr,
                                                   embed_dim=args.embed_dim,
                                                   opt_type=args.opt_type,
                                                   t_extra_words=args.t_extra_words,
                                                   dropout_p=args.dropout_p,
                                                   n_layers=args.n_layers,
                                                   nfolds=args.nfolds,
                                                   resample=args.resample,
                                                   pretrained_vectors=args.pretrained_vectors,
                                                   temp_model_folder=args.temp_folder,
                                                   prev_model_file=args.prev_model_file,
                                                   synthetic_args=args.synth_args,
                                                   optimize_emb=args.optimize_embeddings,
                                                   emb_type=args.embedding_type,
                                                   w2vwindow=window,
                                                   pretrain_docs=pretrain_docs
                                       )
    return (model)


def run(t=1, epochs=50, hdim=512, layers=3, alphas=[5.0], N_test=50):
    lamb = 60
    embed_dim = 4096 
    embed_type = 'word2vec'
    window = 10
    for alpha in alphas:
        ## Fit the model
        args = Arguments()
        args.temp_folder = "models"  ## Temporary folder to hold intermediate models
        if (not os.path.isdir(args.temp_folder)):
            os.mkdir(args.temp_folder)

        args.data_path = "data/alpha" + str(alpha) + '_%dwordsdoc'%lamb ## Folder to hold generated data
        if (not os.path.isdir(args.data_path)):
            os.mkdir(args.data_path)

        ## Generate data
        topics, id2word, pretrain_docs = sim_data.make_data(alpha, args.data_path, emb_type=embed_type, N_train=60000, N_test=N_test, lamb=lamb)

        ## Folder to hold results
        args.results_folder = "results/alpha" + str(alpha) + '_%dwordsdoc'%lamb
        if (not os.path.isdir(args.results_folder)):
            os.mkdir(args.results_folder)

        ## NN parameters
        args.n_layers = layers  ## number of layers
        args.c_dim = 512  ## final layer dimension
        args.embed_dim = embed_dim  ## embedding layer dimension
        args.h_dim = hdim  ## hidden layer dimension
        args.t_extra_words = t  ## the number of target words per document
        args.optimize_embeddings = True  ## optimize word embeddings or not
        args.embedding_type = embed_type  ## embeddig type

        args.nepochs = epochs  ## number of epochs
        args.lr = 0.0002  ## learning rate
        args.dropout_p = 0.0  ## dropout
        args.resample = 2  ## frequency of resampling data
        args.opt_type = "amsgrad"#"rms"  ## optimizer

        ## Arguments for synthetic topic modeling
        args.synth_args = SynthArgs(topics, lamb, id2word)

        model = generate_embeddings(args, window, pretrain_docs)

        ## Construct the posterior
        #Eta, L_ref, L_data, X, doc_topics, topic_l1, topic_l2, post_ent, TV, accuracy = posteriors.construct_posteriors(
        #    args.data_path, model, nref=landmarks)
        Eta, L_data, TV, accuracy = posteriors.construct_posteriors(args.data_path, model, N_test)


        results = {}
        # results["Eta"] = Eta                ## Model-based posterior matrix
        # results["L_ref"] = L_ref            ## Likelihood matrix of reference documents
        # results["L_data"] = L_data          ## Normalized likelihood (posterior) matrix of documents
        # results["X"] = X                    ## Predictions of model
        # results["topics"] = topics          ## topics
        # results["doc_topics"] = doc_topics  ## True generating topics
        # results["topic_l1"] = topic_l1      ## Average pairwise L1 distance among topics
        # results["topic_l2"] = topic_l2      ## Average pairwise L2 distance among topics
        # results["post_ent"] = post_ent      ## Average entropy of true posterior vectors
        results["TV"] = TV  ## Total variation distance between model-based posteriors and true posteriors
        results["accuracy"] = accuracy  ## Accuracy of model-based MAP estimate on the generating top[ic]
        results["N_test"] = N_test

        ## Save out results
        rand_tag = "".join(list(np.random.choice(list(string.ascii_lowercase), 10)))
        input_names = "-".join([str(t)+'words', 'hdim' + str(hdim), 'emb'+str(args.embed_dim),'ly' + str(layers), str(epochs) + 'epochs'])
        if args.embedding_type == 'one-hot':
            input_names = "-".join([str(t)+'words', 'hdim'+str(hdim), 'onehotembed', 'ly'+str(layers), str(epochs)+'epochs'])
        elif args.embedding_type == 'word2vec':
            print('word2vec params: embedding_size=%d, window_size=%d'%(embed_dim, window))
            results['word2vec'] = {'embedding_size': embed_dim, 'window_size': window}
            input_names = "-".join([str(t)+'words', 'hdim'+str(hdim), 'w2vembed', 'ly'+str(layers), str(epochs)+'epochs'])
        fname = input_names + '_' + rand_tag + "_posterior.pkl"
        print(fname)
        print(N_test, 'test documents used')
        print('TV: ', TV)
        #print('optimizer:', args.opt_type)

        with open(os.path.join(args.data_path, 'test_topics_dist.pkl'), 'rb') as f:
            prior = pickle.load(f)

        topic_recovery = top_k_ovelap(prior, Eta, 1)
        results["recovery"] = topic_recovery
        print('Topic recovery rate:', topic_recovery)
        with open(os.path.join(args.results_folder, fname), 'wb') as f:
            pickle.dump(results, f)

def top_k_ovelap(y_true, y_pred, k):
    assert y_true.shape == y_pred.shape
    ret = 0
    N, C = y_true.shape
    if k > C:
        print('error')
        return -1
    for i in range(N):
        overlap = len(set(np.argsort(y_pred[i])[-k:])& set(np.argsort(y_true[i])[-k:]))/k
        ret += overlap
    return ret/N

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('tests', type=int, help='number of test documents')
    parser.add_argument('t', type=int, help='number of extra words per document')
    parser.add_argument('hdim', type=int, help='number of nodes per hidden layer')
    parser.add_argument('layers', type=int, help='number of hidden layers')
    parser.add_argument('epochs', type=int, help='number of epochs during training')

    input_args = parser.parse_args()
    test = input_args.tests
    t = input_args.t
    num_epochs = input_args.epochs
    hdim = input_args.hdim
    num_layers = input_args.layers

    run(t=t, epochs=num_epochs, hdim=hdim, layers=num_layers, alphas=[1.0, 3.0, 5.0, 7.0, 9.0], N_test=test)
