import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import word_tokenize
from collections import Counter
import glob
from multiprocessing import Pool

def subsample_probability(frequency, hyper):
    """Subsumpling procedure used in Mikolov et. al. Distributed Representations of
    Words and Phrases and their Compositionality.
    However, the paper has an error, and the actual C code (e.g. https://github.com/svn2github/word2vec )
    The formula is different. We adopted the formula used in the C source code (word2vec.c, L396).

    :param frequency: array of word occurances
    :return idx_to_prob:
       array of probabilities of
       ommitting the vocab in the training sample.
    """
    t = hyper["subsampling_threshold"]
    f = frequency / sum(frequency)
    # the formula from the paper.
    # I believe the original formula in the paper
    # may return a value outside [0,1].
    # return np.clip(1 - np.sqrt(t / f), 0, 1)
    #
    # the formula from the C code.
    return 1-(np.sqrt(f / t) + 1)*(t / f)


def vocab_to_dict(vocab, hyper):
    _word_to_idx = {"<unk>":0}
    _idx_to_word = ["<unk>"]
    _idx_to_freq = [1]          # avoid division by zero
    idx=1
    for word, occurence in tqdm(vocab.items()):
        # note: dataset may already contain <unk>
        if word == "<unk>":
            # _idx_to_freq[0]+=occurence
            pass
        elif occurence > hyper["min_occurrence"]:
            _word_to_idx[word]=idx
            _idx_to_word.append(word)
            _idx_to_freq.append(occurence)
            # ^^^^ append could be slow, but amortized O(1)
            idx+=1
        else:
            _word_to_idx[word]=0 # <unk>
            # _idx_to_freq[0]+=occurence
    _idx_to_word = np.array(_idx_to_word)
    _idx_to_freq = np.array(_idx_to_freq)
    _idx_to_prob = subsample_probability(_idx_to_freq, hyper)
    return _word_to_idx, _idx_to_word, _idx_to_freq, _idx_to_prob


import re
WORD = re.compile(r'\S+') # wikitext-103 is already tokenized
def re_tokenize(line):
    return WORD.findall(line)

WORD2 = re.compile(r'\w+|<unk>') # ignore all punctuations, also numbers
def re2_tokenize(line):
    return WORD2.findall(line)


def tokenize_line(line):
    return re2_tokenize(line.lower())

def build_vocab(input):
    with open(input,'r',encoding='utf-8') as f:
        vocab = Counter()
        total = 0
        for line in f:
            tokens = tokenize_line(line)
            vocab.update(tokens)
            total += len(tokens)
    return vocab, total

def collect_tokens(args):       # workaround for imap
    return _collect_tokens(*args)
def _collect_tokens(input,total,word_to_idx):
    with open(input,'r',encoding='utf-8') as f:
        text = np.zeros(total,dtype=np.int32)
        count = 0
        for line in f:
            tokens = tokenize_line(line)
            for i,word in enumerate(tokens):
                try:
                    text[count+i] = word_to_idx[word]
                except KeyError:
                    text[count+i] = 0            # <unk>
            count=count+len(tokens)
    return text


