#!/usr/bin/env python3
"""
Process JSONL files to split text into prefix and suffix using a tokenizer.

For each dataset, reads:
  - inference-test.jsonl
  - val.jsonl  
  - inference-train.jsonl

Splits the "text" field into "prefix" and "suffix" where suffix has exactly
suff_tokens number of tokens, and saves the results in a subdirectory
named after the model.
"""
import argparse
import json
from pathlib import Path
from transformers import AutoTokenizer
import os


def process_jsonl_file(input_file, output_file, tokenizer, suff_tokens):
    """
    Process a single JSONL file, splitting text into prefix and suffix.

    Args:
        input_file: Path to input JSONL file
        output_file: Path to output JSONL file
        tokenizer: HuggingFace tokenizer instance
        suff_tokens: Number of tokens for the suffix
    """
    processed_count = 0
    skipped_count = 0

    with open(input_file, 'r', encoding='utf-8') as infile, \
         open(output_file, 'w', encoding='utf-8') as outfile:

        for line in infile:
            record = json.loads(line.strip())

            if 'text' not in record:
                print(f"⚠️  Skipping record without 'text' field")
                skipped_count += 1
                continue

            text = record['text']

            # Tokenize the text
            tokens = tokenizer.encode(text, add_special_tokens=False)

            # Check if we have enough tokens for the suffix
            if len(tokens) <= suff_tokens + 100:
                print(f"⚠️  Skipping record: text has {len(tokens)} tokens, need more than {suff_tokens + 100}")
                skipped_count += 1
                continue

            # Split tokens into prefix and suffix
            prefix_tokens = tokens[:-suff_tokens]
            suffix_tokens = tokens[-suff_tokens:]

            # Decode back to text
            prefix_text = tokenizer.decode(prefix_tokens, skip_special_tokens=True)
            suffix_text = tokenizer.decode(suffix_tokens, skip_special_tokens=True)

            # Create new record with prefix and suffix
            new_record = record.copy()
            del new_record['text']  # Remove original text field
            new_record['prefix'] = prefix_text
            new_record['suffix'] = suffix_text

            # Write to output file
            outfile.write(json.dumps(new_record, ensure_ascii=False) + '\n')
            processed_count += 1

    return processed_count, skipped_count


def process_dataset(dataset_name, input_dir, output_dir, model_name, tokenizer, suff_tokens):
    """
    Process all JSONL files for a single dataset.
    """
    ds_input = Path(input_dir) / dataset_name

    # Replace slashes in the model name with underscores for filesystem safety
    safe_model_name = model_name.replace('/', '_')
    # Output directory layout: <output_dir>/<model>/<dataset>
    ds_output = Path(output_dir) / safe_model_name / dataset_name

    # Files to process
    target_files = ['test.jsonl', 'val.jsonl', 'train.jsonl']

    print(f"📂 Processing dataset: {dataset_name}")

    # Check if any target files exist
    existing_files = []
    for filename in target_files:
        filepath = ds_input / filename
        if filepath.exists():
            existing_files.append(filename)

    if not existing_files:
        print(f"❌ No target files found in {dataset_name}, skipping.")
        return

    # Create output directory
    ds_output.mkdir(parents=True, exist_ok=True)

    total_processed = 0
    total_skipped = 0

    # Process each existing file
    for filename in existing_files:
        input_file = ds_input / filename
        output_file = ds_output / filename

        print(f"  🔄 Processing {filename}...")

        try:
            processed, skipped = process_jsonl_file(input_file, output_file, tokenizer, suff_tokens)
            total_processed += processed
            total_skipped += skipped
            print(f"    ✅ Processed {processed} records, skipped {skipped}")
        except Exception as e:
            print(f"    ❌ Error processing {filename}: {e}")

    print(f"  📊 Dataset {dataset_name}: {total_processed} total processed, {total_skipped} total skipped")


def main():
    parser = argparse.ArgumentParser(
        description="Split text in JSONL files into prefix and suffix using a tokenizer."
    )
    parser.add_argument(
        "--input-dir",
        required=True,
        help="Folder containing dataset subdirectories with JSONL files"
    )
    parser.add_argument(
        "--output-dir", 
        required=True,
        help="Folder to write processed JSONL files"
    )
    parser.add_argument(
        "--model-name",
        required=True,
        help="HuggingFace model name for tokenizer (e.g., 'gpt2', 'microsoft/DialoGPT-medium')"
    )
    parser.add_argument(
        "--suff-tokens",
        type=int,
        required=True,
        help="Number of tokens for the suffix"
    )
    parser.add_argument(
        "--max-val-train",
        action="store_true",
        help="If set, process files generated by split_dataset.py with --max-val-train: train_full.jsonl, train_val.jsonl, test.jsonl, test_val.jsonl. Otherwise, process train.jsonl, val.jsonl, test.jsonl."
    )

    args = parser.parse_args()

    print(f"🤖 Loading tokenizer for model: {args.model_name}")

    try:
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)

        # Some tokenizers don't have a pad_token, add one if needed
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

    except Exception as e:
        print(f"❌ Error loading tokenizer for {args.model_name}: {e}")
        return

    print(f"✅ Tokenizer loaded successfully")
    print(f"📝 Suffix will contain {args.suff_tokens} tokens")

    # Available datasets (same as in the original script)
    AVAILABLE_SUBSETS = [
        'arxiv', 'bookcorpus2', 'books3', 'cc', 'europarl', 
        'freelaw', 'github', 'gutenberg', 'hackernews', 'math', 
        'opensubtitles', 'openwebtext2', 'philpapers', 'stackexchange', 
        'ubuntu', 'uspto', 'wikipedia', 'youtubesubtitles'
    ]

    # Select which files to process based on --max-val-train
    if args.max_val_train:
        target_files = ['train_full.jsonl', 'train_val.jsonl', 'test.jsonl', 'test_val.jsonl']
    else:
        target_files = ['train.jsonl', 'val.jsonl', 'test.jsonl']

    # Process each dataset
    for dataset in AVAILABLE_SUBSETS:
        ds_input = Path(args.input_dir) / dataset
        safe_model_name = args.model_name.replace('/', '_')
        ds_output = Path(args.output_dir) / safe_model_name / dataset

        print(f"📂 Processing dataset: {dataset}")

        # Check if any target files exist
        existing_files = []
        for filename in target_files:
            filepath = ds_input / filename
            if filepath.exists():
                existing_files.append(filename)

        if not existing_files:
            print(f"❌ No target files found in {dataset}, skipping.")
            continue

        ds_output.mkdir(parents=True, exist_ok=True)

        total_processed = 0
        total_skipped = 0

        for filename in existing_files:
            input_file = ds_input / filename
            output_file = ds_output / filename

            print(f"  🔄 Processing {filename}...")

            try:
                processed, skipped = process_jsonl_file(input_file, output_file, tokenizer, args.suff_tokens)
                total_processed += processed
                total_skipped += skipped
                print(f"    ✅ Processed {processed} records, skipped {skipped}")
            except Exception as e:
                print(f"    ❌ Error processing {filename}: {e}")

        print(f"  📊 Dataset {dataset}: {total_processed} total processed, {total_skipped} total skipped")

    print("🎉 Processing complete!")


if __name__ == "__main__":
    main()
