import os
import re
import json
import torch
import torch.multiprocessing as mp
from vllm import LLM, SamplingParams
from transformers import AutoProcessor
from datasets import load_from_disk

from PIL import Image
from io import BytesIO
from tqdm import tqdm
import argparse

# ------------------------
# Configuration
# ------------------------
model_dir = 'DIR_TO_MODEL_DIR'

dataset_dir = 'DIR_TO_UGROUND20K'
output_prefix = 'output_grpo'
output_dir = 'DIR_TO_OUTPUT_DIR'
world_size = 1       # Number of parallel vLLM processes
batch_size = 8      # Batch size per process

sampling_params = SamplingParams(max_tokens=1024, temperature=0) # Sampling parameters for inference

# ------------------------
# Utility functions
# ------------------------
QUESTION_TEMPLATE = "What is the coordinate of [{Question}] in the image?\nThe size of image is ({size_x},{size_y}).\nOutput the thinking process in <think> </think> and final answer (coordinate (x,y)) in <answer> </answer> tags."


def extract_coordinates(text):
    pattern = r'<answer>\s*\((?P<x>\d+\.?\d*),\s*(?P<y>\d+\.?\d*)\)\s*</answer>'
    m = re.search(pattern, text)
    if m:
        return float(m.group('x')), float(m.group('y'))
    return None, None

# ------------------------
# Sample to LLM input conversion
# ------------------------
def make_input_from_sample(sample, processor):
    image = Image.open(BytesIO(sample['image']))
    conversation = sample.get('conversations')
    question = json.loads(sample['conversations'])[0]['value'].split('Description:')[1].split('Answer')[0].strip()
    conversation = [
        {
            "role":"user",
            "content":[
                {
                    "type":"image",
                },
                {
                    "type":"text",
                    "text":QUESTION_TEMPLATE.format(Question=question, size_x=image.size[0], size_y=image.size[1])
                }
            ]
        }
    ]
    text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    return {'prompt': text_prompt, 'multi_modal_data': {'image': image}}

# ------------------------
# Inference worker with on-the-fly batch preprocessing
# ------------------------
def inference_worker(rank, world_size):
    # Assign one GPU per process
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

    # Load model and dataset
    model = LLM(
        model=model_dir,
        tensor_parallel_size=4,
        gpu_memory_utilization=0.9,
        dtype=torch.bfloat16,
    )
    dataset = load_from_disk(dataset_dir)
    processor = AutoProcessor.from_pretrained(model_dir)
    samples = dataset
    total = len(samples)

    # Distribute sample indices among ranks
    indices = list(range(rank, total, world_size))

    # Process in batches
    for start in tqdm(range(0, len(indices), batch_size)):
        batch_indices = indices[start:start + batch_size]
        # Build LLM inputs on the fly
        llm_inputs = []
        for idx in batch_indices:
            sample = samples[idx]
            llm_inputs.append(make_input_from_sample(sample, processor))

        # Generate outputs
        outputs = model.generate(llm_inputs, sampling_params=sampling_params)

        # Save results
        with open(output_dir, 'a') as fout:
            for j, out in enumerate(outputs):
                text = out.outputs[0].text.replace('\n', ' ')
                coords = extract_coordinates(text)
                answer = json.loads(samples[indices[start + j]]['conversations'])[1]['value']
                record = {
                    'global_idx': indices[start + j],
                    'response': text,
                    'coordinates': coords,
                    'answer': answer
                }
                fout.write(json.dumps(record) + '\n')

# ------------------------
# Main execution
# ------------------------
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', required=True)
    parser.add_argument('--dataset_dir', default='./datasets/uground_21k')
    parser.add_argument('--output_dir', default='./datasets/output.jsonl')
    dataset_dir = args.dataset_dir
    model_dir = args.model_dir
    output_dir = args.output_dir


    inference_worker(0, world_size)  # For debugging

