import numpy as np
import string, os, torch
from torch.utils.data import DataLoader
import torch.nn as nn
import contrastive_sampling
from models_datasets import SingleDataset, generate_single_batch, ContrastiveDataset, generate_contrast_batch, ReferenceDataset, generate_reference_batch, ContrastiveModel
from bow_data import get_data
from torch.utils.tensorboard import SummaryWriter


class SynthArgs():
    def __init__(self, topics, lamb, id2word):
        self.topics = topics
        self.lamb = lamb
        self.id2word = id2word
        self.num = 60000

BATCH_SIZE = 512
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

def choose_optimizer(model, lr, opt_type, use_scheduler=True, w_decay=0.0):
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    if(opt_type == "sgd"):
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.009, weight_decay=w_decay)
    elif(opt_type == "amsgrad"):
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay, amsgrad=True)
    elif(opt_type == 'adagrad'):
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    elif(opt_type == 'adadelta'):
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=w_decay)
    elif(opt_type == 'rms'):
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.009, weight_decay=w_decay)

    if(use_scheduler):
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.5, min_lr=1e-7)
        return(optimizer, scheduler)
    else:
        return(optimizer)


def reference_embeddings(tokens, counts, ref_tokens, ref_counts, model, tol=0.05, take_exp=True):
    nexamples = len(tokens)
    nrefs = len(ref_tokens)
    dataset = ReferenceDataset(tokens, counts, ref_tokens, ref_counts)
    data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, shuffle=False, collate_fn=generate_reference_batch)

    representation = None
    
    with torch.no_grad():
        all_preds = []
        model.to(device)
        model.eval()
        print("Making predictions")
        for text_1, offset_1, text_2, offset_2 in data_loader:
            predictions = model(text_1.to(device), offset_1.to(device), text_2.to(device), offset_2.to(device)).squeeze(1).float()
            if(take_exp):
                predictions = torch.exp(predictions)

            if(tol > 0):
                tol_tensor = ((1.0-tol)/tol)*torch.ones_like(predictions)
                x = torch.min(predictions, tol_tensor).cpu()
            else:
                x = predictions.cpu()
                
            all_preds.append(x)

        print("Tensorizing result")
        representation = torch.cat(all_preds)
        representation = representation.view(nexamples, nrefs)
    
    return(representation)


    
def model_embeddings(tokens, counts, model):
    dataset = SingleDataset(tokens, counts)
    data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, shuffle=False, collate_fn=generate_single_batch)

    with torch.no_grad():
        all_preds = []
        model.to(device)
        model.eval()
        print("Getting embeddings")
        for text, offset in data_loader:
            embedding = model.get_embedding(text.to(device), offset.to(device))
            all_preds.append(embedding.cpu())

        print("Tensorizing result")
        representation = torch.cat(all_preds)
    
    return(representation)


