import torch
import os
import logging
import numpy as np

from typing import Any, List, Union, Optional
from tqdm import tqdm
from transformers import AutoTokenizer
from multiprocessing import Pool
from LLMProxy.option import PreprocessArg

logger = logging.getLogger("Preprocess")


def best_fitting_int_dtype(
    max_int_to_represent,
) -> Union[np.uint16, np.uint32, np.int64]:
    """Borrow from Fairseq"""
    if max_int_to_represent is None:
        return np.uint32  # Safe guess
    elif max_int_to_represent < 65500:
        return np.uint16
    elif max_int_to_represent < 4294967295:
        return np.uint32
    else:
        return np.int64


def load_files_from_data_folder(data_dir):
    """Load all files from data folder"""
    if not os.path.exists(data_dir):
        raise FileNotFoundError(f'Not found data folder: {data_dir}')
    
    files = [os.path.join(data_dir, file) for file in os.listdir(data_dir)]
    return files


def index_file_path(prefix_path):
    return prefix_path + ".idx"


def data_file_path(prefix_path):
    return prefix_path + ".bin"


def get_tokenizer(tokenizer, token=""):
    return AutoTokenizer.from_pretrained(tokenizer, legacy=False, token=token)


def _safe_readline(fd) -> str:
    # Borrow from fairseq, this part is to guarantee the boundary reading is safe
    pos = fd.tell()
    while True:
        try:
            return fd.readline()
        except UnicodeDecodeError:
            pos -= 1
            fd.seek(pos)  # search where this character begins


def find_offsets(filename: str, num_chunks: int) -> List[int]:
    """
    given a file and a number of chuncks, find the offsets in the file
    to be able to chunk around full lines.
    """
    with open(filename, "r", encoding="utf-8") as f:
        size = os.fstat(f.fileno()).st_size
        chunk_size = size // num_chunks
        offsets = [0 for _ in range(num_chunks + 1)]
        for i in range(1, num_chunks):
            f.seek(chunk_size * i)
            _safe_readline(f)
            offsets[i] = f.tell()
        offsets[-1] = size
        return offsets


class Writer:
    """
    Writer that write tokenized tensor into files
    """
    def __init__(
        self, 
        outfile_pref: str
    ):
        self.out_file = open(data_file_path(outfile_pref), 'wb')
        self.sizes = []
        self.idx_file = open(index_file_path(outfile_pref), 'wb')

    def add_items(self, items):
        self.out_file.write(items.tobytes(order="C"))
        self.sizes.append(len(items))

    def finalize(self):
        """Write index file"""
        np.save(self.idx_file, np.array(self.sizes))

        self.out_file.close()
        self.idx_file.close()

    def summary(self):
        pass


class Binarizer(object):
    """
    Tokenize str line by using a dictionary / vocabulary,
    inspired from fairseq.
    """
    def __init__(
        self,
        tokenizer: str,
        append_eos: bool = True,
        already_numberized: bool = False,
        token: str = "",
    ):  
        self.tokenizer = get_tokenizer(tokenizer=tokenizer, token=token)
        self.append_eos = append_eos
        self.already_numberized = already_numberized

        logger.info(f"Using {tokenizer} to tokenize corpus ...")

    def tokenize_line_to_ids(
        self,
        line: str,
        **kwargs,
    ):
        if self.already_numberized:
            ids = list(map(int, line.strip().split()))
        else:
            ids = self.tokenizer.encode(line)
        
        if self.append_eos:
            ids.append(self.tokenizer.eos_token_id)
        ids = torch.IntTensor(ids)
        return ids


class FileBinarizer(object):
    """
    A file binarizer to tokenize, process a (small) file. 
    Outputs: (ids: np.array, sizes: array)
    """    
    @classmethod
    def process_file(
        cls,
        binarizer: Binarizer,
        filename: str,
        **kwargs
    ):
        vocab_size = len(binarizer.tokenizer)
        dtype = best_fitting_int_dtype(vocab_size)

        ids = []
        with open(filename, 'r', encoding='utf-8') as input_file:
            for line in input_file:
                ids.append(np.array(binarizer.tokenize_line_to_ids(line), dtype=dtype))
        ids = np.concatenate(ids)
        return ids

    """
    Using multiprocess to tokenize dataset with massive files
    """
    @classmethod
    def multiprocess_files(
        cls,
        binarizer: Binarizer,
        args: PreprocessArg,
        **kwargs
    ):
        files = load_files_from_data_folder(args.data_dir)

        if not os.path.exists(args.dest_dir):
            os.makedirs(args.dest_dir)
        save_pref = os.path.join(args.dest_dir, args.save_name)
        logger.info(f"Binarize datasets into tensors")

        writer = Writer(save_pref)
        
        pbar = tqdm(total=len(files))
        pool = Pool(args.worker)
        processed_results = [
            pool.apply_async(
                cls.process_file,
                args=(
                    binarizer,
                    file,
                ),
                callback=lambda _: pbar.update(1),
            )
            for file in files
        ]
        pool.close()
        pool.join()

        logger.info(f"Binarized dataset completed. Write them to {args.dest_dir}/{args.save_name}.bin ...")
        for result in tqdm(processed_results):
            ids = result.get()
            writer.add_items(ids)
        
        writer.finalize()
        return writer


# Design for processing large files
class LargeFileBinarizer(object):

    @classmethod
    def get_chunk_iterator(cls, filename: str, start_offset: int, end_offset: int):
        with open(filename, 'r', encoding='utf-8') as input_file:
            input_file.seek(start_offset)
            line = _safe_readline(input_file)

            while line:
                pos = input_file.tell()
                if (
                    end_offset > 0
                    and pos > end_offset
                    and pos < end_offset + 2**32
                ):
                    break
                yield line
                line = input_file.readline()
    
    @classmethod
    def process_file_chunk(
        cls,
        binarizer: Binarizer,
        filename: str,
        start_offset: int,
        end_offset: int,
        **kwargs
    ):
        vocab_size = len(binarizer.tokenizer)
        dtype = best_fitting_int_dtype(vocab_size)

        ids = []
        for line in cls.get_chunk_iterator(filename, start_offset, end_offset):
            ids.append(np.array(binarizer.tokenize_line_to_ids(line), dtype=dtype))
        ids = np.concatenate(ids)
        return ids

    @classmethod
    def process_large_file(
        cls,
        binarizer: Binarizer,
        args: PreprocessArg,
        **kwargs
    ):
        files = load_files_from_data_folder(args.data_dir)

        if not os.path.exists(args.dest_dir):
            os.makedirs(args.dest_dir)
        save_pref = os.path.join(args.dest_dir, args.save_name)
        logger.info(f"Binarize datasets into tensors")

        writer = Writer(save_pref)

        for file in files:

            file_offset = find_offsets(file, args.worker)
            offsets = list(zip(file_offset[:-1], file_offset[1:]))
            pbar = tqdm(total=len(offsets))
            pool = Pool(args.worker)
            processed_results = [
                pool.apply_async(
                    cls.process_file_chunk,
                    args=(
                        binarizer,
                        file,
                        offset[0],
                        offset[1],
                    ),
                    callback=lambda _: pbar.update(1),
                )
                for offset in offsets
            ]
            pool.close()
            pool.join()

            logger.info(f"Binarized {file} completed. Write them to {args.dest_dir}/{args.save_name}.bin ...")
            
            for result in tqdm(processed_results):
                ids = result.get()
                writer.add_items(ids)

        writer.finalize()
        return writer


