import numpy as np
import pickle, string, os, torch
from torch.utils.data import DataLoader
import torch.nn as nn
import sim_data, sampling
from models_datasets import PredictWordModel, PredictWordDataset, generate_batch, BaselineModel, SoftmaxBlockModel
from models_datasets import ResBlockModel, ResSoftmaxModel, AttentionModel, AttnCTMDataset, generate_attn_batch
from torch.utils.tensorboard import SummaryWriter
from gensim.models.word2vec import Word2Vec


class SynthArgs():
    def __init__(self, topics, lamb, id2word):
        self.topics = topics
        self.lamb = lamb
        self.id2word = id2word
        self.num = 60000


BATCH_SIZE = 128  # 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def choose_optimizer(model, lr, opt_type, use_scheduler=True, w_decay=0.0):
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    if (opt_type == "sgd"):
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.009,
                                    weight_decay=w_decay)
    elif (opt_type == "amsgrad"):
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay,
                                     amsgrad=True)
    elif (opt_type == 'adagrad'):
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr,
                                        weight_decay=w_decay)
    elif (opt_type == 'adadelta'):
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr,
                                        weight_decay=w_decay)
    elif (opt_type == 'rms'):
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.009,
                                        weight_decay=w_decay)

    if (use_scheduler):
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.5,
                                                               min_lr=1e-7)
        return (optimizer, scheduler)
    else:
        return (optimizer)


def choose_model(model_type, vocab_size, embed_dim, h_dim, dropout_p, n_layers, vectors, optimize_emb, emb_type):
    if model_type == 'default':
        model = PredictWordModel(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim, dropout_p=dropout_p,
                                 n_layers=n_layers, vectors=vectors, optimize_embedding=optimize_emb)
        if emb_type == 'one-hot':
            one_hot_mat = np.zeros((vocab_size, vocab_size))
            for i in range(one_hot_mat.shape[0]):
                one_hot_mat[i, i] = 1
            one_hot_mat = torch.tensor(one_hot_mat, dtype=torch.float32)
            model = PredictWordModel(vocab_size=vocab_size, embed_dim=vocab_size, h_dim=h_dim, dropout_p=dropout_p,
                                     n_layers=n_layers, vectors=one_hot_mat, optimize_embedding=optimize_emb)

    elif model_type == 'baseline':
        model = BaselineModel(vocab_size, embed_dim)
    elif model_type == 'block':
        model = SoftmaxBlockModel(vocab_size, h_dim, dropout_p, n_layers)
    elif model_type == 'residual':
        model = ResBlockModel(vocab_size, h_dim, dropout_p, n_layers)
    elif model_type == 'residual-softmax':
        model = ResSoftmaxModel(vocab_size, h_dim, dropout_p, n_layers)
    return model


