import numpy as np
import scipy.sparse
import pickle
import os
import torch

def sample_preprocessed(doc_folder):
    valid_dataset, contrast_dataset = None, None

    valid_path = os.path.join(doc_folder, "valid")
    valid_files = [fname for fname in os.listdir(valid_path) if fname.endswith("_valid.pkl")]
    vfile = np.random.choice(valid_files)
    vfile = os.path.join(valid_path, vfile)

    contrast_path = os.path.join(doc_folder, "contrast")
    contrast_files = [fname for fname in os.listdir(contrast_path) if fname.endswith("_contrast.pkl")]
    cfile = np.random.choice(contrast_files)
    cfile = os.path.join(contrast_path, cfile)

    
    with open(vfile, 'rb') as fname:
        valid_dataset = pickle.load(fname)
    
    with open(cfile, 'rb') as fname:
        contrast_dataset = pickle.load(fname)

    return(contrast_dataset, valid_dataset)    

def expand(tokens, counts):
    tokens = map(lambda x: x.squeeze(), tokens)
    tokens = map(lambda x: [x.tolist()] if x.ndim==0 else x.tolist(), tokens)
    counts = map(lambda x: x.squeeze(), counts)
    counts = map(lambda x: [x.tolist()] if x.ndim==0 else x.tolist(), counts)
    return(map(lambda x: np.random.permutation(np.concatenate(list(map(lambda z: np.repeat(z[0], z[1]), zip(x[0], x[1]))))), zip(tokens, counts)))

def even_split(tokens, counts, lam, nsize):
    if(nsize > 0):
        nsize = min(nsize, len(tokens))
        inds = np.random.choice(len(tokens), size=nsize, replace=False)
        tokens = [tokens[i] for i in inds]
        counts = [counts[i] for i in inds]
    
    expanded = expand(tokens, counts)
    half_1, half_2 = zip(*map(lambda doc: (doc[:int(len(doc)/2)],doc[int(len(doc)/2):]), expanded))
    if(lam > 0):
        sizes = np.random.poisson(lam=lam, size=len(half_1))
        sizes_1 = np.maximum(np.minimum(sizes,  list(map(len, half_1))), 2)
        sizes_2 = np.maximum(np.minimum(sizes,  list(map(len, half_2))), 2)
        half_1 = map(lambda x: (x[0])[:x[1]], zip(half_1, sizes_1))
        half_2 = map(lambda x: (x[0])[:x[1]], zip(half_2, sizes_2))
    
    tokens_1, counts_1 = zip(*map(lambda doc: np.unique(doc, return_counts=True), half_1))
    tokens_2, counts_2 = zip(*map(lambda doc: np.unique(doc, return_counts=True), half_2))
    return(tokens_1, counts_1, tokens_2, counts_2)

def even_contrastive_documents(tokens, counts, nfolds=1, lam=0, nsize=0):
    num_docs = len(tokens)
    tokens_1, counts_1, tokens_2, counts_2 = [], [], [], []

    tokens_neg_half_1, counts_neg_half_1 = [], []
    tokens_neg_half_2, counts_neg_half_2 = [], []

    Y = []
    for _ in range(nfolds):
        tokens_pos_half_1, counts_pos_half_1, tokens_pos_half_2, counts_pos_half_2 = even_split(tokens, counts, lam=lam, nsize=nsize)

        tokens_1.extend(tokens_pos_half_1)
        counts_1.extend(counts_pos_half_1)
        tokens_2.extend(tokens_pos_half_2)
        counts_2.extend(counts_pos_half_2)

        tokens_neg_half_1, counts_neg_half_1, tokens_neg_half_2, counts_neg_half_2 = even_split(tokens, counts, lam=lam, nsize=nsize)
        
        permutation = np.random.permutation(len(tokens_neg_half_1))
        indices = np.arange(len(tokens_neg_half_1))
        valid_pos = indices!=permutation
        permutation = permutation[valid_pos]
        indices = indices[valid_pos]
        
        tokens_neg_half_1 = [tokens_neg_half_1[i] for i in indices]
        counts_neg_half_1 = [counts_neg_half_1[i] for i in indices]

        tokens_neg_half_2 = [tokens_neg_half_2[i] for i in permutation]
        counts_neg_half_2 = [counts_neg_half_2[i] for i in permutation]

        tokens_1.extend(tokens_neg_half_1)
        counts_1.extend(counts_neg_half_1)
        tokens_2.extend(tokens_neg_half_2)
        counts_2.extend(counts_neg_half_2)

        Y.extend([1]*len(tokens_pos_half_1) + [0]*len(tokens_neg_half_1))

    return(tokens_1, counts_1, tokens_2, counts_2, np.array(Y))