import argparse
from typing import List

import apache_beam as beam
import pyarrow
from apache_beam.options.pipeline_options import PipelineOptions
from datasets import load_dataset
from transformers import AutoTokenizer

# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer-model", type=str, default="gpt2", help="")
parser.add_argument("--min-sequence-length", type=int, default=128, help="")
parser.add_argument("--max-sequence-length", type=int, default=256, help="")
parser.add_argument("--num-special-token-reserved", type=int, default=2, help="")
parser.add_argument("--ignore-label", type=int, default=-100, help="")
parser.add_argument("--stride", type=int, default=128, help="")
parser.add_argument("--dataset-split-type", type=str, default="train", help="")
parser.add_argument("--output-path", type=str, default="data/penn/", help="")
# fmt: on

class TokenizingDoFn(beam.DoFn):
    def __init__(self, tokenizer_model: str):
        self.tokenizer_model = tokenizer_model

    def setup(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_model, use_fast=True)
        # Ensure <unk> is treated as a token
        if self.tokenizer.unk_token is None:
            self.tokenizer.add_special_tokens({"unk_token": "<unk>"})

    def process(self, text: str):
        yield self.tokenizer.encode(text, add_special_tokens=False)

class SplitingChunksDoFN(beam.DoFn):
    def __init__(self, min_sequence_length: int = 128, max_sequence_length: int = 256, num_special_token_reserved: int = 2, stride: int = 128):
        self.min_sequence_length = min_sequence_length - num_special_token_reserved
        self.max_sequence_length = max_sequence_length - num_special_token_reserved
        self.stride = stride

    def process(self, input_ids: List[int]):
        for i in range(0, len(input_ids), self.stride):
            chunk = input_ids[i : i + self.max_sequence_length + 1]
            if len(chunk) > self.min_sequence_length:
                yield chunk

class PaddingAndPackagingDoFN(beam.DoFn):
    def __init__(self, sequence_length: int = 256, bos_token_id: int = 1, eos_token_id: int = 2, pad_token_id: int = 0, ignore_label: int = -100):
        self.sequence_length = sequence_length
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.pad_token_id = pad_token_id
        self.ignore_label = ignore_label

    def process(self, sequence: List[int]):
        sequence = [self.bos_token_id] + sequence + [self.eos_token_id]
        sequence = sequence + [self.pad_token_id] * (self.sequence_length - len(sequence) + 1)
        input_ids = sequence[: self.sequence_length]
        attention_mask = [1.0 if token > 0 else 0.0 for token in input_ids]
        labels = [
            label if label != self.pad_token_id else self.ignore_label
            for label in sequence[1 : self.sequence_length + 1]
        ]
        assert len(input_ids) == len(attention_mask) == len(labels)
        yield {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

def main(args: argparse.Namespace, pipeline_options: argparse.Namespace):
    split_file = {
        "train": "data/penn/train.txt",
        "valid": "data/penn/valid.txt",
        "test": "data/penn/test.txt",
    }[args.dataset_split_type]
    with open(split_file, "r", encoding="utf-8") as f:
        text = " ".join(line.strip() for line in f if line.strip())
        lines = [text]  # make a list with one long string
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_model, use_fast=True)

    pipeline_options = PipelineOptions(pipeline_args)
    with beam.Pipeline(options=pipeline_options) as pipeline:
        split_chunk_do_fn = SplitingChunksDoFN(
            min_sequence_length=args.min_sequence_length,
            max_sequence_length=args.max_sequence_length,
            num_special_token_reserved=args.num_special_token_reserved,
            stride=args.stride,
        )
        padding_and_packaging_do_fn = PaddingAndPackagingDoFN(
            sequence_length=args.max_sequence_length,
            bos_token_id=tokenizer.bos_token_id or tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id or 0,
            ignore_label=args.ignore_label,
        )
        write_fn = beam.io.WriteToParquet(
            args.output_path,
            pyarrow.schema(
                [
                    ("input_ids", pyarrow.list_(pyarrow.int64())),
                    ("attention_mask", pyarrow.list_(pyarrow.float32())),
                    ("labels", pyarrow.list_(pyarrow.int64())),
                ]
            ),
        )
        _ = (
            pipeline
            | "Create PCollection from dataset" >> beam.Create(lines)
            | "Tokenize text" >> beam.ParDo(TokenizingDoFn(args.tokenizer_model))
            | "Split into chunks" >> beam.ParDo(split_chunk_do_fn)
            | "Pad and package" >> beam.ParDo(padding_and_packaging_do_fn)
            | "Write output" >> write_fn
        )

if __name__ == "__main__":
    known_args, pipeline_args = parser.parse_known_args()
    main(known_args, pipeline_args)
