import numpy as np
import scipy.sparse
import pickle
import os
import torch


def expand(tokens, counts):
    tokens = map(lambda x: x.squeeze(), tokens)
    tokens = map(lambda x: [x.tolist()] if x.ndim==0 else x.tolist(), tokens)
    counts = map(lambda x: x.squeeze(), counts)
    counts = map(lambda x: [x.tolist()] if x.ndim==0 else x.tolist(), counts)
    return(map(lambda x: np.random.permutation(np.concatenate(list(map(lambda z: np.repeat(z[0], z[1]), zip(x[0], x[1]))))), zip(tokens, counts)))

def even_split(tokens, counts, lam, nsize):
    if(nsize > 0):
        nsize = min(nsize, len(tokens))
        inds = np.random.choice(len(tokens), size=nsize, replace=False)
        tokens = [tokens[i] for i in inds]
        counts = [counts[i] for i in inds]
    
    expanded = expand(tokens, counts)
    half_1, half_2 = zip(*map(lambda doc: (doc[:int(len(doc)/2)],doc[int(len(doc)/2):]), expanded))
    if(lam > 0):
        sizes = np.random.poisson(lam=lam, size=len(half_1))
        sizes_1 = np.maximum(np.minimum(sizes,  list(map(len, half_1))), 2)
        sizes_2 = np.maximum(np.minimum(sizes,  list(map(len, half_2))), 2)
        half_1 = map(lambda x: (x[0])[:x[1]], zip(half_1, sizes_1))
        half_2 = map(lambda x: (x[0])[:x[1]], zip(half_2, sizes_2))
    
    tokens_1, counts_1 = zip(*map(lambda doc: np.unique(doc, return_counts=True), half_1))
    tokens_2, counts_2 = zip(*map(lambda doc: np.unique(doc, return_counts=True), half_2))
    return(tokens_1, counts_1, tokens_2, counts_2)

def even_contrastive_documents(tokens, counts, nfolds=1, lam=0, nsize=0):
    num_docs = len(tokens)
    tokens_1, counts_1, tokens_2, counts_2 = [], [], [], []

    tokens_neg_half_1, counts_neg_half_1 = [], []
    tokens_neg_half_2, counts_neg_half_2 = [], []

    Y = []
    for _ in range(nfolds):
        tokens_pos_half_1, counts_pos_half_1, tokens_pos_half_2, counts_pos_half_2 = even_split(tokens, counts, lam=lam, nsize=nsize)

        tokens_1.extend(tokens_pos_half_1)
        counts_1.extend(counts_pos_half_1)
        tokens_2.extend(tokens_pos_half_2)
        counts_2.extend(counts_pos_half_2)

        tokens_neg_half_1, counts_neg_half_1, tokens_neg_half_2, counts_neg_half_2 = even_split(tokens, counts, lam=lam, nsize=nsize)
        
        permutation = np.random.permutation(len(tokens_neg_half_1))
        indices = np.arange(len(tokens_neg_half_1))
        valid_pos = indices!=permutation
        permutation = permutation[valid_pos]
        indices = indices[valid_pos]
        
        tokens_neg_half_1 = [tokens_neg_half_1[i] for i in indices]
        counts_neg_half_1 = [counts_neg_half_1[i] for i in indices]

        tokens_neg_half_2 = [tokens_neg_half_2[i] for i in permutation]
        counts_neg_half_2 = [counts_neg_half_2[i] for i in permutation]

        tokens_1.extend(tokens_neg_half_1)
        counts_1.extend(counts_neg_half_1)
        tokens_2.extend(tokens_neg_half_2)
        counts_2.extend(counts_neg_half_2)

        Y.extend([1]*len(tokens_pos_half_1) + [0]*len(tokens_neg_half_1))

    return(tokens_1, counts_1, tokens_2, counts_2, np.array(Y))







#Method to sample self-supervised predict documents with labels
#Sampled from unsupervised dataset and self-generate labels by randomly masking t words
def real_predict_document(tokens,counts,t=1):
    docs,Y=[],[]
    expanded_docs=[np.concatenate([np.repeat(t, c) for t,c in zip(tok.squeeze(), count.squeeze())]) for tok, count in zip(tokens, counts)]
    for doc in expanded_docs:
        mask_inds = np.random.choice(len(doc), size=t, replace=False)
        mask_doc=[doc[i] for i in range(len(doc)) if i not in mask_inds]
        words_to_predict=[doc[i] for i in range(len(doc)) if i in mask_inds]
        docs.append(np.array(mask_doc))
        Y.append(words_to_predict)
    return (docs,np.array(Y))

# def real_predict_document_attn(tokens,counts,t=1):
#     docs,Y=[],[]
#     expanded_docs=[np.concatenate([np.repeat(t, c) for t,c in zip(tok.squeeze(), count.squeeze())]) for tok, count in zip(tokens, counts)]
#     for doc in expanded_docs:
#         mask_inds = np.random.choice(len(doc), size=t, replace=False)
#         mask_doc=[doc[i] for i in range(len(doc)) if i not in mask_inds]
#         words_to_predict=[doc[i] for i in range(len(doc)) if i in mask_inds]
#         docs.append(np.array(mask_doc))
#         Y.append(words_to_predict)
#     return (docs,np.array(Y))

  


