import torch
import glob
import os
import io
import dataclasses as dc
import re
from litgpt.tokenizer import Tokenizer
from tokenizers import Tokenizer as HFTokenizer
from litgpt.data import TextFiles as LitTextFiles, DataModule, SFTDataset
from pathlib import Path
from functools import partial

from torch.utils.data import DataLoader


@dc.dataclass
class TextFiles(LitTextFiles):
    """
    Inherit from litgpt.data.TextFiles. This will segment sequences in a text file.
    """

    sep: str | tuple[str, ...] | None = None
    drop_size: int | None = None
    read_stride: int = 2048
   
    def prepare_data(
        self,
        max_workers: int | None = None
    ) -> None:
        from litdata import optimize
        from litdata.streaming.item_loader import TokensLoader

        train_files = sorted(glob.glob(str(self.train_data_path / "*.txt")))
        assert len(train_files) > 0, f"No .txt files found in train data {train_files}"

        if self.val_data_path is not None:
            self.val_data_path = Path(self.val_data_path)
            val_files = sorted(glob.glob(str(self.val_data_path / "*.txt")))
            assert len(val_files) > 0, f"No .txt files found in validation data {val_files}"
        # train/test split. let's use only shard 0 for test split, rest train
        else:
            assert len(train_files) > 1, f"Expected at least two .txt files in {train_files}"
            val_files, *train_files = train_files
            val_files = [val_files]

        tokenizer = _validate_tokenizer(self.tokenizer)
        sep: str | re.Pattern
        if isinstance(self.sep, tuple):
            sep = re.compile('(' + ('|'.join(self.sep)) + ')')
        elif self.sep is None:
            processor = tokenizer.processor
            if tokenizer.backend == 'huggingface':
                assert isinstance(processor, HFTokenizer) and tokenizer.eos_id is not None
                sep = processor.id_to_token(tokenizer.eos_id)
            else:
                raise NotImplementedError
        else:
            sep = self.sep

        item_loader = TokensLoader(block_size=self.max_seq_length)
        fn = partial(_tokenize, tokenizer=tokenizer, stride=self.read_stride, sep=sep, drop_size=self.drop_size)

        # It's ok to use almost all CPUs here because this runs in a single process
        num_cpu = os.cpu_count()
        assert num_cpu is not None
        max_workers = max_workers or (num_cpu - 1)
        use_workers = min(max_workers, len(train_files))
        if not Path(self.out_path_train).is_dir():
            optimize(
                fn=fn,
                inputs=train_files,
                output_dir=str(self.out_path_train),
                num_workers=use_workers,
                chunk_bytes="50MB",
                item_loader=item_loader,
            )
        else:
            print(
                f"\nWarning: Preprocessed training data found in {self.out_path_train}."
                " For efficiency, reprocessing is skipped. If your text input has changed since"
                " the last `litgpt pretrain` command, remove the preprocessed file(s) to trigger"
                f" reprocessing: `rm -rf {self.out_path_train}`\n"
            )
        use_workers = min(max_workers, len(val_files))
        if not Path(self.out_path_val).is_dir():
            optimize(
                fn=fn,
                inputs=val_files,
                output_dir=str(self.out_path_val),
                num_workers=use_workers,
                chunk_bytes="50MB",
                item_loader=item_loader,
            )
        else:
            print(
                f"\nWarning: Preprocessed validation data found in {self.out_path_val}."
                " For efficiency, reprocessing is skipped. If your text input has changed since"
                " the last `litgpt pretrain` command, remove the preprocessed file(s) to trigger"
                f" reprocessing: `rm -rf {self.out_path_val}`\n"
            )


def _tokenize(filename: str, tokenizer: Tokenizer, stride: int, sep: str | re.Pattern, drop_size: int | None = None):
    f = open(filename, "r", encoding="utf-8")

    tail = ''
    while tail is not None:
        keep_tail = (drop_size is None or len(tail) < drop_size)

        text, tail = _read(f, stride, sep, tail if keep_tail else None)
        text = text.strip()
        if text:
            token_ids = tokenizer.encode(text, bos=False, eos=False)
            yield token_ids
    
    f.close()


def _read(f: io.TextIOBase, size: int, sep: re.Pattern | str, tail: str | None = ''):

    text = f.read(size)
    if text == '':
        return tail or '', None

    text = (tail + text) if tail else text

    drop = 0 if tail is not None else (1 if isinstance(sep, str) else 2)

    if isinstance(sep, str):
        splits = text.split(sep)
        seqs = splits[drop : -1]
        if seqs:
            text = sep.join(seqs) + sep
        else:
            text = ''
    else:
        splits = re.split(sep, text)
        seqs = splits[drop : -1]
        text = ''.join(seqs)

    tail = splits[-1]
    return text, tail


def _validate_tokenizer(tokenizer: Tokenizer | None) -> Tokenizer:
    if tokenizer is None:
        raise ValueError(
            "Tokenizer is None. If you are using this data module via `litgpt pretrain`, "
            "please provide a valid `--tokenizer_dir` path."
        )
    return tokenizer