class WikiText2DataSet:
    def __init__(self, input=None, vocab=None, **hyper):
        import os.path
        # cache the input preprocessing.
        if os.path.exists(input.replace("*","_")+".npz") and not hyper["force_reload"]:
            self.load_cache(input,vocab,**hyper)
        else:
            self.save_cache(input,vocab,**hyper)

        # subsample training data by replacing them with <unk>
        if hyper["subsampling"] and vocab is None:
            r = np.random.rand(self.text.shape[0])
            p = self._idx_to_prob[self.text]
            replace_idx = (r < p)
            print(f"marking {replace_idx.sum()} / {len(replace_idx)} words ({replace_idx.mean()*100:.2f}%) as <unk> ")
            print("for example,")
            print(self._idx_to_word[self.text[replace_idx][:30]])
            print("whose probabilities are:")
            print(p[replace_idx][:30])
            self.text[replace_idx] = 0

        print("the input text has",len(self.text),"words")
        print("in which",len(self._idx_to_word),"words are unique")
        print("and ",len(self.vocab)-len(self._idx_to_word),"words are treated as <unk>")
        self.vocab_size = len(self._idx_to_word)
        self.context_before = hyper["context_before"]
        self.context_after  = hyper["context_after"]

    def load_cache(self, input=None, vocab=None, **hyper):
        print("loading the cached dataset... :",input+".npz")
        with np.load(input.replace("*","_")+".npz",allow_pickle=True) as data:
            self.vocab = data["vocab"].item()
            self._word_to_idx = data["w2i"].item()
            self._idx_to_word = data["i2w"]
            self._idx_to_freq = data["i2f"]
            self._idx_to_prob = data["i2p"]
            self.text = data["text"]

    def save_cache(self, input=None, vocab=None, **hyper):
        print("caching the raw dataset... :",input,
              "(forced)" if hyper["force_reload"] else "")
        # when vocab is specified (from the training dataset),
        # use the same vocaburary. However, the file still needs
        # to be scanned for setting up the total length.

        self.vocab = vocab

        inputs = glob.glob(input)
        if len(inputs)==1:      # split the file into smaller chunks
            import subprocess
            cmd = ["split", "-l", "300000", "-d", inputs[0], inputs[0]+".part."]
            print(cmd)
            subprocess.run(cmd)
            inputs = glob.glob(input+".part.*")
        print(inputs)

        vocab = Counter()
        total = 0
        sub_totals = []
        with Pool(hyper["threads"]) as p:
            print("running in",hyper["threads"],"threads")
            for sub_vocab, sub_total in tqdm(p.imap(build_vocab,inputs),
                                             unit="Files",
                                             total=len(inputs),):
                # print(list(vocab.items())[:20])
                # print(list(sub_vocab.items())[:20])
                vocab += sub_vocab
                total += sub_total
                sub_totals.append(sub_total)

        if self.vocab is None:
            self.vocab = vocab

        self._word_to_idx, self._idx_to_word, self._idx_to_freq, self._idx_to_prob = \
            vocab_to_dict(self.vocab,hyper)

        with Pool(hyper["threads"]) as p:
            from itertools import repeat
            self.text = np.zeros(total,dtype=np.int32)
            current = 0
            for sub_total, sub_text in \
                tqdm(
                    zip(sub_totals,
                        p.imap(collect_tokens,
                               zip(inputs,
                                   sub_totals,
                                   repeat(self._word_to_idx, len(inputs))))),
                    total=len(inputs),
                    unit="Files"):
                self.text[current:current+sub_total] = sub_text
                current += sub_total

        print("saving a cache file... :",input+".npz")
        np.savez_compressed(input.replace("*","_")+".npz",
                            vocab=self.vocab,
                            w2i=self._word_to_idx,
                            i2w=self._idx_to_word,
                            i2f=self._idx_to_freq,
                            i2p=self._idx_to_prob,
                            text=self.text)
        
    def word_to_idx(self,word):
        try:
            return self._word_to_idx[word]
        except KeyError:
            return 0            # <unk>
    
    def idx_to_word(self,idx):
        return self._idx_to_word[idx]



def remove_unk(data,target,context):
    # remove data with more than 1 <unk> in context
    unk_target_idx  = (target==0)
    unk_context_idx = (np.sum((context==0), axis=1)>=2)
    delete_idx = np.bitwise_or(unk_target_idx, unk_context_idx)
    print(f"removed {len(delete_idx)-delete_idx.sum()} data points ({(1-delete_idx.mean())*100:.2f}%) with <unk>")
    return data[np.bitwise_not(delete_idx)]


