import os
import logging
import torch
import numpy as np

from multiprocessing import Pool
from transformers.tokenization_utils import PreTrainedTokenizerBase
from tqdm import tqdm

from . import data_utils

logger = logging.getLogger("Preprocess")


class Writer:
    """
    Writer that write tokenized tensor into files
    """
    def __init__(self, outfile_pref: str):
        self.out_file = open(data_utils.bin_file_path(outfile_pref), 'wb')
        self.sizes = []
        self.idx_file = open(data_utils.idx_file_path(outfile_pref), 'wb')

    def add_items(self, items):
        """Write items to the output file."""
        self.out_file.write(items.tobytes(order="C"))
        self.sizes.append(len(items))

    def finalize(self):
        """Write index file and close file handles."""
        np.save(self.idx_file, np.array(self.sizes))

        self.out_file.close()
        self.idx_file.close()


class Binarizer:
    """
    Tokenize string line by using a dictionary / vocabulary,
    inspired from fairseq.
    """
    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        append_eos: bool = True,
        already_numberized: bool = False,
    ):
        self.tokenizer = tokenizer
        self.append_eos = append_eos
        self.already_numberized = already_numberized
        self.dtype = data_utils.best_fitting_int_dtype(len(tokenizer))

    def tokenize_line_to_ids(self, line: str) -> torch.IntTensor:
        """Tokenize a line to a list of ids."""
        if self.already_numberized:
            ids = list(map(int, line.strip().split()))
        else:
            ids = self.tokenizer.encode(line)

        if self.append_eos:
            ids.append(self.tokenizer.eos_token_id)
        return torch.IntTensor(ids)

    def save_tokenizer(self, save_dir: str):
        """Save tokenizer to folder"""
        os.makedirs(save_dir, exist_ok=True)
        self.tokenizer.save_pretrained(save_dir)


class DatasetBinarizer:
    """
    Dataset Builder to binarize raw datasets.
    """
    @classmethod
    def process_file_or_chunk(
        cls,
        binarizer: Binarizer,
        filename: str,
        start_offset: int,
        end_offset: int
    ) -> np.ndarray:
        """Process a file or a chunk and return tokenized ids as np.array."""
        dtype = binarizer.dtype

        ids = []
        for line in data_utils.get_chunk_iterator(filename, start_offset, end_offset):
            ids.append(np.array(binarizer.tokenize_line_to_ids(line), dtype=dtype))
        ids = np.concatenate(ids)
        return ids

    @classmethod
    def binarize_dataset_with_multiprocess(
        cls,
        binarizer: Binarizer,
        data_dir: str,
        prefix_name: str,
        dest_dir: str,
        save_name: str,
        chunk_load: bool = False,
        num_worker: int = 1,
    ) -> Writer:
        """Binarize the dataset by using multiple processes."""
        files = data_utils.load_files_from_folder(data_dir, prefix_name)
        file_offsets = data_utils.get_file_offsets(files, chunk_load, num_worker)

        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)

        writer = Writer(os.path.join(dest_dir, save_name))

        pbar = tqdm(total=len(file_offsets))
        pool = Pool(num_worker)

        processed_results = [
            pool.apply_async(
                cls.process_file_or_chunk,
                args=(
                    binarizer,
                    file_offset[0],
                    file_offset[1][0],
                    file_offset[1][1],
                ),
                callback=lambda _: pbar.update(1),
            )
            for file_offset in file_offsets
        ]
        pool.close()
        pool.join()

        logger.info(f"Binarized dataset completed. Write them to {dest_dir}/{save_name}.bin ...")
        binarizer.save_tokenizer(dest_dir)
        for result in tqdm(processed_results):
            ids = result.get()
            writer.add_items(ids)

        writer.finalize()
        return writer
    