#!/usr/bin/env python3
"""
Split downloaded Parquet files into JSONL shards for train/validation/inference.

For each dataset in the input directory, reads:
  - train-00000-of-00001.parquet
  - val-00000-of-00001.parquet

Computes N = min(num_train, num_val // 2)
Then samples and splits:
  - train → train.jsonl (N samples)
  - val   → test.jsonl  (N samples)
            val.jsonl   (N samples)

All output files have exactly N samples and are disjoint.
"""
import argparse
import random
from pathlib import Path
import sys
import pandas as pd
import json
from transformers import AutoTokenizer

sys.path.append(str(Path(__file__).parent.parent))

from settings.subsets import SUBSETS

def split_and_save(df, counts, names, output_dir):
    """
    Splits dataframe df into len(counts) parts of sizes counts,
    saves each part as JSONL under output_dir/names[i].
    """
    assert sum(counts) <= len(df)
    idx = list(df.index)
    random.shuffle(idx)
    start = 0
    output_dir.mkdir(parents=True, exist_ok=True)
    for count, name in zip(counts, names):
        subset_idx = idx[start:start+count]
        start += count
        subset = df.loc[subset_idx]
        out_path = output_dir / f"{name}.jsonl"
        with open(out_path, "w", encoding="utf-8") as f:
            for record in subset.to_dict(orient="records"):
                f.write(json.dumps(record, ensure_ascii=False) + "\n")
        print(f"  ✎ Wrote {count} samples to {out_path}")

def process_dataset(dataset_name, input_dir, output_dir, tokenizer, min_tokens, max_val_train=False, seed=42):
    random.seed(seed)
    ds_in = Path(input_dir) / dataset_name
    ds_out = Path(output_dir) / dataset_name
    train_path = ds_in / "train-00000-of-00001.parquet"
    val_path   = ds_in / "val-00000-of-00001.parquet"
    if not train_path.exists() or not val_path.exists():
        print(f"❌ Missing parquet in {dataset_name}, skipping.")
        return
    print(f"📂 Processing dataset: {dataset_name}")
    # Load
    train_df = pd.read_parquet(train_path)
    val_df   = pd.read_parquet(val_path)
    # Filter by min token length
    def enough_tokens(text):
        return len(tokenizer.encode(str(text), add_special_tokens=False)) >= min_tokens
    before_train = len(train_df)
    train_df = train_df[train_df['text'].apply(enough_tokens)]
    removed_train = before_train - len(train_df)
    if removed_train > 0:
        print(f"  🗑️  Removed {removed_train} train samples with < {min_tokens} tokens")
    before_val = len(val_df)
    val_df = val_df[val_df['text'].apply(enough_tokens)]
    removed_val = before_val - len(val_df)
    if removed_val > 0:
        print(f"  🗑️  Removed {removed_val} val samples with < {min_tokens} tokens")
    num_train = len(train_df)
    num_val   = len(val_df)
    if max_val_train:
        if num_train == 0 or num_val < 2:
            print(f"⚠️  Not enough samples in {dataset_name} ({num_train=}, {num_val=}), skipping.")
            return
        print(f"  📜 Saving full train ({num_train} rows) and full val ({num_val} rows)")
        split_and_save(
            train_df,
            counts=[num_train],
            names=["train_full"],
            output_dir=ds_out
        )
        split_and_save(
            val_df,
            counts=[num_val],
            names=["train_val"],
            output_dir=ds_out
        )
        N = num_val // 2
        if N == 0:
            print(f"⚠️  Not enough val samples to split in {dataset_name}, skipping test/test_val.")
            return
        print(f"  📜 Splitting val ({num_val} rows) into test ({N}) and test_val ({N})")
        split_and_save(
            val_df,
            counts=[N, N],
            names=["test", "test_val"],
            output_dir=ds_out
        )
        return
    N = min(num_train, num_val // 2)
    if N <= 0:
        print(f"⚠️  Not enough samples in {dataset_name} ({num_train=}, {num_val=}), skipping.")
        return
    print(f"  ✂️  Splitting train ({num_train} rows) into {N}")
    split_and_save(
        train_df,
        counts=[N],
        names=["train"],
        output_dir=ds_out
    )
    print(f"  ✂️  Splitting val ({num_val} rows) into 2×{N}")
    split_and_save(
        val_df,
        counts=[N, N],
        names=["test", "val"],
        output_dir=ds_out
    )

def main():
    parser = argparse.ArgumentParser(
        description="Split Parquet datasets into JSONL shards for train/val/inference."
    )
    parser.add_argument(
        "--input-dir",
        required=True,
        help="Folder where Parquet subfolders reside"
    )
    parser.add_argument(
        "--output-dir",
        required=True,
        help="Folder to write JSONL shards"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for shuffling (default: 42)"
    )
    parser.add_argument(
        "--tokenizer",
        required=True,
        help="HuggingFace tokenizer name (e.g. 'gpt2')"
    )
    parser.add_argument(
        "--min-tokens",
        type=int,
        required=True,
        help="Minimal number of tokens in 'text' field"
    )
    parser.add_argument(
        "--max-val-train",
        action="store_true",
        help="If set, save full train as train_full.jsonl, full val as train_val.jsonl, and split val into test.jsonl and test_val.jsonl (each half of val)."
    )
    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)

    for ds in SUBSETS:
        process_dataset(ds, args.input_dir, args.output_dir, tokenizer, args.min_tokens, max_val_train=args.max_val_train, seed=args.seed)

if __name__ == "__main__":
    main()
