#!/usr/bin/env python3
"""
Process JSONL files to split text into prefix and suffix using a tokenizer.

For each dataset, reads:
  - inference-test.jsonl
  - val.jsonl  
  - inference-train.jsonl

Splits the "text" field into "prefix" and "suffix" where suffix has exactly
suff_tokens number of tokens, and saves the results in a subdirectory
named after the model.
"""
import argparse
import json
from pathlib import Path
from transformers import AutoTokenizer
import os
import sys
sys.path.append("../utils/NL-Augmenter")
sys.path.append("../utils")
# torch 2.3 is needed, otherwise torchtext won't be compatible with pytorch in versions >=2.4
from transform import get_aug_generator
from settings.subsets import SUBSETS
from settings.text_transformations import TEXT_TRANSFORMATIONS


def process_jsonl_file(input_file, output_file, tokenizer, suff_tokens, text_transformtions):
    """
    Process a single JSONL file, splitting text into prefix and suffix.

    Args:
        input_file: Path to input JSONL file
        output_file: Path to output JSONL file
        tokenizer: HuggingFace tokenizer instance
        suff_tokens: Number of tokens for the suffix
        text_transformtions: List of text transformations applied to prefix
    """
    processed_count = 0
    skipped_count = 0

    with open(input_file, 'r', encoding='utf-8') as infile, \
         open(output_file, 'w', encoding='utf-8') as outfile:
        
        trans_generators = {}
        for transformtion in text_transformtions:
            trans_generators[transformtion] = get_aug_generator(transformtion)

        for line in infile:
            record = json.loads(line.strip())

            if 'text' not in record:
                print(f"⚠️  Skipping record without 'text' field", flush=True)
                skipped_count += 1
                continue

            text = record['text']

            # Tokenize the text
            tokens = tokenizer.encode(text, add_special_tokens=False)

            # Check if we have enough tokens for the suffix
            if len(tokens) <= suff_tokens + 100:
                print(f"⚠️  Skipping record: text has {len(tokens)} tokens, need more than {suff_tokens + 100}", flush=True)
                skipped_count += 1
                continue

            # Split tokens into prefix and suffix
            prefix_tokens = tokens[:-suff_tokens]
            suffix_tokens = tokens[-suff_tokens:]

            # Decode back to text
            prefix_text = tokenizer.decode(prefix_tokens, skip_special_tokens=True)
            suffix_text = tokenizer.decode(suffix_tokens, skip_special_tokens=True)

            # Create new record with prefix and suffix
            new_record = record.copy()
            new_record['prefix'] = prefix_text
            new_record['suffix'] = suffix_text
            for transformtion in text_transformtions:
                prefix_text_transformed = trans_generators[transformtion].generate(prefix_text)[0]
                new_record[f"prefix_{transformtion}"] = prefix_text_transformed
                del new_record[transformtion]
            del new_record['text']  # Remove original text field

            # Write to output file
            outfile.write(json.dumps(new_record, ensure_ascii=False) + '\n')
            processed_count += 1

    return processed_count, skipped_count


def main():
    parser = argparse.ArgumentParser(
        description="Split text in JSONL files into prefix and suffix using a tokenizer."
    )
    parser.add_argument(
        "--input-dir",
        required=True,
        help="Folder containing dataset subdirectories with JSONL files"
    )
    parser.add_argument(
        "--output-dir", 
        required=True,
        help="Folder to write processed JSONL files"
    )
    parser.add_argument(
        "--model-name",
        required=True,
        help="HuggingFace model name for tokenizer (e.g., 'gpt2', 'microsoft/DialoGPT-medium')"
    )
    parser.add_argument(
        "--suff-tokens",
        type=int,
        required=True,
        help="Number of tokens for the suffix"
    )
    parser.add_argument(
        "--max-val-train",
        action="store_true",
        help="If set, process files generated by split_dataset.py with --max-val-train: train_full.jsonl, train_val.jsonl, test.jsonl, test_val.jsonl. Otherwise, process train.jsonl, val.jsonl, test.jsonl."
    )

    args = parser.parse_args()

    print(f"🤖 Loading tokenizer for model: {args.model_name}", flush=True)

    try:
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)

        # Some tokenizers don't have a pad_token, add one if needed
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

    except Exception as e:
        print(f"❌ Error loading tokenizer for {args.model_name}: {e}", flush=True)
        return

    print(f"✅ Tokenizer loaded successfully", flush=True)
    print(f"📝 Suffix will contain {args.suff_tokens} tokens", flush=True)

    # Select which files to process based on --max-val-train
    if args.max_val_train:
        target_files = ['train_full.jsonl', 'train_val.jsonl', 'test.jsonl', 'test_val.jsonl']
        # target_files = ['train_full.jsonl', 'train_val.jsonl']
    else:
        target_files = ['train.jsonl', 'val.jsonl', 'test.jsonl']

    # Process each dataset
    for dataset in SUBSETS:
        ds_input = Path(args.input_dir) / dataset
        safe_model_name = args.model_name.replace('/', '_')
        ds_output = Path(args.output_dir) / safe_model_name / dataset

        print(f"📂 Processing dataset: {dataset}", flush=True)

        # Check if any target files exist
        existing_files = []
        for filename in target_files:
            filepath = ds_input / filename
            if filepath.exists():
                existing_files.append(filename)

        if not existing_files:
            print(f"❌ No target files found in {dataset}, skipping.", flush=True)
            continue

        ds_output.mkdir(parents=True, exist_ok=True)

        total_processed = 0
        total_skipped = 0

        for filename in existing_files:
            input_file = ds_input / filename
            output_file = ds_output / filename

            print(f"  🔄 Processing {filename}...", flush=True)

            try:
                processed, skipped = process_jsonl_file(input_file, output_file, tokenizer, args.suff_tokens, TEXT_TRANSFORMATIONS)
                total_processed += processed
                total_skipped += skipped
                print(f"    ✅ Processed {processed} records, skipped {skipped}", flush=True)
            except Exception as e:
                print(f"    ❌ Error processing {filename}: {e}", flush=True)

        print(f"  📊 Dataset {dataset}: {total_processed} total processed, {total_skipped} total skipped", flush=True)

    print("🎉 Processing complete!", flush=True)


if __name__ == "__main__":
    main()
