# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
QUESTION_TEMPLATE = "What is the coordinate of [{Question}] in the image? Output the thinking process in <think> </think> and final answer (coordinate (x,y)) in <answer> </answer> tags."



import torch
import os
import io
from PIL import Image, ImageDraw
from dataclasses import dataclass, field
from typing import Optional
import logging

from datasets import load_dataset, load_from_disk, DatasetDict
from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration, Qwen2VLProcessor, Qwen2_5_VLProcessor

from qwen_vl_utils import process_vision_info

from trainer import CropTrainer

from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)

import math
import numpy as np

import re
def extract_coordinates(text):
    match = re.search(r'<answer>\s*\((\d+),\s*(\d+)\)\s*</answer>', text)
    if match:
        return int(match.group(1)), int(match.group(2))
    return None, None

def random_crop_image(image, bbox, crop_ratio=0.7):
    """
    Randomly crop an image while ensuring the bounding box is included in the crop.
    
    Args:
        image (PIL.Image): The input image to be cropped.
        bbox (list): The bounding box coordinates in relative format [x_min, y_min, x_max, y_max].
        crop_ratio (float): The ratio of the original image size to crop (default 0.7).
        
    Returns:
        PIL.Image: The randomly cropped image.
        list: The updated bounding box coordinates in the cropped image [x_min, y_min, x_max, y_max].
    """
    # Get image dimensions
    width, height = image.size
    bbox = [
        max(0, min(1, bbox[0])),
        max(0, min(1, bbox[1])),
        max(0, min(1, bbox[2])),
        max(0, min(1, bbox[3]))
    ]
    while True:
        # Calculate crop dimensions
        crop_width = int(width * crop_ratio)
        crop_height = int(height * crop_ratio)
        # Convert bbox to absolute coordinates
        bbox_abs = [
            bbox[0] * width,  # x_min
            bbox[1] * height, # y_min
            bbox[2] * width,  # x_max
            bbox[3] * height  # y_max
        ]

        # Calculate feasible range for left and top
        left_min = max(0, int(math.ceil(bbox_abs[2] - crop_width)))  # x_max - crop_width
        left_max = min(int(math.floor(bbox_abs[0])), width - crop_width)  # x_min
        top_min = max(0, int(math.ceil(bbox_abs[3] - crop_height)))  # y_max - crop_height
        top_max = min(int(math.floor(bbox_abs[1])), height - crop_height)  # y_min

        # Check if crop is feasible
        if left_min > left_max or top_min > top_max:
            # Increase crop ratio and retry 
            crop_ratio += 0.1
            if crop_ratio > 1.0:
                crop_ratio = 1.0
                return image, bbox_abs  # Return original image if crop is not feasible
                
            print(f"{bbox_abs=}")
            print(f"{bbox=}")
            print(f"Crop ratio increased to {crop_ratio:.2f}. Retrying crop...")
            continue  # Retry if crop is not feasible
        break

    # Randomly select crop position
    left = np.random.randint(left_min, left_max + 1)
    top = np.random.randint(top_min, top_max + 1)
    right = left + crop_width
    bottom = top + crop_height

    # Crop the image
    cropped_image = image.crop((left, top, right, bottom))

    # Recalculate bbox relative to the cropped image
    bbox_x_min = (bbox_abs[0] - left) / crop_width
    bbox_y_min = (bbox_abs[1] - top) / crop_height
    bbox_x_max = (bbox_abs[2] - left) / crop_width
    bbox_y_max = (bbox_abs[3] - top) / crop_height

    updated_bbox = [bbox_x_min, bbox_y_min, bbox_x_max, bbox_y_max]

    return cropped_image, updated_bbox
from torch.utils.data import Dataset
from PIL import Image
import random

