import torch
import os
import numpy as np

from .binarizer import (
    index_file_path,
    data_file_path,
    best_fitting_int_dtype,
    get_tokenizer
)
from functools import lru_cache
from typing import Optional, List, Dict, Any


class LMMemmapDataset(torch.utils.data.Dataset):
    """
    LM Dataset that map data into disk memory via np.memmap

    Args:
        path: str, the path that include training data
        tokenizer: str, the name of tokenizer
        split: str, the prefix for training data
        block_size: int, the maximum length of training data
    """
    def __init__(
        self, 
        path: str,
        tokenizer: str,
        split: str = 'train',
        block_size: int = 1024,
    ):
        self.idx_path = index_file_path(os.path.join(path, split))
        self.bin_path = data_file_path(os.path.join(path, split))
        self.tokenizer = get_tokenizer(tokenizer)
        self.block_size = block_size
        self.dtype = best_fitting_int_dtype(len(self.tokenizer))
        
        self.do_init()

    def do_init(self):
        if not os.path.exists(self.idx_path):
            raise FileNotFoundError(f'Not found index data: {self.idx_path}')
        if not os.path.exists(self.bin_path):
            raise FileNotFoundError(f'Not found binarized data: {self.bin_path}')
        
        self.idx = np.load(self.idx_path)
        self._bin_buffer_mmap = np.memmap(
            self.bin_path, mode='r', order='C'
        )
        self._bin_buffer = memoryview(self._bin_buffer_mmap)

    def __len__(self):
        return self.token_num // self.block_size

    @property
    def token_num(self):
        return sum(self.idx)

    @lru_cache(maxsize=8)
    def __getitem__(self, i):
        # Need to check if vocab size > 65536
        ptr = i * self.block_size * 2
        np_array = np.frombuffer(
            self._bin_buffer, dtype=self.dtype, count=self.block_size, offset=ptr
        ).astype(np.int64)
        assert max(np_array) < len(self.tokenizer)

        return torch.from_numpy(np_array[np.newaxis, :])

    def __del__(self):
        if hasattr(self, "_bin_buffer_mmap"):
            self._bin_buffer_mmap._mmap.close()
            del self._bin_buffer_mmap
            del self.idx


class MultipleCorpusMemmapDataset(torch.utils.data.Dataset):
    """
    Dataset that integrate multiple LM datasets

    Args:
        path: str, the path that include training data
        tokenizer: str, the name of tokenizer
        splits: List[str], the prefix for the collection of training data
        block_size: int, the maximum length of training data
    """
    def __init__(
        self,
        path: str,
        tokenizer: str,
        splits: List[str],
        block_size: int = 1024,
    ):
        self.datasets = [
            LMMemmapDataset(
                path=path,
                tokenizer=tokenizer,
                split=split,
                block_size=block_size,
            )
            for split in splits
        ]
        self.lens = [len(ds) for ds in self.datasets]
        self.block_size = block_size
        self.tokenizer = get_tokenizer(tokenizer)

    @lru_cache(maxsize=16)
    def __getitem__(self, index):
        for i, len in enumerate(self.lens):
            if index < len:
                return self.datasets[i][index]
            index = index - len

        return self.datasets[0][0]

    @property
    def token_num(self):
        return len(self) * self.block_size

    def __len__(self):
        return sum(self.lens)

    def __del__(self):
        for ds in self.datasets:
            ds.__del__()