def contrastive_representation(data_path, results_folder, c_dim, h_dim, nepochs, lr, embed_dim, opt_type, landmarks=0, dropout_p=0.5, n_layers=3, nfolds=1, resample=0, save_freq=25, use_scheduler=True, pretrained_vectors=False, temp_model_folder=None, prev_model_file=None, presampled_docs_file=None, synthetic_args=None):
    if(not os.path.isdir(results_folder)):
        os.mkdir(results_folder)
    
    vocab, train, test, unsup, valid = get_data(data_path)
    vocab_size = len(vocab)
    
    ## Get data
    train_tokens = train['tokens']
    train_counts = train['counts']

    test_tokens = test['tokens']
    test_counts = test['counts']

    unsup_tokens = unsup['tokens']
    unsup_counts = unsup['counts']

    valid_tokens = valid['tokens']
    valid_counts = valid['counts']

    if(landmarks > 0):
        ## Create landmark documents
        inds = np.random.choice(len(unsup_tokens), landmarks, replace=False)
        ref_tokens = [unsup_tokens[i] for i in inds]
        ref_counts = [unsup_counts[i] for i in inds]

    ## Load up pretrained vectors
    vectors = None
    if(pretrained_vectors):
        pretrained_vectors_file = os.path.join(data_path, 'skipEmbeddings.npy')
        vectors = torch.from_numpy(np.load(pretrained_vectors_file)).float()
        vocab_size, embed_dim = vectors.shape

    ## Build contrastive model
    print("building model...")
    temp_file = "".join(list(np.random.choice(list(string.ascii_lowercase), 10)) + list("_model_temp.pt"))
    if(temp_model_folder):
        temp_file = os.path.join(temp_model_folder, temp_file)
    

    ## Define sampling function
    def sample_documents():
        print("Sampling contrastive documents...")
        if(presampled_docs_file is not None):
            raw_train_dataset, raw_valid_dataset = contrastive_sampling.sample_preprocessed(presampled_docs_file)
        else:
            raw_train_dataset = contrastive_sampling.even_contrastive_documents(unsup_tokens, unsup_counts, nfolds=1)
            raw_valid_dataset = contrastive_sampling.even_contrastive_documents(valid_tokens, valid_counts, nfolds=1)

        train_dataset = ContrastiveDataset(raw_train_dataset[0], raw_train_dataset[1], raw_train_dataset[2], raw_train_dataset[3], raw_train_dataset[4])
        valid_dataset = ContrastiveDataset(raw_valid_dataset[0], raw_valid_dataset[1], raw_valid_dataset[2], raw_valid_dataset[3], raw_valid_dataset[4])

        valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, collate_fn=generate_contrast_batch)
        train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, collate_fn=generate_contrast_batch, shuffle=True)

        return(train_loader, valid_loader)

    train_loader, valid_loader = sample_documents()
        
    ## Get model
    model = ContrastiveModel(vocab_size=vocab_size, embed_dim=embed_dim, c_dim=c_dim, h_dim=h_dim, dropout_p=dropout_p, n_layers=n_layers, vectors=vectors)

    ## Load model if it exists
    if(prev_model_file is not None):
        print(prev_model_file)
        model = torch.load(prev_model_file)

    model = model.to(device)
    print(model)
    torch.save(model, temp_file)

    ## Loss function
    loss_fn = nn.BCEWithLogitsLoss(reduction='sum') 

    ## Scheduler/optimizer
    if(use_scheduler):
        optimizer, scheduler = choose_optimizer(model, lr, opt_type, use_scheduler=use_scheduler)
    else:
        optimizer = choose_optimizer(model, lr, opt_type, use_scheduler=use_scheduler)

    ## Validation loss + accuracy
    def validation_loss():
        model.eval()
        total_loss = 0.0
        total_examples = 0.0
        total_correct = 0.0
        with torch.no_grad():
            for text_1, offset_1, text_2, offset_2, y in valid_loader:
                n_examples = len(offset_1)
                outputs = model(text_1.to(device), offset_1.to(device), text_2.to(device), offset_2.to(device)).squeeze(1).cpu()

                ## Compute loss
                curr_loss = loss_fn(outputs, y.squeeze())
                total_loss += curr_loss.item()

                ## Compute accuracy
                predictions = torch.where(outputs > 0.5, torch.ones_like(outputs), torch.zeros_like(outputs))
                total_correct += (predictions == y).sum().item()


                total_examples += n_examples
        return(total_loss/total_examples, total_correct/total_examples)

    ## Training step
    def train_step():
        CLIP_NORM = 25.0
        model.train()
        train_loss = 0.0
        total_examples = 0.0
        for text_1, offset_1, text_2, offset_2, y in train_loader:
            optimizer.zero_grad()
            model.zero_grad()

            n_examples = len(offset_1)
            predictions = model(text_1.to(device), offset_1.to(device), text_2.to(device), offset_2.to(device)).squeeze(1)
            loss = loss_fn(predictions, y.squeeze().to(device))

            train_loss += loss.item()
            total_examples += n_examples

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            optimizer.step()

        return(train_loss/total_examples)

    valid_loss_list = []
    valid_loss, valid_acc = validation_loss()
    best_valid_loss = valid_loss

    ## Define save function
    def save_stats(epoch):
        best_model = torch.load(temp_file)
        best_model.to(device)

        if(landmarks > 0):
            print("Creating landmark representation for testing documents.")
            X_test = reference_embeddings(test_tokens, test_counts, ref_tokens, ref_counts, model, tol=0.05)
            print("Creating landmark representation for training documents.")
            X_train = reference_embeddings(train_tokens, train_counts, ref_tokens, ref_counts, model, tol=0.05)
        else:
            print("Creating contrastive representation for testing documents.")
            X_test = model_embeddings(test_tokens, test_counts, model)
            print("Creating contrastive representation for training documents.")
            X_train = model_embeddings(train_tokens, train_counts, model)

        folder = os.path.join(results_folder, "_".join(["epoch", str(epoch)]))
        if(not os.path.isdir(folder)):
            os.mkdir(folder)
        
        rand_tag = "".join(list(np.random.choice(list(string.ascii_lowercase), 10)))
        test_name = os.path.join(folder, "_".join([rand_tag, "test.npy"]))
        train_name = os.path.join(folder, "_".join([rand_tag, "train.npy"]))
        loss_name = os.path.join(folder, "_".join([rand_tag, "valid_loss.npy"]))
        model_name = os.path.join(folder, "_".join([rand_tag, "model.pt"]))

        np.save(test_name, X_test.numpy())
        np.save(train_name, X_train.numpy())
        np.save(loss_name, np.array(valid_loss_list))
        torch.save(best_model, model_name)

        del valid_loss_list[:]
    
    model_param_string = "n_layers_" + str(n_layers) + "_emb_dim_" + str(embed_dim) + "_c_dim_" + str(c_dim) + "_h_dim_" + str(h_dim)
    writer = SummaryWriter(comment=model_param_string)
    writer.add_scalar('Loss/valid', valid_loss, 0)
    writer.add_scalar('Accuracy/valid', valid_acc, 0)

    print("Validation loss: ", valid_loss)
    print("Validation accuracy: ", valid_acc)
    valid_loss_list.append(valid_loss)
    for epoch in range(1, nepochs+1):
        train_loss = train_step()
        valid_loss, valid_acc = validation_loss()


        valid_loss_list.append(valid_loss)

        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/valid', valid_loss, epoch)
        writer.add_scalar('Accuracy/valid', valid_acc, epoch)
        if(use_scheduler):
            scheduler.step(valid_loss)
            writer.add_scalar('Learning rate', optimizer.param_groups[0]['lr'], epoch)
        
        if(valid_loss < best_valid_loss and temp_file):
            best_valid_loss = valid_loss
            ## Save out the current model
            torch.save(model, temp_file)

        if((resample > 0) and (epoch < nepochs) and ((epoch % resample) == 0)):
            ## Resample documents
            train_loader, valid_loader = sample_documents()
        
        ## Compute embeddings
        if((epoch % save_freq) == 0):
            save_stats(epoch)

    if(nepochs % save_freq != 0):
        save_stats(nepochs) ## One last time
    
    model = torch.load(temp_file)
    os.remove(temp_file)
    return(model)