# Changed the iteration order, significantly faster
# because the inner iteration happens in numpy/C++.
# Same memory requrement, 10 sec to load wikitext-103.
class WikiText2DataSetSkipGram_v3(WikiText2DataSet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        print("converting the data for SkipGram...")
        text   = self.text
        before = self.context_before
        after  = self.context_after

        length = len(text)-before-after
        data = np.zeros(( length * (before+after) , 2), dtype=np.int32)
        data2 = data.reshape(length, (before+after) , 2)

        bar = tqdm(total=before+after+1)
        bar.update()
        if after == 0:
            data2[:,:,0] = text[before:,None]
        else:
            data2[:,:,0] = text[before:-after,None]
        for j in range(before):
            bar.update()
            data2[:,j,1] = text[j:length+j]
        for j in range(after):
            bar.update()
            data2[:,before+j,1] = text[before+1+j:length+before+1+j]
        bar.close()

        data2 = remove_unk(data2, data2[:,0,0], data2[:,:,1])

        self.data = data2.reshape(-1, 2)
        # release reference
        del text
        del self.text
        print(self.data.shape)


class WikiText2DataSetCBOW_v3(WikiText2DataSet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        print("converting the data for CBOW...") 
        text   = self.text
        before = self.context_before
        after  = self.context_after

        length = len(text)-before-after
        data = np.zeros(( length, before+after+1), dtype=np.int32)

        bar = tqdm(total=before+after+1)
        for j in range(before):
            bar.update()
            data[:,j] = text[j:length+j]
        for j in range(after):
            bar.update()
            data[:,before+j] = text[before+1+j:length+before+1+j]
        bar.update()
        if after == 0:
            data[:,-1]        = text[before:]
        else:
            data[:,-1]        = text[before:-after]
        bar.close()

        data = remove_unk(data, data[:, -1], data[:,:-1])

        self.data = data
        # release reference
        del text
        del self.text
        print(self.data.shape)


# Slightly optimized variants. I know this does not fully resolve
# the problem of having the entire dataset in-memory,
# but better than nothing.
# CBOW requires around 2GB memory, 4min to load wikitext-103.
class WikiText2DataSetSkipGram_v2(WikiText2DataSet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        print("converting the data for SkipGram...")
        text   = self.text
        before = self.context_before
        after  = self.context_after
        r = range(before, len(text)-after)
        data = np.zeros(( len(r) * (before+after) , 2), dtype=np.int32)
        data2 = data.reshape(len(r), (before+after) , 2)

        for i in tqdm(r):
            data2[i,:,0] = text[i]
            data2[i,:before,1] = text[i-before:i]
            data2[i,before:,1] = text[i+1:i+1+after]

        self.data = data
        # release reference
        del text
        del self.text
        print(self.data.shape)


class WikiText2DataSetCBOW_v2(WikiText2DataSet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        print("converting the data for CBOW...") 
        text   = self.text
        before = self.context_before
        after  = self.context_after
        r = range(before, len(text)-after)
        data = np.zeros(( len(r), before+after+1), dtype=np.int32)
        for i in tqdm(r):
            data[i-before,:before]   = text[i-before:i]
            data[i-before,before:-1] = text[i+1:i+after]
            data[i-before,-1]        = text[i]

        self.data = data
        # release reference
        del text
        del self.text
        print(self.data.shape)



# Initial, slow but correct versions. DO NOT USE.
# It has to create a huge intermediate python array (not numpy arrays),
# making wikitext-103 unloadable in 16GB memory.
class WikiText2DataSetSkipGram_v1(WikiText2DataSet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data =  [[self.text[i], self.text[i+j]]
                      for j in range(-self.context_before, self.context_after + 1) if j != 0
                      for i in range(self.context_before, len(self.text)-self.context_after)]
        del self.text
        self.data = np.asarray(self.data)
        print(self.data.shape)


class WikiText2DataSetCBOW_v1(WikiText2DataSet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.data =  [[self.text[i-(j+1)] for j in range(self.context_before)] +
                      [self.text[i+(j+1)] for j in range(self.context_after)] +
                      [self.text[i]]
                      for i in range(self.context_before, len(self.text)-self.context_after)]
        del self.text
        self.data = np.asarray(self.data)
        print(self.data.shape)


WikiText2DataSetSkipGram = WikiText2DataSetSkipGram_v3
WikiText2DataSetCBOW     = WikiText2DataSetCBOW_v3

