import os
import io
import json
import math
import random
import struct
from PIL import Image
from tqdm import tqdm
import numpy as np
import re
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import os
import json


def number_to_words(n):
    # 定义阿拉伯数字到字母的映射
    num_to_word = {
        1: "one", 2: "two", 3: "three", 4: "four", 5: "five", 
        6: "six", 7: "seven", 8: "eight", 9: "nine", 10: "ten", 
        11: "eleven", 12: "twelve", 13: "thirteen", 14: "fourteen", 
        15: "fifteen", 16: "sixteen", 17: "seventeen", 18: "eighteen", 
        19: "nineteen", 20: "twenty", 30: "thirty", 40: "forty", 
        50: "fifty"
    }
    
    if n <= 20:
        return num_to_word[n]
    elif 21 <= n <= 50:
        tens = (n // 10) * 10
        ones = n % 10
        if ones == 0:
            return num_to_word[tens]
        else:
            return f"{num_to_word[tens]}-{num_to_word[ones]}"
    else:
        return "Number out of range"  # 可以根据需要调整

def process_bbox_info(bbox_info, original_width, original_height, target_width=1024, target_height=1024):
    scale_x = target_width / original_width
    scale_y = target_height / original_height

    processed_bbox_info = []

    for item in bbox_info:
        bbox = item['bbox']
        x_min = int(bbox[0] * scale_x)
        y_min = int(bbox[1] * scale_y)
        x_max = int(bbox[2] * scale_x)
        y_max = int(bbox[3] * scale_y)

        x_min = x_min - (x_min % 16)
        y_min = y_min - (y_min % 16)
        x_max = x_max + (16 - x_max % 16) if x_max % 16 != 0 else x_max
        y_max = y_max + (16 - y_max % 16) if y_max % 16 != 0 else y_max

        x_min = max(0, min(x_min, target_width))
        y_min = max(0, min(y_min, target_height))
        x_max = max(0, min(x_max, target_width))
        y_max = max(0, min(y_max, target_height))

        processed_bbox = [x_min, y_min, x_max, y_max]
        processed_bbox_info.append(processed_bbox)

    return processed_bbox_info



class Text2ImageRGBDataset(Dataset):
    def __init__(self, resolution, center_crop, random_flip, data_json_path, image_folder, train_with_one_sample, replace_number=False):
        """
        Args:
            resolution (int): The target resolution for the images.
            center_crop (bool): Whether to use center crop or random crop.
            random_flip (bool): Whether to apply random horizontal flip.
            data_json_path (string): Path to the JSON file containing image metadata.
            image_folder (string): Folder containing the image files.
            train_with_one_sample (bool): If True, always use the first sample for training.
        """
        self.image_folder = image_folder
        
        # Load the JSON data
        with open(data_json_path, 'r') as f:
            self.data_json_list = json.load(f)

        # Define the transformations
        self.train_resize = transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR)
        self.train_crop = transforms.CenterCrop(resolution) if center_crop else transforms.RandomCrop(resolution)
        self.train_flip = transforms.RandomHorizontalFlip(p=1.0)
        self.train_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # Normalize for RGB images
        ])
        
        self.random_flip = random_flip
        self.center_crop = center_crop
        self.resolution = resolution
        self.train_with_one_sample = train_with_one_sample
        self.replace_number = replace_number
    

    def __len__(self):
        return len(self.data_json_list)

    def __getitem__(self, idx):
        # Always use the first sample if training with one sample
        if self.train_with_one_sample:
            idx = 0
        
        # Load data item from the JSON
        data_item = self.data_json_list[idx]
        image_id = os.path.basename(data_item['img_path'])
        caption = data_item['prompt']
        try:
            bbox_info = data_item['bbox_info']
        except:
            bbox_info = None
        
        label = int(data_item['gt_num_obj']) - 1

        try:
            number_positions = data_item['number_positions']['positions']
        except:
            number_positions = None
        
        if self.replace_number:
            original_number_str = number_to_words(label+1)
        
            new_number = random.randint(1, 50)
            label_new = new_number-1
            new_number_str = number_to_words(new_number)
            
            caption = caption.replace(original_number_str, new_number_str)
        else:
            label_new = None
        

        # Construct image path and load the image
        # image_path = os.path.join(self.image_folder, f'{image_id}')
        image_path = data_item['img_path']
        image_pil = Image.open(image_path).convert('RGB')  # Load as RGB image

        # Resize the image
        original_size = (image_pil.height, image_pil.width)
        image_pil = self.train_resize(image_pil)

        # Apply random flip if required
        if self.random_flip and random.random() < 0.5:
            image_pil = self.train_flip(image_pil)

        # Apply crop (either center crop or random crop)
        if self.center_crop:
            image_pil = self.train_crop(image_pil)
        else:
            y1, x1, h, w = self.train_crop.get_params(image_pil, (self.resolution, self.resolution))
            image_pil = transforms.functional.crop(image_pil, y1, x1, h, w)
       
        # Convert to tensor and normalize
        image_pt = self.train_transforms(image_pil)

        return {
            "image_pt": image_pt, # [3, 512, 512]
            "caption": caption,
            "number_positions": number_positions,
            "label": label,
            "label_new": label_new,
            "bbox_info": bbox_info,
        }