def extract_embedding(data_path, model_file, reference_embed=True, dim=100):
    model = torch.load(model_file)

    _, train, test, unsup, _ = get_data(data_path)

    # 1. training data
    train_tokens = train['tokens']
    train_counts = train['counts']

    # 2. testing set
    test_tokens = test['tokens']
    test_counts = test['counts']

    X_test, X_train = None, None
    if(reference_embed):
        unsup_tokens = unsup['tokens']
        unsup_counts = unsup['counts']

        ## Create representative documents
        inds = np.random.choice(len(unsup_tokens), dim, replace=False)
        ref_tokens = [unsup_tokens[i] for i in inds]
        ref_counts = [unsup_counts[i] for i in inds]

        print("Creating contrastive representation for testing documents.")
        X_test = reference_embeddings(tokens=test_tokens, 
                                        counts=test_counts, 
                                        ref_tokens=ref_tokens, 
                                        ref_counts=ref_counts, 
                                        model=model)
        
        print("Creating contrastive representation for training documents.")
        X_train = reference_embeddings(tokens=train_tokens, 
                                            counts=train_counts, 
                                            ref_tokens=ref_tokens, 
                                            ref_counts=ref_counts, 
                                            model=model)
    else:
        print("Creating contrastive representation for testing documents.")
        X_test = model_embeddings(test_tokens, test_counts, model)
        
        print("Creating contrastive representation for training documents.")
        X_train = model_embeddings(train_tokens, train_counts, model)
    
    return(X_train.numpy(), X_test.numpy())