import numpy as np
import pickle, string, os, torch
from torch.utils.data import DataLoader
import torch.nn as nn
import sim_data, sampling
from models_datasets import PredictWordModel, PredictWordDataset, generate_batch
from torch.utils.tensorboard import SummaryWriter
from gensim.models.word2vec import Word2Vec

class SynthArgs():
    def __init__(self, topics, lamb, id2word):
        self.topics = topics
        self.lamb = lamb
        self.id2word = id2word
        self.num = 60000

BATCH_SIZE = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def make_embeddings(documents, vocab_size, emb_size, window_size=10):
    corpus = [[str(word.item()) for word in doc] for doc in documents]
    print('Using word2vec to pretrain word embeddings...')
    w2v_model = Word2Vec(corpus, vector_size=emb_size, window=window_size)

    embeddings = np.random.rand(vocab_size, emb_size)
    missing = 0
    for i in range(vocab_size):
        if str(i) in w2v_model.wv.index_to_key:
            embeddings[i] = w2v_model.wv[str(i)]
        else:
            missing += 1

    if missing > int(vocab_size/100):
        print("Warning: %d words don't have pretrained embeddings" %missing)
    return embeddings

def choose_optimizer(model, lr, opt_type, use_scheduler=True, w_decay=0.0):
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    if (opt_type == "sgd"):
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.009,
                                    weight_decay=w_decay)
    elif (opt_type == "amsgrad"):
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay,
                                     amsgrad=True)
    elif (opt_type == 'adagrad'):
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr,
                                        weight_decay=w_decay)
    elif (opt_type == 'adadelta'):
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr,
                                        weight_decay=w_decay)
    elif (opt_type == 'rms'):
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.009,
                                        weight_decay=w_decay)

    if (use_scheduler):
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.5,
                                                               min_lr=1e-7)
        return (optimizer, scheduler)
    else:
        return (optimizer)

def choose_model(vocab_size, embed_dim, h_dim, dropout_p, n_layers, vectors, optimize_emb, emb_type, w2v_window, pretrain_docs):
    model = PredictWordModel(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim, dropout_p=dropout_p,
                             n_layers=n_layers, vectors=vectors, optimize_embedding=optimize_emb)
    if emb_type == 'one-hot':
        one_hot_mat = np.zeros((vocab_size, vocab_size))
        for i in range(one_hot_mat.shape[0]):
            one_hot_mat[i, i] = 1
        one_hot_mat = torch.tensor(one_hot_mat, dtype=torch.float32)
        model = PredictWordModel(vocab_size=vocab_size, embed_dim=vocab_size, h_dim=h_dim, dropout_p=dropout_p,
                                 n_layers=n_layers, vectors=one_hot_mat, optimize_embedding=optimize_emb)
    elif emb_type == 'word2vec':
        w2v_emb = make_embeddings(pretrain_docs, vocab_size, embed_dim, w2v_window)
        w2v_emb = torch.tensor(w2v_emb, dtype=torch.float32)
        model = PredictWordModel(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim, dropout_p=dropout_p,
                                 n_layers=n_layers, vectors=w2v_emb, optimize_embedding=optimize_emb)
    return model

