from datasets import Dataset, Features, Sequence, Image, Value, DatasetDict
from transformers import AutoProcessor
import datasets
from PIL import Image as PILImage
import json
import random
import pyarrow as pa  # Ensure PyArrow data is properly formatted

def load_and_process_json(json_file_path, chunk_size=50):
    # Load the JSON file in streaming mode
    data = []
    with open(json_file_path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    
    print(f"Total Samples: {len(data)}")
    
    # Shuffle
    random.shuffle(data)
  
    # Initialize the processor
    processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", do_image_splitting=False)
    max_size = 512  # Resize all images to 512x512
    
    # Process in chunks to avoid memory issues
    processed_chunks = []

    for i in range(0, len(data), chunk_size):
        batch = data[i:i + chunk_size]
        print("processing for full", i ,i+chunk_size)
        processed_data = {'images': [], 'prompt': [], 'chosen': [], 'rejected': []}

        for item in batch:
            # Load and resize image
            image = PILImage.open(item['image_path']).convert("RGB")  
            image.thumbnail((max_size, max_size))  

            # Convert image to a format that saves memory
            image = image.convert("RGB")  # Ensure RGB mode

            # Format prompts and responses
            prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": item["prompt"]}]}]
            chosen = [{"role": "assistant", "content": [{"type": "text", "text": item["chosen"]}]}]
            rejected = [{"role": "assistant", "content": [{"type": "text", "text": item["rejected"]}]}]

            # Apply chat template
            formatted_prompt = processor.apply_chat_template(prompt, tokenize=False)
            formatted_chosen = processor.apply_chat_template(chosen, tokenize=False)
            formatted_rejected = processor.apply_chat_template(rejected, tokenize=False)

            # Add processed data
            processed_data['images'].append([image])  # Keep as a list of lists
            processed_data['prompt'].append(str(formatted_prompt))
            processed_data['chosen'].append(str(formatted_chosen))
            processed_data['rejected'].append(str(formatted_rejected))

        # Convert to Dataset format
        features = Features({
            'images': Sequence(Image(decode=True)),
            'prompt': Value('string'),
            'chosen': Value('string'),
            'rejected': Value('string')
        })
        
        # Convert processed_data to a PyArrow-friendly format
        dataset_chunk = Dataset.from_dict({
            'images': processed_data['images'],
            'prompt': pa.array(processed_data['prompt'], type=pa.string()),
            'chosen': pa.array(processed_data['chosen'], type=pa.string()),
            'rejected': pa.array(processed_data['rejected'], type=pa.string())
        }, features=features)
        
        processed_chunks.append(dataset_chunk)

    # Combine all chunks into one dataset
    full_dataset = datasets.concatenate_datasets(processed_chunks)
    
    # Create a DatasetDict with train split
    dataset_dict = DatasetDict({'train': full_dataset})

    return dataset_dict



def main():
    """Main function to parse arguments and process the dataset"""
    parser = argparse.ArgumentParser(description="Process a JSONL file containing image-prompt-response triples")
    parser.add_argument("input_file", help="Path to the input JSONL file")
    parser.add_argument("output_dir", help="Directory to save the processed dataset")
    
    
    args = parser.parse_args()
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    print(f"Processing input file: {args.input_file}")
    print(f"Output directory: {args.output_dir}")
    
    
    # Process the dataset
    dataset = load_and_process_json(args.input_file)
    
    # Save the dataset
    print(f"Saving dataset to {args.output_dir}")
    dataset.save_to_disk(
        args.output_dir, 
        num_shards={'train': 50, 'test': 10}
    )
    
    print(f"Successfully processed and saved dataset with {len(dataset['train'])} samples")
    return 0

if __name__ == "__main__":
    main()

# python script.py ./it1_dpopairs.jsonl ./it1_dpopairs_llama11b