import glob
import numpy as np
import os
import tensorflow as tf
import tqdm


def load_dataset(enc, path, combine):
    paths = []
    if os.path.isfile(path):
        # Simple file
        paths.append(path)
    elif os.path.isdir(path):
        # Directory
        for (dirpath, _, fnames) in os.walk(path):
            for fname in fnames:
                paths.append(os.path.join(dirpath, fname))
    else:
        # Assume glob
        paths = glob.glob(path)

    token_chunks = []
    raw_text = ''
    for path in tqdm.tqdm(paths):
        if path.endswith('.npz'):
            # Pre-encoded
            with np.load(path) as npz:
                for item in npz.files:
                    token_chunks.append(npz[item])
        else:
            # Plain text
            with open(path, 'r',encoding='UTF-8') as fp:
                raw_text += fp.read()
            if len(raw_text) >= combine:
                tokens = np.stack(enc.encode(raw_text))
                token_chunks.append(tokens)
                raw_text = ''
            else:
                raw_text += '<|endoftext|>'
    if raw_text:
        tokens = np.stack(enc.encode(raw_text))
        token_chunks.append(tokens)
    return token_chunks

def load_raw_dataset(enc, path, combine):
    paths = []
    if os.path.isfile(path):
        # Simple file
        paths.append(path)
    elif os.path.isdir(path):
        # Directory
        for (dirpath, _, fnames) in os.walk(path):
            for fname in fnames:
                paths.append(os.path.join(dirpath, fname))
    else:
        # Assume glob
        paths = glob.glob(path)

    raw_text = ''
    for path in tqdm.tqdm(paths):
        with open(path, 'r',encoding='UTF-8') as fp:
            raw_text += fp.read()

    Len_word = len(raw_text.strip().split())
    Len_character = len(raw_text)
    tokens = np.stack(enc.encode(raw_text))
    Len_subword = len(tokens)
    return Len_word, Len_character, Len_subword


def binary_search(f, lo, hi):
    if f(lo) or not f(hi):
        return None
    while hi > lo + 1:
        mid = (lo + hi) // 2
        if f(mid):
            hi = mid
        else:
            lo = mid
    return hi


class Sampler(object):
    """Fairly samples a slice from a set of variable sized chunks.

    'Fairly' means that the distribution is the same as sampling from one concatenated chunk,
    but without crossing chunk boundaries."""

    def __init__(self, chunks, seed=None):
        self.chunks = chunks
        self.total_size = sum(chunk.shape[0] for chunk in chunks)
        self.boundaries = [0]
        for i in range(len(chunks)):
            self.boundaries.append(self.boundaries[-1] + chunks[i].shape[0])
        self.rs = np.random.RandomState(seed=seed)

    def sample(self, length):
        assert length < self.total_size // len(
            self.chunks
        ), "Dataset files are too small to sample {} tokens at a time".format(
            length)
        while True:
            index = self.rs.randint(0, self.total_size - length - 1)
            i = binary_search(lambda j: self.boundaries[j] > index, 0,
                              len(self.boundaries) - 1) - 1
            if self.boundaries[i + 1] > index + length:
                within_chunk = index - self.boundaries[i]
                return self.chunks[i][within_chunk:within_chunk + length]

    def sample_all(self, length):
        assert length < self.total_size // len(
            self.chunks
        ), "Dataset files are too small to sample {} tokens at a time".format(
            length)
        while True:
            test_all = []

            index_all = range(0, self.total_size, length)
            for index in index_all[:-1]:
                test_all.append(self.chunks[0][index:index + length])
            test_all.append(self.chunks[0][index_all[-1]:])
            return test_all

    def sample_rgbn(self, length, sent_J):
        assert length < self.total_size // len(
            self.chunks
        ), "Dataset files are too small to sample {} tokens at a time".format(
            length)
        while True:
            sent_J_batch = []
            index = self.rs.randint(0, self.total_size - length*sent_J - 1)
            i = binary_search(lambda j: self.boundaries[j] > index, 0,
                              len(self.boundaries) - 1) - 1
            if self.boundaries[i + 1] > index + length:
                within_chunk = index - self.boundaries[i]
            for _ in range(sent_J):
                sent_J_batch.append(self.chunks[i][within_chunk:within_chunk + length])
                within_chunk += length
            return sent_J_batch

    def sample_all_rgbn(self, length, sent_J):
        assert length < self.total_size // len(
            self.chunks
        ), "Dataset files are too small to sample {} tokens at a time".format(
            length)
        while True:
            test_all = []
            index_all = range(0, self.total_size, length)
            for index in index_all[:-1]:
                test_all.append(self.chunks[0][index:index + length])

            test_rgbn_all=[]
            for index in range(len(test_all)):
                if index<sent_J:
                    s = []
                    for i in range(sent_J-index-1):
                        s.append(np.zeros(length))
                    s.extend(test_all[0:index+1])
                    test_rgbn_all.append(s)
                else:
                    test_rgbn_all.append(test_all[(index-sent_J+1):index+1])

            # last_length = self.total_size-index_all[-1]
            last_s=np.reshape(self.chunks[0][-length*sent_J:],[sent_J,length])
            test_rgbn_all.append(last_s)

            return test_rgbn_all