# Custom Dataset Class
class CustomDataset(Dataset):
    def __init__(self, original_dataset, transform=None, crop_ratio=None, eval_data=False, is_hinted=False):
        """
        A dataset that returns both original and cropped images with the same text.

        Args:
            original_dataset: The original dataset with {'image': PIL.Image, 'bbox': list, 'text': str}.
            transform: Image preprocessing transformations (e.g., ToTensor, Resize).
            crop_ratio: The initial crop ratio for random_crop_image (default 0.7).
        """
        self.original_dataset = original_dataset
        self.transform = transform
        self.crop_ratio = crop_ratio
        self.eval_data = eval_data
        self.is_hinted = is_hinted

    def __len__(self):
        # Length is doubled to include both original and cropped images
        if self.eval_data:
            return len(self.original_dataset)
        else:
            return len(self.original_dataset) * 2

    def __getitem__(self, idx):
        # Determine if this is an original or cropped image
        if self.eval_data:
            image = Image.open(io.BytesIO(self.original_dataset[idx]['image'])) # PIL Image
            messages = self.original_dataset[idx]['messages']  # List of messages
            if self.transform:
                image = self.transform(image)
            return {'image': image, 'messages':messages,}
        original_idx = idx // 2
        is_cropped = idx % 2 == 1

        # Fetch the sample from the original dataset
        sample = self.original_dataset[original_idx]
        image = Image.open(io.BytesIO(sample['image'])) # PIL Image
        messages = sample['messages']  # List of messages
        if sample['prediction'] is None or sample['prediction'][0] is None:
            return {'image': image, 'messages':messages,}
        center = [sample['prediction'][0],sample['prediction'][1]] # (x,y)
        bbox = [center[0] - 25, center[1] - 25, center[0] + 25, center[1] + 25] # (x_min, y_min, x_max, y_max)
        bbox = [bbox[0] / image.size[0], bbox[1] / image.size[1], bbox[2] / image.size[0], bbox[3] / image.size[1]] # Normalize to [0, 1]


        if is_cropped:
            # Generate cropped image using the provided function
            if self.crop_ratio is None:
                crop_ratio = random.uniform(0.3, 0.9)
            else:
                crop_ratio = self.crop_ratio
            cropped_image, cropped_bbox = random_crop_image(image, bbox, crop_ratio)
            new_center = [(cropped_bbox[0] + cropped_bbox[2]) / 2 * cropped_image.size[0], (cropped_bbox[1] + cropped_bbox[3]) / 2 * cropped_image.size[1]]
            if self.is_hinted:
                # Generate a hint with red bbox on the image
                imagedraw = ImageDraw.Draw(cropped_image)
                imagedraw.rectangle(cropped_bbox, outline="red", width=3)
            image = cropped_image
            response = sample['response']
            response = response.split('<answer>')[0] + f'<answer>({int(new_center[0])},{int(new_center[1])})</answer>'
            messages = [{
                        "role": "user",
                        "content": [
                            {"type": "image"},
                            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=sample["problem"])},
                        ]},
                        {"role": "assistant",
                        "content": [
                            {"type": "text", "text": f"{response}"},
                        ],
                    }]
                    

        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        return {'image': image, 'messages':messages,}

@dataclass
class SFTScriptArguments(ScriptArguments):
    crop_ratio: Optional[float] = field(
        default=None,
        metadata={"help": "The ratio of the original image size to crop."},
    )
    freeze_llm: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to freeze the LLM model."},
    )
    is_hinted: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use hinting."},
    )
    sequential_sampler: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use sequential sampling."},
    )

if __name__ == "__main__":
    parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
    training_args.remove_unused_columns = False
    training_args.dataset_kwargs = {"skip_prepare_dataset": True}

    ################
    # Model, Tokenizer & Processor
    ################
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    quantization_config = get_quantization_config(model_args)
    model_kwargs = dict(
        revision=model_args.model_revision,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    processor = AutoProcessor.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
    )

    model = AutoModelForVision2Seq.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
    )


    ################
    # Create a data collator to encode text and image pairs
    ################
    def collate_fn(examples):
        # Get the texts and images, and apply the chat template
        texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
        # images = [Image.open(io.BytesIO(example['image'])) for example in examples]
        images = [example['image'] for example in examples]
        if isinstance(model, LlavaForConditionalGeneration):
            # LLava1.5 does not support multiple images
            images = [image[0] for image in images]

        # Tokenize the texts and process the images
        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        # The labels are the input_ids, and we mask the padding tokens in the loss computation
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100  #
        # Ignore the image token index in the loss computation (model specific)
        if isinstance(processor, Qwen2VLProcessor) or isinstance(processor, Qwen2_5_VLProcessor):
            image_tokens = [151652,151653,151655]
        else:
            image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
        for image_token_id in image_tokens:
            labels[labels == image_token_id] = -100
        batch["labels"] = labels
        # print(batch)
        return batch

    def make_conversation_image(example):
        response = example['response']
        return {
                "messages":[{
                        "role": "user",
                        "content": [
                            {"type": "image"},
                            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
                        ]},
                        {"role": "assistant",
                        "content": [
                            {"type": "text", "text": f"{response}"},
                        ],
                    }],
                }

    if os.path.exists(script_args.dataset_name):
        dataset = load_from_disk(script_args.dataset_name)
    else:
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
    dataset = dataset.map(make_conversation_image)
    # Split the dataset into train and test sets
    # Reserve 1000 examples for the test set and use the rest for training
    total_examples = len(dataset['train'])
    test_size = 1000
    train_size = total_examples - test_size

    # Create a random split of the dataset
    split_dataset = dataset['train'].train_test_split(test_size=test_size, train_size=train_size, seed=42)

    # Rename the splits to match the expected names in the trainer
    dataset = DatasetDict({
        script_args.dataset_train_split: split_dataset["train"], 
        script_args.dataset_test_split: split_dataset["test"]
    })
    # Create custom datasets
    train_dataset = CustomDataset(dataset['train'],  crop_ratio=script_args.crop_ratio, is_hinted=script_args.is_hinted)
    eval_dataset = CustomDataset(dataset['test'], crop_ratio=script_args.crop_ratio, eval_data=True)

    ################
    # Training
    ################
    processor.tokenizer.padding_side = "left"


    trainer = CropTrainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset if training_args.eval_strategy != "no" else None,
        processing_class=processor.tokenizer,
        sequential_sampler=script_args.sequential_sampler,
    )



    trainer.train()

    # Save the model
    trainer.save_model(training_args.output_dir)