def get_predictive_model(data_path, results_folder, model_type, c_dim, h_dim, nepochs, lr, embed_dim, opt_type,
                         t_extra_words=1, dropout_p=0.5, n_layers=3, resample=0, save_freq=25, use_scheduler=True,
                         pretrained_vectors=False, temp_model_folder=None, prev_model_file=None,
                         presampled_docs_file=None, synthetic_args=None, optimize_emb=True, emb_type=None,
                         save_model_weights=False):
    if (not os.path.isdir(results_folder)):
        os.mkdir(results_folder)

    with open(os.path.join(data_path, 'vocab.pkl'), 'rb') as f:
        vocab = pickle.load(f)
    vocab_size = len(vocab)

    sigma = np.load('cov_matrix.npy')

    ## Load up pretrained vectors
    w2v_embeddings = None
    if emb_type == 'word2vec':
        pretrain_doc, _ = sampling.synthetic_ctm_docs_attn(synthetic_args.topics, sigma, twords=t_extra_words,
                                                           lam=synthetic_args.lamb,
                                                           n_docs=synthetic_args.num)
        w2v_embeddings = get_pretrained_embeddings(pretrain_doc, vocab_size, h_dim, window_size=10)

    ## Build contrastive model
    print("building model...")
    temp_file = "".join(list(np.random.choice(list(string.ascii_lowercase), 10)) + list("_model_temp.pt"))
    if (temp_model_folder):
        temp_file = os.path.join(temp_model_folder, temp_file)

    ## Define sampling function
    def sample_CTM_documents(t):
        print("Sampling CTM documents...")
        if (synthetic_args is not None):
            raw_train_dataset = sampling.synthetic_ctm_two_targets(synthetic_args.topics, sigma, twords=t,
                                                                 lam=synthetic_args.lamb,
                                                                 n_docs=synthetic_args.num)
            raw_valid_dataset = sampling.synthetic_ctm_two_targets(synthetic_args.topics, sigma, twords=t,
                                                                 lam=synthetic_args.lamb,
                                                                 n_docs=3000)

        train_dataset = AttnCTMDataset(raw_train_dataset[0], raw_train_dataset[1])
        valid_dataset = AttnCTMDataset(raw_valid_dataset[0], raw_valid_dataset[1])

        valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                  collate_fn=generate_attn_batch)
        train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                  collate_fn=generate_attn_batch, shuffle=True)

        return (train_loader, valid_loader)

    train_loader, valid_loader = sample_CTM_documents(t_extra_words)

    ## Get model
    model = AttentionModel(vocab_size, h_dim, dropout_p, n_layers, w2v_embeddings, two_targets=True)

    ## Load model if it exists
    if (prev_model_file is not None):
        print(prev_model_file)
        model = torch.load(prev_model_file)

    model = model.to(device)
    print(model)
    torch.save(model, temp_file)

    ## Loss function
    ce_loss = nn.CrossEntropyLoss(reduction='sum')

    def multiword_loss_fn(outputs, y):
        loss = 0
        t = y.size(1)
        for i in range(t):
            loss += ce_loss(outputs, y[:, i])
        return loss / t

    loss_fn = multiword_loss_fn

    ## Scheduler/optimizer
    if (use_scheduler):
        optimizer, scheduler = choose_optimizer(model, lr, opt_type, use_scheduler=use_scheduler, w_decay=0.001)
    else:
        optimizer = choose_optimizer(model, lr, opt_type, use_scheduler=use_scheduler)

    ## Validation loss + accuracy
    def validation_loss():
        model.eval()
        total_loss = 0.0
        total_examples = 0.0
        with torch.no_grad():
            for text, y in valid_loader:
                outputs = model(text.to(device)).squeeze(1).cpu()

                ## Compute loss
                curr_loss = loss_fn(outputs, y.squeeze())

                total_loss += curr_loss.item()
                total_examples += len(text)
        return total_loss / total_examples

    ## Training step
    def train_step():
        CLIP_NORM = 25.0
        model.train()
        train_loss = 0.0
        total_examples = 0.0
        for text, y in train_loader:
            optimizer.zero_grad()
            model.zero_grad()

            predictions = model(text.to(device)).squeeze(1)
            loss = loss_fn(predictions, y.squeeze().to(device))

            train_loss += loss.item()
            total_examples += len(text)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            optimizer.step()

        return (train_loss / total_examples)

    valid_loss_list = []
    # valid_loss, valid_acc = validation_loss()
    valid_loss = validation_loss()
    best_valid_loss = valid_loss

    model_param_string = "n_layers_" + str(n_layers) + "_emb_dim_" + str(embed_dim) + "_c_dim_" + str(
        c_dim) + "_h_dim_" + str(h_dim)
    writer = SummaryWriter(comment=model_param_string)
    writer.add_scalar('Loss/valid', valid_loss, 0)

    print("Validation loss: ", valid_loss)
    valid_loss_list.append(valid_loss)

    for epoch in range(1, nepochs + 1):
        train_loss = train_step()

        valid_loss = validation_loss()
        valid_loss_list.append(valid_loss)
        print('epoch', epoch, 'valid loss:', round(valid_loss, 6))

        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/valid', valid_loss, epoch)
        if (use_scheduler):
            scheduler.step(valid_loss)
            writer.add_scalar('Learning rate', optimizer.param_groups[0]['lr'], epoch)

        if (valid_loss < best_valid_loss and temp_file):
            best_valid_loss = valid_loss
            ## Save out the current model
            torch.save(model, temp_file)

        if ((resample > 0) and (epoch < nepochs) and ((epoch % resample) == 0)):
            ## Resample documents
            train_loader, valid_loader = sample_CTM_documents(t_extra_words)

    print('final lr:', optimizer.param_groups[0]['lr'])

    model = torch.load(temp_file)
    os.remove(temp_file)
    return (model)


def get_pretrained_embeddings(documents, vocab_size, emb_size, window_size=10):
    corpus = [[str(word) for word in doc] for doc in documents]
    print('Using word2vec to pretrain word embeddings...')
    w2v_model = Word2Vec(corpus, vector_size=emb_size, window=window_size)

    embeddings = np.random.rand(vocab_size, emb_size)
    missing = 0
    for i in range(vocab_size):
        if str(i) in w2v_model.wv.index_to_key:
            embeddings[i] = w2v_model.wv[str(i)]
        else:
            missing += 1

    if missing > int(vocab_size / 100):
        print("Warning: %d words don't have pretrained embeddings" % missing)
    return torch.tensor(embeddings, dtype=torch.float32)