import numpy as np
import string, os, torch
from torch.utils.data import DataLoader
import torch.nn as nn
import sim_data, sampling
from models_datasets import PredictWordModel, PredictWordDataset, generate_batch
from bow_data import get_data
from torch.utils.tensorboard import SummaryWriter

class SynthArgs():
    def __init__(self, topics, lamb, id2word):
        self.topics = topics
        self.lamb = lamb
        self.id2word = id2word
        self.num = 60000

BATCH_SIZE = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def choose_optimizer(model, lr, opt_type, use_scheduler=True, w_decay=0.0):
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    if(opt_type == "sgd"):
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.009, weight_decay=w_decay)
    elif(opt_type == "amsgrad"):
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay, amsgrad=True)
    elif(opt_type == 'adagrad'):
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    elif(opt_type == 'adadelta'):
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    elif(opt_type == 'rms'):
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.009, weight_decay=w_decay)

    if(use_scheduler):
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.5, min_lr=1e-7)
        return(optimizer, scheduler)
    else:
        return(optimizer)


def get_predictive_model(data_path, results_folder, c_dim, h_dim, nepochs, lr, embed_dim, opt_type, t_extra_words=1,
                               dropout_p=0.5, n_layers=3, nfolds=1, resample=0, save_freq=25, use_scheduler=True,
                               pretrained_vectors=False, temp_model_folder=None, prev_model_file=None,
                               presampled_docs_file=None, synthetic_args=None):

    if (not os.path.isdir(results_folder)):
        os.mkdir(results_folder)

    vocab, train, test, unsup, valid = get_data(data_path)
    vocab_size = len(vocab)

    ## Get data
    train_tokens = train['tokens']
    train_counts = train['counts']

    test_tokens = test['tokens']
    test_counts = test['counts']

    unsup_tokens = unsup['tokens']
    unsup_counts = unsup['counts']

    valid_tokens = valid['tokens']
    valid_counts = valid['counts']

    '''
    if (landmarks > 0):
        ## Create landmark documents
        inds = np.random.choice(len(unsup_tokens), landmarks, replace=False)
        ref_tokens = [unsup_tokens[i] for i in inds]
        ref_counts = [unsup_counts[i] for i in inds]
        '''

    ## Load up pretrained vectors
    vectors = None
    if (pretrained_vectors):
        pretrained_vectors_file = os.path.join(data_path, 'skipEmbeddings.npy')
        vectors = torch.from_numpy(np.load(pretrained_vectors_file)).float()
        vocab_size, embed_dim = vectors.shape

    ## Build contrastive model
    print("building model...")
    temp_file = "".join(list(np.random.choice(list(string.ascii_lowercase), 10)) + list("_model_temp.pt"))
    if (temp_model_folder):
        temp_file = os.path.join(temp_model_folder, temp_file)

    ## Define sampling function 
    def sample_documents():
        print("Sampling documents...")
        if (synthetic_args is not None):

            raw_train_dataset = sampling.synthetic_documents(synthetic_args.topics,lam=synthetic_args.lamb,
                                                                                     n_docs=synthetic_args.num)
            raw_valid_dataset = sampling.synthetic_documents(synthetic_args.topics, lam=synthetic_args.lamb,
                                                             n_docs=3000)
        #elif (presampled_docs_file is not None):
         #   raw_train_dataset, raw_valid_dataset = contrastive_sampling.sample_preprocessed(presampled_docs_file)

        train_dataset = PredictWordDataset(raw_train_dataset[0], raw_train_dataset[1], raw_train_dataset[2])
        valid_dataset = PredictWordDataset(raw_valid_dataset[0], raw_valid_dataset[1], raw_valid_dataset[2])

        valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                  collate_fn=generate_batch)
        train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                  collate_fn=generate_batch, shuffle=True)

        return (train_loader, valid_loader)

    def sample_multitarget_documents(t):
        print("Sampling multiple extra word documents...")
        if (synthetic_args is not None):

            raw_train_dataset = sampling.synthetic_multitarget_docs(synthetic_args.topics, twords=t, lam=synthetic_args.lamb,
                                                                                     n_docs=synthetic_args.num)
            raw_valid_dataset = sampling.synthetic_multitarget_docs(synthetic_args.topics, twords=t, lam=synthetic_args.lamb,
                                                             n_docs=3000)
        #elif (presampled_docs_file is not None):
         #   raw_train_dataset, raw_valid_dataset = contrastive_sampling.sample_preprocessed(presampled_docs_file)

        train_dataset = PredictWordDataset(raw_train_dataset[0], raw_train_dataset[1], raw_train_dataset[2])
        valid_dataset = PredictWordDataset(raw_valid_dataset[0], raw_valid_dataset[1], raw_valid_dataset[2])

        valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                  collate_fn=generate_batch)
        train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                  collate_fn=generate_batch, shuffle=True)

        return (train_loader, valid_loader)

    if t_extra_words > 1:
        train_loader, valid_loader = sample_multitarget_documents(t_extra_words)
    else: 
        train_loader, valid_loader = sample_documents()

    ## Get model
    model = PredictWordModel(vocab_size=vocab_size, embed_dim=embed_dim, h_dim=h_dim, dropout_p=dropout_p,
                             n_layers=n_layers, vectors=vectors)

    ## Load model if it exists
    if (prev_model_file is not None):
        print(prev_model_file)
        model = torch.load(prev_model_file)

    model = model.to(device)
    print(model)
    torch.save(model, temp_file)

    ## Loss function
    # loss_fn = nn.CrossEntropyLoss()
    ce_loss = nn.CrossEntropyLoss()
    def multiword_loss_fn(outputs, y):
        loss = 0
        t = y.shape[1] 
        for i in range(t):
            loss += ce_loss(outputs, y[:,i])
        return loss/t

    loss_fn = multiword_loss_fn

    ## Scheduler/optimizer
    if (use_scheduler):
        optimizer, scheduler = choose_optimizer(model, lr, opt_type, use_scheduler=use_scheduler)
    else:
        optimizer = choose_optimizer(model, lr, opt_type, use_scheduler=use_scheduler)

    ## Validation loss + accuracy
    def validation_loss():
        model.eval()
        total_loss = 0.0
        total_examples = 0.0
        total_correct = 0.0
        with torch.no_grad():
            for text, offset, y in valid_loader:
                n_examples = len(offset)
                outputs = model(text.to(device), offset.to(device)).squeeze(1).cpu()
                
                ## Compute loss
                curr_loss = loss_fn(outputs, y.squeeze())
                total_loss += curr_loss.item()

                ## Compute accuracy
                #predictions = torch.argmax(outputs,dim=-1) 
                #total_correct += (predictions == y).sum().item()

                total_examples += n_examples
        return total_loss / total_examples #, total_correct / total_examples

    ## Training step
    def train_step():
        CLIP_NORM = 25.0
        model.train()
        train_loss = 0.0
        total_examples = 0.0
        for text, offset, y in train_loader:
            optimizer.zero_grad()
            model.zero_grad()

            n_examples = len(offset)
            predictions = model(text.to(device), offset.to(device)).squeeze(1)
            loss = loss_fn(predictions, y.squeeze().to(device))

            train_loss += loss.item()
            total_examples += n_examples

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            optimizer.step()

        return (train_loss / total_examples)

    valid_loss_list = []
    #valid_loss, valid_acc = validation_loss()
    valid_loss = validation_loss()
    best_valid_loss = valid_loss

    model_param_string = "n_layers_" + str(n_layers) + "_emb_dim_" + str(embed_dim) + "_c_dim_" + str(
        c_dim) + "_h_dim_" + str(h_dim)
    writer = SummaryWriter(comment=model_param_string)
    writer.add_scalar('Loss/valid', valid_loss, 0)
    #writer.add_scalar('Accuracy/valid', valid_acc, 0)

    print("Validation loss: ", valid_loss)
    #print("Validation accuracy: ", valid_acc)
    valid_loss_list.append(valid_loss)

    for epoch in range(1, nepochs + 1):
        train_loss = train_step()
        #valid_loss, valid_acc = validation_loss()
        valid_loss = validation_loss()

        valid_loss_list.append(valid_loss)

        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/valid', valid_loss, epoch)
        #writer.add_scalar('Accuracy/valid', valid_acc, epoch)
        if (use_scheduler):
            scheduler.step(valid_loss)
            writer.add_scalar('Learning rate', optimizer.param_groups[0]['lr'], epoch)

        if (valid_loss < best_valid_loss and temp_file):
            best_valid_loss = valid_loss
            ## Save out the current model
            torch.save(model, temp_file)

        if ((resample > 0) and (epoch < nepochs) and ((epoch % resample) == 0)):
            ## Resample documents
            #train_loader, valid_loader = sample_documents()
            if t_extra_words > 1:
                train_loader, valid_loader = sample_multitarget_documents(t_extra_words)
            else:
                train_loader, valid_loader = sample_documents()

    model = torch.load(temp_file)
    os.remove(temp_file)
    return (model)

