import argparse
import csv
import os
from typing import Tuple, Iterable, Dict

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from transformers import AutoTokenizer
import pyarrow as pa

# Argument parser
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer-model", type=str, default="bert-base-uncased")
parser.add_argument("--max-sequence-length", type=int, default=256)
parser.add_argument("--data-dir", type=str, required=True, help="Path to data/dbpedia/")
parser.add_argument("--output-dir", type=str, required=True, help="Path to save processed output")
args, beam_args = parser.parse_known_args()

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_model)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


# Load label set from train.csv
label_set = set()
with open(os.path.join(args.data_dir, "DBPEDIA_train.csv"), newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        label_set.add(row["l3"])
label_to_id = {label: idx for idx, label in enumerate(sorted(label_set))}

class ReadCSV(beam.DoFn):
    def process(self, filepath: str) -> Iterable[Tuple[str, str]]:
        with open(filepath, newline='', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                yield (row["text"], row["l3"])


class TokenizeText(beam.DoFn):
    def process(self, element: Tuple[str, str]) -> Iterable[Dict]:
        text, label = element
        tokens = tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=args.max_sequence_length,
        )
        yield {
            "input_ids": tokens["input_ids"],
            "attention_mask": tokens["attention_mask"],
            "label": label_to_id[label]
        }


class ToArrowRecord(beam.DoFn):
    def process(self, element: Dict) -> Iterable[Dict[str, object]]:
        return [element]


def run_pipeline(split_name: str):
    input_file = os.path.join(args.data_dir, f"DBPEDIA_{split_name}.csv")
    output_dir = os.path.join(args.output_dir, split_name)

    options = PipelineOptions(beam_args)
    with beam.Pipeline(options=options) as p:
        records = (
            p
            | f"Create-{split_name}" >> beam.Create([input_file])
            | f"ReadCSV-{split_name}" >> beam.ParDo(ReadCSV())
            | f"Tokenize-{split_name}" >> beam.ParDo(TokenizeText())
        )

        # Save as Parquet
        _ = (
            records
            | f"ToArrow-{split_name}" >> beam.ParDo(ToArrowRecord())
            | f"WriteParquet-{split_name}" >> beam.io.WriteToParquet(
                file_path_prefix=output_dir,
                schema=pa.schema([
                    ("input_ids", pa.list_(pa.int32())),
                    ("attention_mask", pa.list_(pa.int32())),
                    ("label", pa.int32())
                ]),
                file_name_suffix=".parquet",
                shard_name_template=""
            )
        )


if __name__ == "__main__":
    run_pipeline("train")
    run_pipeline("val")
    run_pipeline("test")
