import numpy as np
import csv, os, string, pickle
import train, posteriors, sim_data, train
import torch
from train import SynthArgs
import argparse


class Arguments():
    def __init__(self):
        self.data_path = None
        self.results_folder = None
        self.temp_folder = None
        self.docs_folder = None
        self.pretrained_vectors = False
        self.prev_model_file = None

        ## Learning parameters
        self.opt_type = "rms"
        self.t_extra_words = 1
        self.lr = 0.0001
        self.nepochs = 50
        self.dropout_p = 0.5
        self.nfolds = 1
        self.resample = 0

        ## Network structure params
        self.n_layers = 3
        self.embed_dim = 300
        self.c_dim = 100
        self.h_dim = 256

        ## Synthetic experiment arguments
        self.synth_args = None


def generate_embeddings(args):
    model = train.get_predictive_model(data_path=args.data_path,
                                                   results_folder=args.results_folder,
                                                   c_dim=args.c_dim,
                                                   h_dim=args.h_dim,
                                                   nepochs=args.nepochs,
                                                   lr=args.lr,
                                                   embed_dim=args.embed_dim,
                                                   opt_type=args.opt_type, 
                                                   t_extra_words=args.t_extra_words,
                                                   dropout_p=args.dropout_p,
                                                   n_layers=args.n_layers,
                                                   nfolds=args.nfolds,
                                                   resample=args.resample,
                                                   pretrained_vectors=args.pretrained_vectors,
                                                   temp_model_folder=args.temp_folder,
                                                   prev_model_file=args.prev_model_file,
                                                   presampled_docs_file=args.docs_folder,
                                                   synthetic_args=args.synth_args)
    return (model)


def run(t=1, epochs=50, hdim=512, layers=3, alphas=[5.0]):
    for alpha in alphas:
        ## Fit the model
        args = Arguments()
        args.temp_folder = "models"  ## Temporary folder to hold intermediate models
        if (not os.path.isdir(args.temp_folder)):
            os.mkdir(args.temp_folder)

        args.data_path = "data/alpha" + str(alpha)  ## Folder to hold generated data
        if (not os.path.isdir(args.data_path)):
            os.mkdir(args.data_path)

        ## Generate data
        topics, id2word = sim_data.make_data(alpha, args.data_path, n_docs=10)

        ## Folder to hold results
        args.results_folder = "results/alpha" + str(alpha)
        if (not os.path.isdir(args.results_folder)):
            os.mkdir(args.results_folder)

        ## NN parameters
        args.n_layers = layers  ## number of layers
        args.c_dim = 512  ## final layer dimension
        args.embed_dim = 512  ## embedding layer dimension
        args.h_dim = hdim  ## hidden layer dimension
        args.t_extra_words = t  ## the number of target words per document

        args.nepochs = epochs  ## number of epochs
        args.lr = 0.0002  ## learning rate
        args.dropout_p = 0.0  ## dropout
        args.resample = 2  ## frequency of resampling data
        args.opt_type = "rms"  ## optimizer

        ## Arguments for synthetic topic modeling
        args.synth_args = SynthArgs(topics, 36, id2word)

        model = generate_embeddings(args)

        ## Construct the posterior
        #Eta, L_ref, L_data, X, doc_topics, topic_l1, topic_l2, post_ent, TV, accuracy = posteriors.construct_posteriors(
        #    args.data_path, model, nref=landmarks)
        Eta, L_data, TV, accuracy = posteriors.construct_posteriors(args.data_path, model)


        results = {}
        # results["Eta"] = Eta                ## Model-based posterior matrix
        # results["L_ref"] = L_ref            ## Likelihood matrix of reference documents
        # results["L_data"] = L_data          ## Normalized likelihood (posterior) matrix of documents
        # results["X"] = X                    ## Predictions of model
        # results["topics"] = topics          ## topics
        # results["doc_topics"] = doc_topics  ## True generating topics
        # results["topic_l1"] = topic_l1      ## Average pairwise L1 distance among topics
        # results["topic_l2"] = topic_l2      ## Average pairwise L2 distance among topics
        # results["post_ent"] = post_ent      ## Average entropy of true posterior vectors
        results["TV"] = TV  ## Total variation distance between model-based posteriors and true posteriors
        results["accuracy"] = accuracy  ## Accuracy of model-based MAP estimate on the generating top[ic]

        ## Save out results
        rand_tag = "".join(list(np.random.choice(list(string.ascii_lowercase), 10)))
        input_names = "-".join([str(t)+'words', 'hdim' + str(hdim), 'ly' + str(layers), str(epochs) + 'epochs'])
        fname = input_names + '_' + rand_tag + "_posterior.pkl"
        print(fname)
        print('TV: ', TV)
        print('Accuracy', accuracy)
        with open(os.path.join(args.results_folder, fname), 'wb') as f:
            pickle.dump(results, f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('t', type=int, help='number of extra words per document')
    parser.add_argument('hdim', type=int, help='number of nodes per hidden layer')
    parser.add_argument('layers', type=int, help='number of hidden layers')
    parser.add_argument('epochs', type=int, help='number of epochs during training')

    input_args = parser.parse_args()
    t = input_args.t
    num_epochs = input_args.epochs
    hdim = input_args.hdim
    num_layers = input_args.layers

    run(t=t, epochs=num_epochs, hdim=hdim, layers=num_layers, alphas=[1.0, 3.0, 5.0, 7.0, 9.0])