def get_predictive_model(data_path, results_folder, c_dim, h_dim, nepochs, lr, embed_dim, opt_type, t_extra_words=1,
                         dropout_p=0.5, n_layers=3, nfolds=1, resample=0, save_freq=25, use_scheduler=True,
                         pretrained_vectors=False, temp_model_folder=None, prev_model_file=None,
                         synthetic_args=None, optimize_emb=True, emb_type=None, w2vwindow=10, pretrain_docs=None):
    if (not os.path.isdir(results_folder)):
        os.mkdir(results_folder)

    with open(os.path.join(data_path, 'vocab.pkl'), 'rb') as f:
        vocab = pickle.load(f)
    vocab_size = len(vocab)

    ## Load up pretrained vectors
    vectors = None
    if (pretrained_vectors):
        pretrained_vectors_file = os.path.join(data_path, 'skipEmbeddings.npy')
        vectors = torch.from_numpy(np.load(pretrained_vectors_file)).float()
        vocab_size, embed_dim = vectors.shape

    ## Build contrastive model
    print("building model...")
    temp_file = "".join(list(np.random.choice(list(string.ascii_lowercase), 10)) + list("_model_temp.pt"))
    if (temp_model_folder):
        temp_file = os.path.join(temp_model_folder, temp_file)

    ## Define sampling function
    def sample_documents():
        print("Sampling documents...")
        if (synthetic_args is not None):
            raw_train_dataset = sampling.synthetic_documents(synthetic_args.topics, lam=synthetic_args.lamb,
                                                             n_docs=synthetic_args.num)
            raw_valid_dataset = sampling.synthetic_documents(synthetic_args.topics, lam=synthetic_args.lamb,
                                                             n_docs=3000)
        # elif (presampled_docs_file is not None):
        #   raw_train_dataset, raw_valid_dataset = contrastive_sampling.sample_preprocessed(presampled_docs_file)

        train_dataset = PredictWordDataset(raw_train_dataset[0], raw_train_dataset[1], raw_train_dataset[2])
        valid_dataset = PredictWordDataset(raw_valid_dataset[0], raw_valid_dataset[1], raw_valid_dataset[2])

        valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                  collate_fn=generate_batch)
        train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                  collate_fn=generate_batch, shuffle=True)

        return (train_loader, valid_loader)

    def sample_fullLDA_documents(t):
        print("Sampling full LDA documents...")
        if (synthetic_args is not None):
            raw_train_dataset = sampling.synthetic_lda_docs(synthetic_args.topics, twords=t,
                                                                    lam=synthetic_args.lamb,
                                                                    n_docs=synthetic_args.num)
            raw_valid_dataset = sampling.synthetic_lda_docs(synthetic_args.topics, twords=t,
                                                                    lam=synthetic_args.lamb,
                                                                    n_docs=3000)
        # elif (presampled_docs_file is not None):
        #   raw_train_dataset, raw_valid_dataset = contrastive_sampling.sample_preprocessed(presampled_docs_file)

        train_dataset = PredictWordDataset(raw_train_dataset[0], raw_train_dataset[1], raw_train_dataset[2])
        valid_dataset = PredictWordDataset(raw_valid_dataset[0], raw_valid_dataset[1], raw_valid_dataset[2])

        valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                  collate_fn=generate_batch)
        train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                  collate_fn=generate_batch, shuffle=True)

        return (train_loader, valid_loader)

    if t_extra_words > 1:
        train_loader, valid_loader = sample_fullLDA_documents(t_extra_words)
    else:
        train_loader, valid_loader = sample_documents()

    ## Get model
    model = choose_model(vocab_size, embed_dim, h_dim, dropout_p, n_layers, vectors,
                         optimize_emb, emb_type, w2vwindow, pretrain_docs)

    ## Load model if it exists
    if (prev_model_file is not None):
        print(prev_model_file)
        model = torch.load(prev_model_file)

    model = model.to(device)
    print(model)
    torch.save(model, temp_file)

    ## Loss function
    # loss_fn = nn.CrossEntropyLoss()
    ce_loss = nn.CrossEntropyLoss(reduction='sum')

    def multiword_loss_fn(outputs, y):
        loss = 0
        t = y.shape[1]
        for i in range(t):
            loss += ce_loss(outputs, y[:, i])
        return loss / t

    loss_fn = multiword_loss_fn

    ## Scheduler/optimizer
    if (use_scheduler):
        optimizer, scheduler = choose_optimizer(model, lr, opt_type, use_scheduler=use_scheduler)
    else:
        optimizer = choose_optimizer(model, lr, opt_type, use_scheduler=use_scheduler)

    ## Validation loss + accuracy
    def validation_loss():
        model.eval()
        total_loss = 0.0
        total_examples = 0.0
        total_correct = 0.0
        with torch.no_grad():
            for text, offset, y in valid_loader:
                n_examples = len(offset)
                outputs = model(text.to(device), offset.to(device)).squeeze(1).cpu()

                ## Compute loss
                curr_loss = loss_fn(outputs, y.squeeze())
                total_loss += curr_loss.item()

                ## Compute accuracy
                # predictions = torch.argmax(outputs,dim=-1)
                # total_correct += (predictions == y).sum().item()

                total_examples += n_examples
        return total_loss / total_examples  # , total_correct / total_examples

    ## Training step
    def train_step():
        CLIP_NORM = 25.0
        model.train()
        train_loss = 0.0
        total_examples = 0.0
        for text, offset, y in train_loader:
            optimizer.zero_grad()
            model.zero_grad()

            n_examples = len(offset)
            predictions = model(text.to(device), offset.to(device)).squeeze(1)
            loss = loss_fn(predictions, y.squeeze().to(device))

            train_loss += loss.item()
            total_examples += n_examples

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            optimizer.step()

        return (train_loss / total_examples)

    valid_loss_list = []
    # valid_loss, valid_acc = validation_loss()
    valid_loss = validation_loss()
    best_valid_loss = valid_loss

    model_param_string = "n_layers_" + str(n_layers) + "_emb_dim_" + str(embed_dim) + "_c_dim_" + str(
        c_dim) + "_h_dim_" + str(h_dim)
    writer = SummaryWriter(comment=model_param_string)
    writer.add_scalar('Loss/valid', valid_loss, 0)
    # writer.add_scalar('Accuracy/valid', valid_acc, 0)

    print("Validation loss: ", valid_loss)
    # print("Validation accuracy: ", valid_acc)
    valid_loss_list.append(valid_loss)

    for epoch in range(1, nepochs + 1):
        train_loss = train_step()
        #print('epoch', epoch, 'training loss:', round(train_loss,7))
        
        valid_loss = validation_loss()
        valid_loss_list.append(valid_loss)
        print('epoch', epoch, 'valid loss:', round(valid_loss,6))

        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/valid', valid_loss, epoch)
        # writer.add_scalar('Accuracy/valid', valid_acc, epoch)
        if (use_scheduler):
            scheduler.step(valid_loss)
            writer.add_scalar('Learning rate', optimizer.param_groups[0]['lr'], epoch)

        if (valid_loss < best_valid_loss and temp_file):
            best_valid_loss = valid_loss
            ## Save out the current model
            torch.save(model, temp_file)

        if ((resample > 0) and (epoch < nepochs) and ((epoch % resample) == 0)):
            ## Resample documents
            # train_loader, valid_loader = sample_documents()
            if t_extra_words > 1:
                train_loader, valid_loader = sample_fullLDA_documents(t_extra_words)
            else:
                train_loader, valid_loader = sample_documents()
        ## Compute embeddings
        # if ((epoch % save_freq) == 0):
        #   save_stats(epoch)

    # if (nepochs % save_freq != 0):
    # save_stats(nepochs)  ## One last time
    print('final lr:', optimizer.param_groups[0]['lr'])

    model = torch.load(temp_file)
    os.remove(temp_file)
    return (model)
