from pathlib import Path
from typing import List
from ctypes import c_uint64
import threading
import concurrent.futures
from .cfg_tokenizer import CFGTokenizer
import numpy as np
import polars as pl
import time
import gc


def replace_window_values(a, b, window_size, custom_value, offset_0=0, offset_1=0, inplace=True):
    """
    Replace values in windows of a 2D array with a custom value.
    Window goes from (indices(b in this function) - window_size + offset_0) to (indices + offset_1)
    """
    if not inplace:
        a = a.copy()

    batch_size = a.shape[0]

    # Calculate actual window size with offsets
    actual_window_size = window_size + offset_0 + offset_1

    # Create window indices for each batch element
    # Window goes from (b - window_size + offset_0) to (b + offset_1)
    start_indices = b - window_size + offset_0

    # Create indices for each window
    window_indices = np.arange(actual_window_size)[None, :] + start_indices[:, None]

    # Create batch indices
    batch_indices = np.arange(batch_size)[:, None]

    # Replace values using advanced indexing
    a[batch_indices, window_indices] = custom_value

    return a


class RingBuffer:
    def __init__(self, mem: int, width: int, column: int, dtype=np.int32):
        """ Ring Buffer
        :param mem: maximum allocated memory for each column in bytes
        :param width: width of buffer, height is calculated using ``mem``
        :param column: number of columns in the buffer, the read and write indicator of each column are kept the same
        :param dtype: dtype for the returned tensors
        """
        self.read_idx = c_uint64(0)      # Automic index
        self.write_idx = c_uint64(0)     # Max out at 18,446,744,073,709,551,616
        self.pause = False               # Pause writing

        self.lock = threading.RLock()  # Reentrant lock for thread safety
        self.data_available = threading.Condition(self.lock)  # Condition variable for reader
        self.space_available = threading.Condition(self.lock)  # Condition variable for writer
        self.shutdown_flag = False

        self.mem = mem
        self.width = width
        self.column = column
        self.height = mem // (width * np.iinfo(dtype).bits // 8)
        self.dtype = dtype
        self.buff = np.zeros((column, self.height, width), dtype=dtype)

    def read(self, length: int):
        """ Gate for buffer reader
        :param length: number of rows to read
        """
        if length > self.height:
            raise ValueError(f'read: Requested length {length} exceeds buffer capacity {self.height}.')

        with self.data_available:
            while not self.shutdown_flag and self.read_idx.value + length > self.write_idx.value:
                self.data_available.wait()

            if self.shutdown_flag:
                raise RuntimeError('Buffer is shutting down')

            result = self.__read_block(length)
            self.read_idx = c_uint64(self.read_idx.value + length)
            self.space_available.notify()

            return result

    def write(self, data: np.ndarray):
        """ Gate for buffer writer
        :param data: data to write
        :note: the current data between ``write_idx`` and ``write_idx + len(data)``
               will be overwritten if it has already been read.
        """
        if data.shape[0] != self.column:
            raise ValueError(f'write: Column count not match. Expect {self.column} columns but got {data.shape[0]} instead.')

        with self.space_available:
            while not self.shutdown_flag and self.write_idx.value + data.shape[1] > self.read_idx.value + self.height:
                self.space_available.wait()

            if self.shutdown_flag:
                return  # Don't write if shutting down

            self.__write_block(data)
            self.write_idx = c_uint64(self.write_idx.value + data.shape[1])
            self.data_available.notify()

    def __read_block(self, length: int):
        """ Read from buffer
        :param length: number of rows to read
        """
        start_idx = self.read_idx.value % self.height
        end_idx = (self.read_idx.value + length) % self.height

        if start_idx <= end_idx:
            result = self.buff[:, start_idx:end_idx]
            return result
        else:
            buff = np.empty((self.column, length, self.width), dtype=self.dtype)
            chunk1 = self.buff[:, start_idx:]
            chunk2 = self.buff[:, :end_idx]
            buff[:, :self.height - start_idx] = chunk1
            buff[:, self.height - start_idx:] = chunk2
            del chunk1, chunk2
            return buff

    def __write_block(self, data: np.ndarray):
        """ Write to buffer
        :param data: data to write
        :note: the current data between ``write_idx`` and ``write_idx + len(data)``
               will be overwritten if it has already been read.
        """
        start_idx = self.write_idx.value % self.height
        end_idx = (self.write_idx.value + data.shape[1]) % self.height

        if start_idx <= end_idx:
            self.buff[:, start_idx:end_idx] = data
        else:
            self.buff[:, start_idx:] = data[:, :self.height - start_idx]
            self.buff[:, :end_idx] = data[:, self.height - start_idx:]

    def writable(self):
        """ Get current writing state """
        return not self.pause

    def shutdown(self):
        """ Shutdown the buffer """
        with self.lock:
            self.shutdown_flag = True
            self.data_available.notify_all()
            self.space_available.notify_all()


