from . import data_utils

import torch
import os
import numpy as np

from functools import lru_cache
from typing import List
from multiprocessing import Pool

from CoLM.option import TrainArg


class LMMemmapDataset(torch.utils.data.Dataset):
    """
    LM Dataset that map data into disk memory via np.memmap

    Args:
        path: str, the path that include training data
        tokenizer: str, the name of tokenizer
        split: str, the prefix for training data
        block_size: int, the maximum length of training data
    """
    def __init__(
        self, 
        path: str,
        vocab_size: int,
        split: str = 'train',
        block_size: int = 1024,
    ):
        self.idx_path = data_utils.idx_file_path(os.path.join(path, split))
        self.bin_path = data_utils.bin_file_path(os.path.join(path, split))
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.dtype = data_utils.best_fitting_int_dtype(self.vocab_size)
        
        self.do_init()

    def do_init(self):
        if not os.path.exists(self.idx_path):
            raise FileNotFoundError(f'Not found index data: {self.idx_path}')
        if not os.path.exists(self.bin_path):
            raise FileNotFoundError(f'Not found binarized data: {self.bin_path}')
        
        self.idx = np.load(self.idx_path)
        self._bin_buffer_mmap = np.memmap(
            self.bin_path, mode='r', order='C'
        )
        # self._bin_buffer = memoryview(self._bin_buffer_mmap)
    
    @classmethod
    def build_dataset(
        cls,
        path: str,
        vocab_size: int,
        split: str = 'train',
        block_size: int = 1024,
    ):
        return cls(path, vocab_size, split, block_size)

    def __len__(self):
        return self.token_num // self.block_size

    @property
    def token_num(self) -> int:
        return sum(self.idx)

    @lru_cache(maxsize=8)
    def __getitem__(self, i) -> torch.Tensor:
        # Need to check if vocab size > 65536
        ptr = i * self.block_size * 2
        _bin_buffer = self._bin_buffer_mmap[ptr:ptr + self.block_size * 2]
        np_array = np.frombuffer(
            _bin_buffer, dtype=np.uint16
        ).astype(np.int64)
        #np_array = np.frombuffer(
        #    self._bin_buffer, dtype=self.dtype, count=self.block_size, offset=ptr
        #).astype(np.int64)
        assert max(np_array) < self.vocab_size

        return torch.from_numpy(np_array[np.newaxis, :])

    def __del__(self):
        if hasattr(self, "_bin_buffer_mmap"):
            self._bin_buffer_mmap._mmap.close()
            del self._bin_buffer_mmap
            del self.idx


class LMCorpusMemmapDataset(torch.utils.data.Dataset):
    """
    Dataset that integrates multiple LM datasets

    Args:
        path: str, the path that include training data
        tokenizer: str, the name of tokenizer
        splits: List[str], the prefix for the collection of training data
        block_size: int, the maximum length of training data
    """
    def __init__(
        self,
        path: str,
        vocab_size: int,
        splits: List[str],
        block_size: int = 1024,
    ):
        print(splits)
        print(path)
        self.datasets = [
            LMMemmapDataset.build_dataset(
                path=path,
                vocab_size=vocab_size,
                split=split,
                block_size=block_size,
            )
            for split in splits
        ]
        self.lens = [len(ds) for ds in self.datasets]
        self.block_size = block_size

    @lru_cache(maxsize=16)
    def __getitem__(self, index) -> torch.Tensor:
        for i, len in enumerate(self.lens):
            if index < len:
                return self.datasets[i][index]
            index = index - len

        return self.datasets[0][0]

    @property
    def token_num(self) -> int:
        return len(self) * self.block_size

    def __len__(self) -> int:
        return sum(self.lens)

    def __del__(self):
        for ds in self.datasets:
            ds.__del__()

    @classmethod
    def build_dataset(cls, splits: List[str], args: TrainArg, vocab_size: int):
        return cls(
            path=args.data_dir,
            vocab_size=vocab_size,
            splits=splits,
            block_size=args.seq_len,
        )