from pathlib import Path
from typing import List
from ctypes import c_uint64
import threading
import concurrent.futures
from tokenizers import Tokenizer
import numpy as np
import polars as pl
import time
import gc


class RingBuffer:
    def __init__(self, mem: int, width: int, column: int, dtype=np.int32):
        """ Ring Buffer
        :param mem: maximum allocated memory for each column in bytes
        :param width: width of buffer, height is calculated using ``mem``
        :param column: number of columns in the buffer, the read and write indicator of each column are kept the same
        :param dtype: dtype for the returned tensors
        """
        self.read_idx = c_uint64(0)      # Automic index
        self.write_idx = c_uint64(0)     # Max out at 18,446,744,073,709,551,616
        self.pause = False               # Pause writing

        self.lock = threading.RLock()  # Reentrant lock for thread safety
        self.data_available = threading.Condition(self.lock)  # Condition variable for reader
        self.space_available = threading.Condition(self.lock)  # Condition variable for writer
        self.shutdown_flag = False

        self.mem = mem
        self.width = width
        self.column = column
        self.height = mem // (width * np.iinfo(dtype).bits // 8)
        self.dtype = dtype
        self.buff = np.zeros((column, self.height, width), dtype=dtype)

    def read(self, length: int):
        """ Gate for buffer reader
        :param length: number of rows to read
        """
        if length > self.height:
            raise ValueError(f'read: Requested length {length} exceeds buffer capacity {self.height}.')

        with self.data_available:
            while not self.shutdown_flag and self.read_idx.value + length > self.write_idx.value:
                self.data_available.wait()

            if self.shutdown_flag:
                raise RuntimeError('Buffer is shutting down')

            result = self.__read_block(length)
            self.read_idx = c_uint64(self.read_idx.value + length)
            self.space_available.notify()

            return result

    def write(self, data: np.ndarray):
        """ Gate for buffer writer
        :param data: data to write
        :note: the current data between ``write_idx`` and ``write_idx + len(data)``
               will be overwritten if it has already been read.
        """
        if data.shape[0] != self.column:
            raise ValueError(f'write: Column count not match. Expect {self.column} columns but got {data.shape[0]} instead.')

        with self.space_available:
            while not self.shutdown_flag and self.write_idx.value + data.shape[1] > self.read_idx.value + self.height:
                self.space_available.wait()

            if self.shutdown_flag:
                return  # Don't write if shutting down

            self.__write_block(data)
            self.write_idx = c_uint64(self.write_idx.value + data.shape[1])
            self.data_available.notify()

    def __read_block(self, length: int):
        """ Read from buffer
        :param length: number of rows to read
        """
        start_idx = self.read_idx.value % self.height
        end_idx = (self.read_idx.value + length) % self.height

        if start_idx <= end_idx:
            result = self.buff[:, start_idx:end_idx]
            return result
        else:
            buff = np.empty((self.column, length, self.width), dtype=self.dtype)
            chunk1 = self.buff[:, start_idx:]
            chunk2 = self.buff[:, :end_idx]
            buff[:, :self.height - start_idx] = chunk1
            buff[:, self.height - start_idx:] = chunk2
            del chunk1, chunk2
            return buff

    def __write_block(self, data: np.ndarray):
        """ Write to buffer
        :param data: data to write
        :note: the current data between ``write_idx`` and ``write_idx + len(data)``
               will be overwritten if it has already been read.
        """
        start_idx = self.write_idx.value % self.height
        end_idx = (self.write_idx.value + data.shape[1]) % self.height

        if start_idx <= end_idx:
            self.buff[:, start_idx:end_idx] = data
        else:
            self.buff[:, start_idx:] = data[:, :self.height - start_idx]
            self.buff[:, :end_idx] = data[:, self.height - start_idx:]

    def writable(self):
        """ Get current writing state """
        return not self.pause

    def shutdown(self):
        """ Shutdown the buffer """
        with self.lock:
            self.shutdown_flag = True
            self.data_available.notify_all()
            self.space_available.notify_all()


