import json
import os
import random
from PIL import Image
from io import BytesIO
from datasets import Dataset, DatasetDict, Features, Sequence, Value
from datasets import Image as HFImageData


IMAGE_BASE_PATH = "path/to/ChartX/ChartX_png" 

INPUT_TRAIN_JSON = "path/to/ChartX/ChartX_annotation_test.json"
OUTPUT_TRAIN_FILE = "benchmark_chartx.parquet" 
NUM_TRAIN_SAMPLES = 5000 

# --- Helper functions ---
def process_entry(entry, image_base_path):
    """Process a single record from ChartX JSON"""
    try:
        # Build complete image path
        # Assume entry['img'] is a relative path like './bar_chart/png/bar_87.png'
        relative_img_path = entry.get("img", "")
        if not relative_img_path:
             print(f"Warning: Record {entry.get('imgname', 'N/A')} missing 'img' field. Skipping.")
             return None
        if relative_img_path.startswith("./"):
             relative_img_path = relative_img_path[2:] # Remove leading './'
        full_img_path = os.path.join(image_base_path, relative_img_path)

        if not os.path.exists(full_img_path):
             print(f"Warning: Image not found: {full_img_path}. Skipping record {entry.get('imgname', 'N/A')}.")
             return None

        # Load image and convert to RGB
        pil_image = Image.open(full_img_path).convert("RGB")

        # Build prompt and ground_truth from QA field
        qa_data = entry.get("QA", {})
        qa_input = qa_data.get("input", "")
        qa_output = qa_data.get("output", "")

        if not qa_input or not qa_output:
             print(f"Warning: Record {entry.get('imgname', 'N/A')} missing QA input or output. Skipping.")
             return None

        # Build format required by RLHFDataset
        prompt_text = "<image> " + qa_input # <image> placeholder + question
        ground_truth_text = qa_output       # target answer

        # Convert PIL image to bytes so datasets.Image feature can handle it
        img_byte_arr = BytesIO()
        pil_image.save(img_byte_arr, format='PNG') # Can choose PNG or JPEG
        img_bytes = img_byte_arr.getvalue()

        # Return processed dictionary, note the format of images field
        return {
            "images": [{"bytes": img_bytes}], # datasets.Image feature requires this {'bytes': ...} structure
            "prompt": prompt_text,
            "ground_truth": ground_truth_text
        }
    except Exception as e:
        print(f"Error processing record {entry.get('imgname', 'N/A')}: {e}")
        return None

# --- Process training data ---
print(f"Loading training data from {INPUT_TRAIN_JSON}...")
processed_train_data = []
try:
    with open(INPUT_TRAIN_JSON, 'r', encoding='utf-8') as f:
        chartx_train_raw = json.load(f)
except FileNotFoundError:
    print(f"Error: Training file {INPUT_TRAIN_JSON} not found. Please check the path.")
    exit()
except json.JSONDecodeError:
    print(f"Error: Training file {INPUT_TRAIN_JSON} has invalid format.")
    exit()


# Random sampling
if len(chartx_train_raw) < NUM_TRAIN_SAMPLES:
    print(f"Warning: Need {NUM_TRAIN_SAMPLES} samples, but {INPUT_TRAIN_JSON} only has {len(chartx_train_raw)}. Will use all available samples.")
    sampled_indices = range(len(chartx_train_raw))
else:
    print(f"Randomly sampling {NUM_TRAIN_SAMPLES} from {len(chartx_train_raw)} data entries for training set...")
    sampled_indices = random.sample(range(len(chartx_train_raw)), NUM_TRAIN_SAMPLES)

# Process sampled data
for i in sampled_indices:
    processed = process_entry(chartx_train_raw[i], IMAGE_BASE_PATH)
    if processed:
        processed_train_data.append(processed)

# --- Save as Parquet ---
if processed_train_data:
    print(f"Saving {len(processed_train_data)} processed training samples to {OUTPUT_TRAIN_FILE}...")
    # Define dataset features, especially for image columns
    features = Features({
        'images': Sequence(feature=HFImageData(decode=True), length=-1, id=None), # Use datasets.Image feature
        'prompt': Value(dtype='string', id=None),
        'ground_truth': Value(dtype='string', id=None)
    })
    try:
        # Create Hugging Face Dataset object
        train_dataset = Dataset.from_list(processed_train_data, features=features)
        # Save as Parquet file
        train_dataset.to_parquet(OUTPUT_TRAIN_FILE)
        print(f"Successfully saved training data to {OUTPUT_TRAIN_FILE}")
    except Exception as e:
        print(f"Error saving training data as Parquet: {e}")
else:
    print("No training data to save.")


print("Data preprocessing completed.")