class DataLoader:
    def __init__(self, path: List[str], pattern: str, mem: int, batch_size: int, context_len: int, tokenizer: str,
                 pad_token: int, threads: int, ratio: List[float] = None, dtype=np.int32):
        """ Data Loader
        :param path: path to the dataset
        :param pattern: file name pattern to filter dataset
        :param mem: maximum allocated memory for each buffer in MB
        :param batch_size: batch size
        :param context_len: context length of each sample
        :param pad_token: padding token
        :param threads: number of threads for processing each data source
        :param ratio: ratio of difference dataset sources
        :param dtype: dtype for the returned tensors

        :note: each data source has its own buff, and each buff has 2 columns. Column 0 is for data, column 1 is for target.

        :return data: input tokens, shaped (batch_size, context_len)
        :return target: target tokens, shaped (batch_size, context_len)
        :return mask: 2D padding mask, shaped (batch_size, context_len)
        :return indicator: scalar padding mask, shaped (batch_size,)
        """
        self.path = path
        self.epoch = 0
        self.ready = True
        self.pattern = pattern
        self.batch_size = batch_size
        self.context_len = context_len
        self.pad_token = pad_token
        self.threads = threads
        self.data_buff = [RingBuffer(mem * 1024 * 1024 // 3, context_len, 2, dtype) for _ in range(len(path))]
        self.mask_buff = [RingBuffer(mem * 1024 * 1024 // 3, context_len, 1, dtype) for _ in range(len(path))]
        self.dtype = dtype
        self.sub_batch = [ 0 for _ in range(len(path))]

        self.distribute(ratio)

        self.tokenizer = CFGTokenizer(context_length=context_len, pad_token=pad_token)

        self.shutdown_event = threading.Event()

        # Multi-thread setup
        self.processors = [
            threading.Thread(
                target=self.__process,
                args=(idx, p)
            ) for idx, p in enumerate(path)
        ]

        for p in self.processors:
            p.daemon = True
            p.start()

        # Pre-allocation
        self.pairs = np.empty((2, batch_size, context_len), dtype=dtype)
        self.mask = np.empty((1, batch_size, context_len), dtype=dtype)

    def stop(self):
        self.ready = False
        self.shutdown_event.set()

        for buffer in self.data_buff:
            buffer.shutdown()
        for buffer in self.mask_buff:
            buffer.shutdown()

        time.sleep(0.5)

        for p in self.processors:
            p.join(timeout=5.0)

        for p in self.processors:
            if p.is_alive():
                print(f'[WARNING] DataLoader.stop: thread {p.name} did not terminate gracefully')

        for buffer in self.data_buff:
            buffer.buff = None
        for buffer in self.mask_buff:
            buffer.buff = None

        self.pairs = None
        self.mask = None
        self.data_buff = None
        self.mask_buff = None
        self.processors = None

        gc.collect()
        print(f'[INFO] DataLoader.stop: loader stopped, average data repetition is {self.epoch}')

    def __next__(self):
        try:
            for i in range(len(self.path)):
                self.pairs[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.data_buff[i].read(self.sub_batch[i])
                self.mask[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.mask_buff[i].read(self.sub_batch[i])
        except RuntimeError as e:
            if 'shutting down' in str(e):
                raise StopIteration('[ERROR] DataLoader: shutting down')
            else:
                raise

        return self.pairs[0], self.pairs[1], self.mask[0], np.sum(self.mask[0], axis=-1)

    def distribute(self, ratio: List[float]):
        """ Calculate sub-batch size of different data sources
        :param ratio: ratio of difference dataset sources
        """
        if len(self.path) != 1 and len(ratio) != len(self.path) - 1:
            raise ValueError(f'[ERROR] DataLoader.distribute: Got {len(self.path)} paths, but {len(ratio)} ratios. The number of ratios is expected to be the number of paths - 1.')

        if len(self.path) != 1 and sum(ratio) > 1:
            raise ValueError(f'[ERROR] DataLoader.distribute: The sum of ratios are expected no bigger than 1, but got {sum(ratio)}.')

        for i in range(len(self.path) - 1):
            self.sub_batch[i] = round(self.batch_size * ratio[i])
        self.sub_batch[len(self.path) - 1] = self.batch_size - sum(self.sub_batch[i] for i in range(len(self.path) - 1))

    def __process(self, idx: int, path: str):
        """ Process all dataset files found according to the pattern in a single source """
        paths = list(Path(path).rglob(self.pattern))

        while self.ready and not self.shutdown_event.is_set():
            self.epoch += 1
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
                futures = [executor.submit(self.__process_file, idx, str(p)) for p in paths]

                for future in concurrent.futures.as_completed(futures):
                    if not self.ready:
                        executor.shutdown(wait=True, cancel_futures=True)
                        break
                    future.result()

    def __process_file(self, idx: int, path: str):
        """ Process a single dataset file """
        if not self.ready or self.shutdown_event.is_set():
            return

        try:
            df = pl.read_parquet(path)
            df = df.select('text')
            print(f'[INFO] DataLoader.__process_file: processing file {path.split('/')[-2]}/{path.split('/')[-1]}')
            for row in df.iter_rows():
                if row[0] is not None:
                    encoded = self.tokenizer.encode(row[0])
                    self.__process_token(idx, encoded)
            del df
        except Exception as e:
            print(f'[ERROR] DataLoader.__process_file: failed to process file ({path}): {e}')

    def __process_token(self, idx: int, encoded):
        """ Process a single token array with padding or sliding window policy """
        if not self.ready or self.shutdown_event.is_set():
            return

        token = np.asarray(encoded['input_ids'], dtype=self.dtype)
        input_seq = token.reshape(-1, self.context_len)

        target = np.concat((token[1:], np.array([self.pad_token], dtype=self.dtype)), axis=-1)
        target_seq = target.reshape(-1, self.context_len)

        masks = np.asarray(encoded['attention_mask'], dtype=self.dtype)
        mask_seq = masks.reshape(-1, self.context_len)

        self.data_buff[idx].write(np.stack((input_seq, target_seq), axis=0))
        self.mask_buff[idx].write(np.expand_dims(mask_seq, axis=0))


class CoconutDataLoader:
    def __init__(self, path: List[str], pattern: str, mem: int, batch_size: int, context_len: int, tokenizer: str,
                 pad_token: int, threads: int, soft_thinking_steps: int, ratio: List[float] = None, dtype=np.int32):
        """ Data Loader
        :param path: path to the dataset
        :param pattern: file name pattern to filter dataset
        :param mem: maximum allocated memory for each buffer in MB
        :param batch_size: batch size
        :param context_len: context length of each sample
        :param pad_token: padding token
        :param threads: number of threads for processing each data source
        :param ratio: ratio of difference dataset sources
        :param dtype: dtype for the returned tensors

        :note: each data source has its own buff, and each buff has 2 columns. Column 0 is for data, column 1 is for target.

        :return data: input tokens, shaped (batch_size, context_len)
        :return target: target tokens, shaped (batch_size, context_len)
        :return mask: 2D padding mask, shaped (batch_size, context_len)
        :return indicator: scalar padding mask, shaped (batch_size,)
        """
        self.path = path
        self.epoch = 0
        self.ready = True
        self.pattern = pattern
        self.batch_size = batch_size
        self.context_len = context_len
        self.pad_token = pad_token
        self.threads = threads
        self.data_buff = [RingBuffer(mem * 1024 * 1024 // 3, context_len, 2, dtype) for _ in range(len(path))]
        self.mask_buff = [RingBuffer(mem * 1024 * 1024 // 3, context_len, 1, dtype) for _ in range(len(path))]
        self.dtype = dtype
        self.sub_batch = [ 0 for _ in range(len(path))]

        self.distribute(ratio)

        self.tokenizer = CFGTokenizer(context_length=context_len, pad_token=pad_token)

        self.shutdown_event = threading.Event()

        # Multi-thread setup
        self.processors = [
            threading.Thread(
                target=self.__process,
                args=(idx, p)
            ) for idx, p in enumerate(path)
        ]

        for p in self.processors:
            p.daemon = True
            p.start()

        # Pre-allocation
        self.pairs = np.empty((2, batch_size, context_len), dtype=dtype)
        self.mask = np.empty((1, batch_size, context_len), dtype=dtype)

        self.soft_thinking_steps = soft_thinking_steps
        self.batch_indices = np.arange(batch_size, dtype=np.int32)[:, None]

    def stop(self):
        self.ready = False
        self.shutdown_event.set()

        for buffer in self.data_buff:
            buffer.shutdown()
        for buffer in self.mask_buff:
            buffer.shutdown()

        time.sleep(0.5)

        for p in self.processors:
            p.join(timeout=5.0)

        for p in self.processors:
            if p.is_alive():
                print(f'[WARNING] DataLoader.stop: thread {p.name} did not terminate gracefully')

        for buffer in self.data_buff:
            buffer.buff = None
        for buffer in self.mask_buff:
            buffer.buff = None

        self.pairs = None
        self.mask = None
        self.data_buff = None
        self.mask_buff = None
        self.processors = None

        gc.collect()
        print(f'[INFO] DataLoader.stop: loader stopped, average data repetition is {self.epoch}')

    def __next__(self):
        try:
            for i in range(len(self.path)):
                self.pairs[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.data_buff[i].read(self.sub_batch[i])
                self.mask[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.mask_buff[i].read(self.sub_batch[i])
        except RuntimeError as e:
            if 'shutting down' in str(e):
                raise StopIteration('[ERROR] DataLoader: shutting down')
            else:
                raise

        mask = self.mask[0]
        indices = np.sum(mask, axis=-1)
        mask[self.batch_indices, indices[:, None] - 1] = 0
        data = self.pairs[0]
        target = self.pairs[1]

        data = replace_window_values(data, indices, window_size=self.soft_thinking_steps + 1, custom_value=0)
        mask = replace_window_values(mask, indices, window_size=self.soft_thinking_steps, custom_value=0, offset_0=-2, offset_1=2)

        return data, target, mask, indices - self.soft_thinking_steps - 1

    def distribute(self, ratio: List[float]):
        """ Calculate sub-batch size of different data sources
        :param ratio: ratio of difference dataset sources
        """
        if len(self.path) != 1 and len(ratio) != len(self.path) - 1:
            raise ValueError(f'[ERROR] DataLoader.distribute: Got {len(self.path)} paths, but {len(ratio)} ratios. The number of ratios is expected to be the number of paths - 1.')

        if len(self.path) != 1 and sum(ratio) > 1:
            raise ValueError(f'[ERROR] DataLoader.distribute: The sum of ratios are expected no bigger than 1, but got {sum(ratio)}.')

        for i in range(len(self.path) - 1):
            self.sub_batch[i] = round(self.batch_size * ratio[i])
        self.sub_batch[len(self.path) - 1] = self.batch_size - sum(self.sub_batch[i] for i in range(len(self.path) - 1))

    def __process(self, idx: int, path: str):
        """ Process all dataset files found according to the pattern in a single source """
        paths = list(Path(path).rglob(self.pattern))

        while self.ready and not self.shutdown_event.is_set():
            self.epoch += 1
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
                futures = [executor.submit(self.__process_file, idx, str(p)) for p in paths]

                for future in concurrent.futures.as_completed(futures):
                    if not self.ready:
                        executor.shutdown(wait=True, cancel_futures=True)
                        break
                    future.result()

    def __process_file(self, idx: int, path: str):
        """ Process a single dataset file """
        if not self.ready or self.shutdown_event.is_set():
            return

        try:
            df = pl.read_parquet(path)
            df = df.select('text')
            print(f'[INFO] DataLoader.__process_file: processing file {path.split('/')[-2]}/{path.split('/')[-1]}')
            for row in df.iter_rows():
                if row[0] is not None:
                    encoded = self.tokenizer.encode(row[0])
                    self.__process_token(idx, encoded)
            del df
        except Exception as e:
            print(f'[ERROR] DataLoader.__process_file: failed to process file ({path}): {e}')

    def __process_token(self, idx: int, encoded):
        """ Process a single token array with padding or sliding window policy """
        if not self.ready or self.shutdown_event.is_set():
            return

        token = np.asarray(encoded['input_ids'], dtype=self.dtype)
        input_seq = token.reshape(-1, self.context_len)

        target = np.concat((token[1:], np.array([self.pad_token], dtype=self.dtype)), axis=-1)
        target_seq = target.reshape(-1, self.context_len)

        masks = np.asarray(encoded['attention_mask'], dtype=self.dtype)
        mask_seq = masks.reshape(-1, self.context_len)

        self.data_buff[idx].write(np.stack((input_seq, target_seq), axis=0))
        self.mask_buff[idx].write(np.expand_dims(mask_seq, axis=0))


class BenchmarkDataLoader:
    def __init__(self, path: List[str], pattern: str, mem: int, batch_size: int, context_len: int,
                 pad_token: int, threads: int, inference_steps: int, ratio: List[float] = None, dtype=np.int32):
        """ Data Loader
        :param path: path to the dataset
        :param pattern: file name pattern to filter dataset
        :param mem: maximum allocated memory for each buffer in MB
        :param batch_size: batch size
        :param context_len: context length of each sample
        :param pad_token: padding token
        :param threads: number of threads for processing each data source
        :param ratio: ratio of difference dataset sources
        :param dtype: dtype for the returned tensors

        :note: each data source has its own buff, and each buff has 2 columns. Column 0 is for data, column 1 is for target.

        :return data: input tokens, shaped (batch_size, context_len)
        :return target: target tokens, shaped (batch_size, context_len)
        :return mask: 2D padding mask, shaped (batch_size, context_len)
        :return indicator: scalar padding mask, shaped (batch_size,)
        """
        self.path = path
        self.epoch = 0
        self.ready = True
        self.pattern = pattern
        self.batch_size = batch_size
        self.context_len = context_len
        self.pad_token = pad_token
        self.threads = threads
        self.data_buff = [RingBuffer(mem * 1024 * 1024 // 3, context_len, 2, dtype) for _ in range(len(path))]
        self.mask_buff = [RingBuffer(mem * 1024 * 1024 // 3, context_len, 1, dtype) for _ in range(len(path))]
        self.dtype = dtype
        self.sub_batch = [ 0 for _ in range(len(path))]

        self.distribute(ratio)

        self.tokenizer = CFGTokenizer(context_length=context_len, pad_token=pad_token)

        self.shutdown_event = threading.Event()

        # Multi-thread setup
        self.processors = [
            threading.Thread(
                target=self.__process,
                args=(idx, p)
            ) for idx, p in enumerate(path)
        ]

        for p in self.processors:
            p.daemon = True
            p.start()

        # Pre-allocation
        self.pairs = np.empty((2, batch_size, context_len), dtype=dtype)
        self.mask = np.empty((1, batch_size, context_len), dtype=dtype)

        self.inference_steps = inference_steps
        self.batch_indices = np.arange(batch_size, dtype=np.int32)[:, None]

    def set_inference_steps(self, inference_steps: int):
        self.inference_steps = inference_steps

    def stop(self):
        self.ready = False
        self.shutdown_event.set()

        for buffer in self.data_buff:
            buffer.shutdown()
        for buffer in self.mask_buff:
            buffer.shutdown()

        time.sleep(0.5)

        for p in self.processors:
            p.join(timeout=5.0)

        for p in self.processors:
            if p.is_alive():
                print(f'[WARNING] DataLoader.stop: thread {p.name} did not terminate gracefully')

        for buffer in self.data_buff:
            buffer.buff = None
        for buffer in self.mask_buff:
            buffer.buff = None

        self.pairs = None
        self.mask = None
        self.data_buff = None
        self.mask_buff = None
        self.processors = None

        gc.collect()
        print(f'[INFO] DataLoader.stop: loader stopped, average data repetition is {self.epoch}')

    def __next__(self):
        try:
            for i in range(len(self.path)):
                self.pairs[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.data_buff[i].read(self.sub_batch[i])
                self.mask[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.mask_buff[i].read(self.sub_batch[i])
        except RuntimeError as e:
            if 'shutting down' in str(e):
                raise StopIteration('[ERROR] DataLoader: shutting down')
            else:
                raise

        mask = self.mask[0]
        indices = np.sum(mask, axis=-1)
        mask[self.batch_indices, indices[:, None] - 1] = 0
        data = self.pairs[0]
        target = self.pairs[1]

        data = replace_window_values(data, indices, window_size=self.inference_steps, custom_value=0)
        mask = replace_window_values(mask, indices, window_size=self.inference_steps, custom_value=0, offset_0=-2, offset_1=2)

        return data, target, mask, indices - self.inference_steps

    def distribute(self, ratio: List[float]):
        """ Calculate sub-batch size of different data sources
        :param ratio: ratio of difference dataset sources
        """
        if len(self.path) != 1 and len(ratio) != len(self.path) - 1:
            raise ValueError(f'[ERROR] DataLoader.distribute: Got {len(self.path)} paths, but {len(ratio)} ratios. The number of ratios is expected to be the number of paths - 1.')

        if len(self.path) != 1 and sum(ratio) > 1:
            raise ValueError(f'[ERROR] DataLoader.distribute: The sum of ratios are expected no bigger than 1, but got {sum(ratio)}.')

        for i in range(len(self.path) - 1):
            self.sub_batch[i] = round(self.batch_size * ratio[i])
        self.sub_batch[len(self.path) - 1] = self.batch_size - sum(self.sub_batch[i] for i in range(len(self.path) - 1))

    def __process(self, idx: int, path: str):
        """ Process all dataset files found according to the pattern in a single source """
        paths = list(Path(path).rglob(self.pattern))

        while self.ready and not self.shutdown_event.is_set():
            self.epoch += 1
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
                futures = [executor.submit(self.__process_file, idx, str(p)) for p in paths]

                for future in concurrent.futures.as_completed(futures):
                    if not self.ready:
                        executor.shutdown(wait=True, cancel_futures=True)
                        break
                    future.result()

    def __process_file(self, idx: int, path: str):
        """ Process a single dataset file """
        if not self.ready or self.shutdown_event.is_set():
            return

        try:
            df = pl.read_parquet(path)
            df = df.select('text')
            for row in df.iter_rows():
                if row[0] is not None:
                    encoded = self.tokenizer.encode(row[0])
                    self.__process_token(idx, encoded)
            del df
        except Exception as e:
            print(f'[ERROR] DataLoader.__process_file: failed to process file ({path}): {e}')

    def __process_token(self, idx: int, encoded):
        """ Process a single token array with padding or sliding window policy """
        if not self.ready or self.shutdown_event.is_set():
            return

        token = np.asarray(encoded['input_ids'], dtype=self.dtype)
        input_seq = token.reshape(-1, self.context_len)

        target_seq = token.reshape(-1, self.context_len)

        masks = np.asarray(encoded['attention_mask'], dtype=self.dtype)
        mask_seq = masks.reshape(-1, self.context_len)

        self.data_buff[idx].write(np.stack((input_seq, target_seq), axis=0))
        self.mask_buff[idx].write(np.expand_dims(mask_seq, axis=0))


class CoconutDataLoaderMultiAr:
    def __init__(self, path: List[str], pattern: str, mem: int, batch_size: int, context_len: int,
                 pad_token: int, threads: int, soft_thinking_steps: int, token_generation_step: int = 1, ratio: List[float] = None, dtype=np.int32):
        """ Data Loader
        :param path: path to the dataset
        :param pattern: file name pattern to filter dataset
        :param mem: maximum allocated memory for each buffer in MB
        :param batch_size: batch size
        :param context_len: context length of each sample
        :param pad_token: padding token
        :param threads: number of threads for processing each data source
        :param ratio: ratio of difference dataset sources
        :param dtype: dtype for the returned tensors

        :note: each data source has its own buff, and each buff has 2 columns. Column 0 is for data, column 1 is for target.

        :return data: input tokens, shaped (batch_size, context_len)
        :return target: target tokens, shaped (batch_size, context_len)
        :return mask: 2D padding mask, shaped (batch_size, context_len)
        :return indicator: scalar padding mask, shaped (batch_size,)
        """
        self.path = path
        self.epoch = 0
        self.ready = True
        self.pattern = pattern
        self.batch_size = batch_size
        self.context_len = context_len
        self.pad_token = pad_token
        self.threads = threads
        self.data_buff = [RingBuffer(mem * 1024 * 1024 // 3, context_len, 2, dtype) for _ in range(len(path))]
        self.mask_buff = [RingBuffer(mem * 1024 * 1024 // 3, context_len, 1, dtype) for _ in range(len(path))]
        self.dtype = dtype
        self.sub_batch = [ 0 for _ in range(len(path))]

        self.distribute(ratio)

        self.tokenizer = CFGTokenizer(context_length=context_len, pad_token=pad_token)

        self.shutdown_event = threading.Event()

        # Multi-thread setup
        self.processors = [
            threading.Thread(
                target=self.__process,
                args=(idx, p)
            ) for idx, p in enumerate(path)
        ]

        for p in self.processors:
            p.daemon = True
            p.start()

        # Pre-allocation
        self.pairs = np.empty((2, batch_size, context_len), dtype=dtype)
        self.mask = np.empty((1, batch_size, context_len), dtype=dtype)

        self.soft_thinking_steps = soft_thinking_steps
        self.token_generation_step = token_generation_step
        self.batch_indices = np.arange(batch_size, dtype=np.int32)[:, None]

    def stop(self):
        self.ready = False
        self.shutdown_event.set()

        for buffer in self.data_buff:
            buffer.shutdown()
        for buffer in self.mask_buff:
            buffer.shutdown()

        time.sleep(0.5)

        for p in self.processors:
            p.join(timeout=5.0)

        for p in self.processors:
            if p.is_alive():
                print(f'[WARNING] DataLoader.stop: thread {p.name} did not terminate gracefully')

        for buffer in self.data_buff:
            buffer.buff = None
        for buffer in self.mask_buff:
            buffer.buff = None

        self.pairs = None
        self.mask = None
        self.data_buff = None
        self.mask_buff = None
        self.processors = None

        gc.collect()
        print(f'[INFO] DataLoader.stop: loader stopped, average data repetition is {self.epoch}')

    def __next__(self):
        try:
            for i in range(len(self.path)):
                self.pairs[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.data_buff[i].read(self.sub_batch[i])
                self.mask[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.mask_buff[i].read(self.sub_batch[i])
        except RuntimeError as e:
            if 'shutting down' in str(e):
                raise StopIteration('[ERROR] DataLoader: shutting down')
            else:
                raise

        mask = self.mask[0]
        indices = np.sum(mask, axis=-1)
        mask[self.batch_indices, indices[:, None] - 1] = 0
        data = self.pairs[0]
        target = self.pairs[1]

        data = replace_window_values(data, indices, window_size=self.soft_thinking_steps + self.token_generation_step, custom_value=0)
        mask = replace_window_values(mask, indices, window_size=self.soft_thinking_steps, custom_value=-2, offset_0=-1, offset_1=-2)

        return data, target, mask, indices - self.soft_thinking_steps - 1

    def distribute(self, ratio: List[float]):
        """ Calculate sub-batch size of different data sources
        :param ratio: ratio of difference dataset sources
        """
        if len(self.path) != 1 and len(ratio) != len(self.path) - 1:
            raise ValueError(f'[ERROR] DataLoader.distribute: Got {len(self.path)} paths, but {len(ratio)} ratios. The number of ratios is expected to be the number of paths - 1.')

        if len(self.path) != 1 and sum(ratio) > 1:
            raise ValueError(f'[ERROR] DataLoader.distribute: The sum of ratios are expected no bigger than 1, but got {sum(ratio)}.')

        for i in range(len(self.path) - 1):
            self.sub_batch[i] = round(self.batch_size * ratio[i])
        self.sub_batch[len(self.path) - 1] = self.batch_size - sum(self.sub_batch[i] for i in range(len(self.path) - 1))

    def __process(self, idx: int, path: str):
        """ Process all dataset files found according to the pattern in a single source """
        paths = list(Path(path).rglob(self.pattern))

        while self.ready and not self.shutdown_event.is_set():
            self.epoch += 1
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
                futures = [executor.submit(self.__process_file, idx, str(p)) for p in paths]

                for future in concurrent.futures.as_completed(futures):
                    if not self.ready:
                        executor.shutdown(wait=True, cancel_futures=True)
                        break
                    future.result()

    def __process_file(self, idx: int, path: str):
        """ Process a single dataset file """
        if not self.ready or self.shutdown_event.is_set():
            return

        try:
            df = pl.read_parquet(path)
            df = df.select('text')
            print(f'[WARNING] DataLoader. This version is for coconut 4+4')
            print(f'[INFO] DataLoader.__process_file: processing file {path.split('/')[-2]}/{path.split('/')[-1]}')
            for row in df.iter_rows():
                if row[0] is not None:
                    encoded = self.tokenizer.encode(row[0])
                    self.__process_token(idx, encoded)
            del df
        except Exception as e:
            print(f'[ERROR] DataLoader.__process_file: failed to process file ({path}): {e}')

    def __process_token(self, idx: int, encoded):
        """ Process a single token array with padding or sliding window policy """
        if not self.ready or self.shutdown_event.is_set():
            return

        token = np.asarray(encoded['input_ids'], dtype=self.dtype)
        input_seq = token.reshape(-1, self.context_len)

        target = np.concat((token[1:], np.array([self.pad_token], dtype=self.dtype)), axis=-1)
        target_seq = target.reshape(-1, self.context_len)

        masks = np.asarray(encoded['attention_mask'], dtype=self.dtype)
        mask_seq = masks.reshape(-1, self.context_len)

        self.data_buff[idx].write(np.stack((input_seq, target_seq), axis=0))
        self.mask_buff[idx].write(np.expand_dims(mask_seq, axis=0))