import numpy as np
import csv, os, string, pickle
import train, posteriors, sim_data, train_resample, train_two_targets
import torch
from train import SynthArgs
import argparse


class Arguments():
    def __init__(self):
        self.data_path = None
        self.results_folder = None
        self.temp_folder = None
        self.docs_folder = None
        self.pretrained_vectors = False
        self.prev_model_file = None

        ## Learning parameters
        self.opt_type = "rms"
        self.t_extra_words = 1
        self.lr = 0.0001
        self.nepochs = 50
        self.dropout_p = 0.5
        self.nfolds = 1
        self.resample = 0

        ## Network structure params
        self.model_type = 'default'
        self.n_layers = 3
        self.embed_dim = 300
        self.c_dim = 100
        self.h_dim = 256
        self.optimize_embeddings = True
        self.embedding_type = None

        ## Synthetic experiment arguments
        self.synth_args = None


def generate_embeddings(args, reuse_train, two_targets):
    if reuse_train:
        model = train_resample.get_predictive_model(data_path=args.data_path,
                                                   results_folder=args.results_folder,
                                                   model_type=args.model_type,
                                                   c_dim=args.c_dim,
                                                   h_dim=args.h_dim,
                                                   nepochs=args.nepochs,
                                                   lr=args.lr,
                                                   embed_dim=args.embed_dim,
                                                   opt_type=args.opt_type,
                                                   t_extra_words=args.t_extra_words,
                                                   dropout_p=args.dropout_p,
                                                   n_layers=args.n_layers,
                                                   resample=args.resample,
                                                   pretrained_vectors=args.pretrained_vectors,
                                                   temp_model_folder=args.temp_folder,
                                                   prev_model_file=args.prev_model_file,
                                                   presampled_docs_file=args.docs_folder,
                                                   synthetic_args=args.synth_args,
                                                   optimize_emb=args.optimize_embeddings,
                                                   emb_type=args.embedding_type,
                                                   )
    else:
        if two_targets:
            print("Two targets")
            model = train_two_targets.get_predictive_model(data_path=args.data_path,
                                                   results_folder=args.results_folder,
                                                   model_type=args.model_type,
                                                   c_dim=args.c_dim,
                                                   h_dim=args.h_dim,
                                                   nepochs=args.nepochs,
                                                   lr=args.lr,
                                                   embed_dim=args.embed_dim,
                                                   opt_type=args.opt_type,
                                                   t_extra_words=args.t_extra_words,
                                                   dropout_p=args.dropout_p,
                                                   n_layers=args.n_layers,
                                                   resample=args.resample,
                                                   pretrained_vectors=args.pretrained_vectors,
                                                   temp_model_folder=args.temp_folder,
                                                   prev_model_file=args.prev_model_file,
                                                   presampled_docs_file=args.docs_folder,
                                                   synthetic_args=args.synth_args,
                                                   optimize_emb=args.optimize_embeddings,
                                                   emb_type=args.embedding_type,
                                                   )
        else:
            model = train.get_predictive_model(data_path=args.data_path,
                                                   results_folder=args.results_folder,
                                                   model_type=args.model_type,
                                                   c_dim=args.c_dim,
                                                   h_dim=args.h_dim,
                                                   nepochs=args.nepochs,
                                                   lr=args.lr,
                                                   embed_dim=args.embed_dim,
                                                   opt_type=args.opt_type,
                                                   t_extra_words=args.t_extra_words,
                                                   dropout_p=args.dropout_p,
                                                   n_layers=args.n_layers,
                                                   resample=args.resample,
                                                   pretrained_vectors=args.pretrained_vectors,
                                                   temp_model_folder=args.temp_folder,
                                                   prev_model_file=args.prev_model_file,
                                                   presampled_docs_file=args.docs_folder,
                                                   synthetic_args=args.synth_args,
                                                   optimize_emb=args.optimize_embeddings,
                                                   emb_type=args.embedding_type,
                                                   )
    return (model)