class DataLoader:
    def __init__(self, path: List[str], pattern: str, mem: int, batch_size: int, context_len: int, tokenizer: str,
                 pad_token: int, threads: int, ratio: List[float] = None, dtype=np.int32):
        """ Data Loader
        :param path: path to the dataset
        :param pattern: file name pattern to filter dataset
        :param mem: maximum allocated memory for each buffer in MB
        :param batch_size: batch size
        :param context_len: context length of each sample
        :param pad_token: padding token
        :param threads: number of threads for processing each data source
        :param ratio: ratio of difference dataset sources
        :param dtype: dtype for the returned tensors

        :note: each data source has its own buff, and each buff has 2 columns. Column 0 is for data, column 1 is for target.

        :return data: input tokens, shaped (batch_size, context_len)
        :return target: target tokens, shaped (batch_size, context_len)
        :return mask: 2D padding mask, shaped (batch_size, context_len)
        :return indicator: scalar padding mask, shaped (batch_size,)
        """
        self.path = path
        self.epoch = 0
        self.ready = True
        self.pattern = pattern
        self.batch_size = batch_size
        self.context_len = context_len
        self.pad_token = pad_token
        self.threads = threads
        self.data_buff = [RingBuffer(mem * 1024 * 1024 // 3, context_len, 2, dtype) for _ in range(len(path))]
        self.mask_buff = [RingBuffer(mem * 1024 * 1024 // 3, context_len, 1, dtype) for _ in range(len(path))]
        self.dtype = dtype
        self.sub_batch = [ 0 for _ in range(len(path))]

        self.distribute(ratio)

        self.tokenizer = Tokenizer.from_file(f'{tokenizer}.json')
        self.tokenizer.enable_padding(
            direction='right',
            pad_id=pad_token,
            pad_type_id=pad_token,
            pad_token='{PAD}',
            pad_to_multiple_of=context_len
        )

        self.shutdown_event = threading.Event()

        # Multi-thread setup
        self.processors = [
            threading.Thread(
                target=self.__process,
                args=(idx, p)
            ) for idx, p in enumerate(path)
        ]

        for p in self.processors:
            p.daemon = True
            p.start()

        # Pre-allocation
        self.pairs = np.empty((2, batch_size, context_len), dtype=dtype)
        self.mask = np.empty((1, batch_size, context_len), dtype=dtype)

    def stop(self):
        self.ready = False
        self.shutdown_event.set()

        for buffer in self.data_buff:
            buffer.shutdown()
        for buffer in self.mask_buff:
            buffer.shutdown()

        time.sleep(0.5)

        for p in self.processors:
            p.join(timeout=5.0)

        for p in self.processors:
            if p.is_alive():
                print(f'[WARNING] DataLoader.stop: thread {p.name} did not terminate gracefully')

        for buffer in self.data_buff:
            buffer.buff = None
        for buffer in self.mask_buff:
            buffer.buff = None

        self.pairs = None
        self.mask = None
        self.data_buff = None
        self.mask_buff = None
        self.processors = None

        gc.collect()
        print(f'[INFO] DataLoader.stop: loader stopped, average data repetition is {self.epoch}')

    def __next__(self):
        try:
            for i in range(len(self.path)):
                self.pairs[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.data_buff[i].read(self.sub_batch[i])
                self.mask[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.mask_buff[i].read(self.sub_batch[i])
        except RuntimeError as e:
            if 'shutting down' in str(e):
                raise StopIteration('[ERROR] DataLoader: shutting down')
            else:
                raise

        return self.pairs[0], self.pairs[1], self.mask[0], np.sum(self.mask[0], axis=-1)

    def distribute(self, ratio: List[float]):
        """ Calculate sub-batch size of different data sources
        :param ratio: ratio of difference dataset sources
        """
        if len(self.path) != 1 and len(ratio) != len(self.path) - 1:
            raise ValueError(f'[ERROR] DataLoader.distribute: Got {len(self.path)} paths, but {len(ratio)} ratios. The number of ratios is expected to be the number of paths - 1.')

        if len(self.path) != 1 and sum(ratio) > 1:
            raise ValueError(f'[ERROR] DataLoader.distribute: The sum of ratios are expected no bigger than 1, but got {sum(ratio)}.')

        for i in range(len(self.path) - 1):
            self.sub_batch[i] = round(self.batch_size * ratio[i])
        self.sub_batch[len(self.path) - 1] = self.batch_size - sum(self.sub_batch[i] for i in range(len(self.path) - 1))

    def __process(self, idx: int, path: str):
        """ Process all dataset files found according to the pattern in a single source """
        paths = list(Path(path).rglob(self.pattern))

        while self.ready and not self.shutdown_event.is_set():
            self.epoch += 1
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
                futures = [executor.submit(self.__process_file, idx, str(p)) for p in paths]

                for future in concurrent.futures.as_completed(futures):
                    if not self.ready:
                        executor.shutdown(wait=True, cancel_futures=True)
                        break
                    future.result()

    def __process_file(self, idx: int, path: str):
        """ Process a single dataset file """
        if not self.ready or self.shutdown_event.is_set():
            return

        try:
            df = pl.read_parquet(path)
            df = df.select('text')
            print(f'[INFO] DataLoader.__process_file: processing file {path.split('/')[-2]}/{path.split('/')[-1]}')
            for row in df.iter_rows():
                encoded = self.tokenizer.encode(row[0] + '{EOS}')
                self.__process_token(idx, encoded)
            del df
        except Exception as e:
            print(f'[ERROR] DataLoader.__process_file: failed to process file ({path}): {e}')

    def __process_token(self, idx: int, encoded):
        """ Process a single token array with padding or sliding window policy """
        if not self.ready or self.shutdown_event.is_set():
            return

        token = np.asarray(encoded.ids, dtype=self.dtype)
        input_seq = token.reshape(-1, self.context_len)

        target = np.concat((token[1:], np.array([self.pad_token], dtype=self.dtype)), axis=-1)
        target_seq = target.reshape(-1, self.context_len)

        masks = np.asarray(encoded.attention_mask, dtype=self.dtype)
        mask_seq = masks.reshape(-1, self.context_len)

        self.data_buff[idx].write(np.stack((input_seq, target_seq), axis=0))
        self.mask_buff[idx].write(np.expand_dims(mask_seq, axis=0))


class MultiShiftDataLoader:
    def __init__(self, path: List[str], pattern: str, mem: int, batch_size: int, context_len: int, tokenizer: str,
                 pad_token: int, threads: int, shift: int = 1, ratio: List[float] = None, dtype=np.int32):
        """ Multi-Shifting Data Loader
        :param path: path to the dataset
        :param pattern: file name pattern to filter dataset
        :param mem: maximum allocated memory for each buffer in MB
        :param batch_size: batch size
        :param context_len: context length of each sample
        :param pad_token: padding token
        :param threads: number of threads for processing each data source
        :param shift: how many times shifting should be applied to generate targets
        :param ratio: ratio of difference dataset sources
        :param dtype: dtype for the returned tensors

        :note: each data source has its own buff, and each buff has 2 columns. Column 0 is for data, column 1 is for target.

        :return data: input tokens, shaped (batch_size, context_len)
        :return target: target tokens, shaped (batch_size, context_len)
        :return mask: 2D padding mask, shaped (batch_size, context_len)
        :return indicator: scalar padding mask, shaped (batch_size,)
        """
        self.path = path
        self.epoch = 0
        self.ready = True
        self.pattern = pattern
        self.batch_size = batch_size
        self.shift = shift
        self.context_len = context_len
        self.pad_token = pad_token
        self.threads = threads
        self.data_buff = [RingBuffer(mem * 1024 * 1024 // (shift + 2), context_len, shift + 1, dtype) for _ in range(len(path))]
        self.mask_buff = [RingBuffer(mem * 1024 * 1024 // (shift + 2), context_len, 1, dtype) for _ in range(len(path))]
        self.dtype = dtype
        self.sub_batch = [ 0 for _ in range(len(path))]

        if shift >= context_len:
            raise ValueError(f'[ERROR] MultiShiftDataLoader: Shift {shift} exceeds context length')

        self.distribute(ratio)

        self.tokenizer = Tokenizer.from_file(f'{tokenizer}.json')
        self.tokenizer.enable_padding(
            direction='right',
            pad_id=pad_token,
            pad_type_id=pad_token,
            pad_token='{PAD}',
            pad_to_multiple_of=context_len
        )

        self.shutdown_event = threading.Event()

        # Multi-thread setup
        self.processors = [
            threading.Thread(
                target=self.__process,
                args=(idx, p)
            ) for idx, p in enumerate(path)
        ]

        for p in self.processors:
            p.daemon = True
            p.start()

        # Pre-allocation
        self.pairs = np.empty((shift + 1, batch_size, context_len), dtype=dtype)
        self.mask = np.empty((1, batch_size, context_len), dtype=dtype)

    def stop(self):
        self.ready = False
        self.shutdown_event.set()

        for buffer in self.data_buff:
            buffer.shutdown()
        for buffer in self.mask_buff:
            buffer.shutdown()

        time.sleep(0.5)

        for p in self.processors:
            p.join(timeout=5.0)

        for p in self.processors:
            if p.is_alive():
                print(f'[WARNING] MultiShiftDataLoader.stop: thread {p.name} did not terminate gracefully')

        for buffer in self.data_buff:
            buffer.buff = None
        for buffer in self.mask_buff:
            buffer.buff = None

        self.pairs = None
        self.mask = None
        self.data_buff = None
        self.mask_buff = None
        self.processors = None

        gc.collect()
        print(f'[INFO] MultiShiftDataLoader.stop: loader stopped, average data repetition is {self.epoch}')

    def __next__(self):
        try:
            for i in range(len(self.path)):
                self.pairs[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.data_buff[i].read(self.sub_batch[i])
                self.mask[:, sum(self.sub_batch[:i]):sum(self.sub_batch[:i + 1])] = self.mask_buff[i].read(self.sub_batch[i])
        except RuntimeError as e:
            if 'shutting down' in str(e):
                raise StopIteration('[ERROR] MultiShiftDataLoader: shutting down')
            else:
                raise

        return self.pairs[0], self.pairs[1:], self.mask[0], np.sum(self.mask[0], axis=-1)

    def distribute(self, ratio: List[float]):
        """ Calculate sub-batch size of different data sources
        :param ratio: ratio of difference dataset sources
        """
        if len(self.path) != 1 and len(ratio) != len(self.path) - 1:
            raise ValueError(f'[ERROR] MultiShiftDataLoader.distribute: Got {len(self.path)} paths, but {len(ratio)} ratios. The number of ratios is expected to be the number of paths - 1.')

        if len(self.path) != 1 and sum(ratio) > 1:
            raise ValueError(f'[ERROR] MultiShiftDataLoader.distribute: The sum of ratios are expected no bigger than 1, but got {sum(ratio)}.')

        for i in range(len(self.path) - 1):
            self.sub_batch[i] = round(self.batch_size * ratio[i])
        self.sub_batch[len(self.path) - 1] = self.batch_size - sum(self.sub_batch[i] for i in range(len(self.path) - 1))

    def __process(self, idx: int, path: str):
        """ Process all dataset files found according to the pattern in a single source """
        paths = list(Path(path).rglob(self.pattern))

        while self.ready and not self.shutdown_event.is_set():
            self.epoch += 1
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
                futures = [executor.submit(self.__process_file, idx, str(p)) for p in paths]

                for future in concurrent.futures.as_completed(futures):
                    if not self.ready:
                        executor.shutdown(wait=True, cancel_futures=True)
                        break
                    future.result()

    def __process_file(self, idx: int, path: str):
        """ Process a single dataset file """
        if not self.ready or self.shutdown_event.is_set():
            return

        try:
            df = pl.read_parquet(path)
            df = df.select('text')
            print(f'[INFO] MultiShiftDataLoader.__process_file: processing file {path.split('/')[-2]}/{path.split('/')[-1]}')
            for row in df.iter_rows():
                encoded = self.tokenizer.encode(row[0] + '{EOS}')
                self.__process_token(idx, encoded)
            del df
        except Exception as e:
            print(f'[ERROR] MultiShiftDataLoader.__process_file: failed to process file ({path}): {e}')

    def __process_token(self, idx: int, encoded):
        """ Process a single token array with padding or sliding window policy """
        if not self.ready or self.shutdown_event.is_set():
            return

        token = np.asarray(encoded.ids, dtype=self.dtype)
        seq_batch = token.reshape(1, -1, self.context_len)

        for i in range(1, self.shift + 1):
            seq = np.concat((token[i:], np.array([self.pad_token] * i, dtype=self.dtype)), axis=-1)
            seq_batch = np.concat((seq_batch, seq.reshape(1, -1, self.context_len)), axis=0)

        masks = np.asarray(encoded.attention_mask, dtype=self.dtype)
        mask_seq = masks.reshape(-1, self.context_len)

        self.data_buff[idx].write(seq_batch)
        self.mask_buff[idx].write(np.expand_dims(mask_seq, axis=0))