def run(extra_words=1, epochs=50, hdim=512, layers=3, alphas=[5.0], N_test=50, save_model=True, constrained_opt=False, 
        gen_new_tests=False, reuse_train_data=False, two_targets=False, lamb=60):
    
    model_type = 'residual-softmax'
    cov = 0.99
    negative_cov = 0

    _, K, V = np.load('TopicMatrices.npy').shape

    sigma = np.diag([1.0] * K)
    for i in range(K//4):
        sigma[4 * i, 4 * i + 2] = sigma[4 * i + 2, 4 * i] = cov
        sigma[4 * i + 1, 4 * i + 3] = sigma[4 * i + 3, 4 * i + 1] = cov
        sigma[4*i, 4*i + 3] = sigma[4*i + 1, 4*i + 2] = sigma[4*i + 2, 4*i + 1] = sigma[4*i + 3, 4*i] = negative_cov
    sigma = sigma * 15
    np.save('cov_matrix.npy', sigma)

    for alpha in alphas:
        ## Fit the model
        args = Arguments()
        args.temp_folder = "models"  ## Temporary folder to hold intermediate model weights
        if not os.path.isdir(args.temp_folder):
            os.mkdir(args.temp_folder)

        if not os.path.isdir('data'):
            os.mkdir('data')
        args.data_path = "data/alpha" + str(alpha) + '_%dwordsdoc'%lamb ## Folder to hold synthesized data
        if not os.path.isdir(args.data_path):
            os.mkdir(args.data_path)

        ## Generate data
        topics, id2word = sim_data.make_data(alpha, args.data_path, N_test=N_test, lamb=lamb, gen_new_tests=gen_new_tests)

        ## Folder to hold results
        if not os.path.isdir('results'):
            os.mkdir('results')
        args.results_folder = "results/alpha" + str(alpha) + '_%dwordsdoc'%lamb
        if (not os.path.isdir(args.results_folder)):
            os.mkdir(args.results_folder)

        ## NN parameters
        args.model_type = model_type
        args.n_layers = layers  ## number of layers
        args.c_dim = 512  ## final layer dimension
        args.embed_dim = 5000  ## embedding layer dimension
        args.h_dim = hdim  ## hidden layer dimension
        args.t_extra_words = extra_words  ## the number of times we sample t target word(s) per document
        args.optimize_embeddings = True  ## optimize word embeddings or not
        args.embedding_type = 'word2vec'  ## initial embedding type

        args.nepochs = epochs  ## number of epochs
        args.lr = 0.0001#0.0002  ## learning rate
        args.dropout_p = 0.0  ## dropout
        args.resample = int(epochs/2) #2  ## frequency of resampling data
        args.opt_type = "amsgrad"  ## optimizer

        ## Arguments for synthetic topic modeling
        args.synth_args = SynthArgs(topics, lamb, id2word)

        ## Train neural network model
        model = generate_embeddings(args, reuse_train_data, two_targets=two_targets)

        rand_tag = ''.join(list(np.random.choice(list(string.ascii_lowercase), 10)))
        input_name = f't={2 if two_targets else 1}_width={hdim}_depth={layers}_{epochs}epochs'
        if save_model:
            if reuse_train_data:
                rand_tag = rand_tag + 'finite'
            if not os.path.isdir('savedmodels'):
                os.mkdir('savedmodels')
            torch.save(model, f'savedmodels/alpha{alpha}_{input_name}_{rand_tag}.pt')

        ## Construct the posterior
        Eta, L_data, TV, accuracy = posteriors.construct_posteriors(args.data_path, model, N_test, constrained_opt=constrained_opt, 
                                                                    two_targets=two_targets, re_calculate=gen_new_tests)

        results = {}
        results["TV"] = TV  ## Total variation distance between model-based posteriors and true posteriors
        results["accuracy"] = accuracy  ## Accuracy of model-based MAP estimate on the generating top[ic]
        results["N_test"] = N_test

        ## Save out results
        fname = input_name + '_' + rand_tag + "_posterior.pkl"
        print(fname)
        print('TV: ', TV)
        with open(os.path.join(args.results_folder, fname), 'wb') as f:
            pickle.dump(results, f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--tests', type=int, help='number of test documents, default=200', default=200)
    parser.add_argument('--times', type=int, help='number of times we sample t target word(s) per document, default=6', default=6)
    parser.add_argument('--hdim', type=int, help='number of neurons per hidden layer in neural network, default=768', default=768)
    parser.add_argument('--layers', type=int, help='number of hidden layers in neural network, default=8', default=8)
    parser.add_argument('--epochs', type=int, help='number of epochs during training, default=200', default=200)
    parser.add_argument('--doclength', type=int, help='average length of synthesized documents, default=60', default=60)
    parser.add_argument('--save_model', action='store_true', help='add this flag to save neural network weight')
    parser.add_argument('--constrained_opt', action='store_true', help='add this flag to use constrained optimization to find topic posterior')
    parser.add_argument('--reuse_train_data', action='store_true', help='add this flag to reuse training data (e.g. when the number of training data is limited)')
    #parser.add_argument('--two_targets', action='store_true', help='add this flag to set t=2, i.e. neural network learns a distribution over V*V word pairs. Otherwise, t=1')
    parser.add_argument('--gen_new_tests', action='store_true', help='add this flag to generate new test documents that overwrite previously-generated test documents')

    input_args = parser.parse_args()
    test = input_args.tests
    times = input_args.times
    num_epochs = input_args.epochs
    hdim = input_args.hdim
    num_layers = input_args.layers
    save_model = input_args.save_model
    constrained_opt = input_args.constrained_opt
    reuse_train_data = input_args.reuse_train_data
    two_targets = True
    lamb = input_args.doclength
    gen_new_tests = input_args.gen_new_tests

    print(
        f'''
        Training Hyperparameters: 

        number of test documents = {test}
        t = {2 if two_targets else 1}
        number of times t target words are sampled for each training document = {times}
        neural network width = {hdim}
        neural network depth = {num_layers}
        number of training epochs = {num_epochs} (total number of training documents = {60000 * (num_epochs//2)})
        average document length = {lamb}
        save neural network model at the end? {save_model}
        using constrained optimization? {constrained_opt}
        reuse training data? {reuse_train_data}
        generate new test documents? {gen_new_tests}
        '''
    )

    run(extra_words=times, epochs=num_epochs, hdim=hdim, layers=num_layers, alphas=[1.0, 3.0, 5.0, 7.0, 9.0], N_test=test, save_model=save_model, 
        constrained_opt=constrained_opt, gen_new_tests=gen_new_tests, reuse_train_data=reuse_train_data, two_targets=two_targets, lamb=lamb)