import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
import os
import random
import numpy as np

from PIL import Image, ImageDraw

from datasets import load_dataset

from .trainer import OminiModel, get_config, train
from ..pipeline.flux_omini import Condition, generate


class SubjectPositionDataset(Dataset):
    """
    结合 Subject 和 Position 控制的数据集
    - 使用 Subjects200K 数据集的双图结构
    - 一张图作为 subject 参考（条件1）
    - 另一张图作为目标，并生成 fill mask 指示位置（条件2）
    """
    def __init__(
        self,
        base_dataset,
        condition_size=(512, 512),
        target_size=(512, 512),
        image_size: int = 512,
        padding: int = 0,
        drop_text_prob: float = 0.1,
        drop_subject_prob: float = 0.1,
        drop_position_prob: float = 0.1,
        min_mask_ratio: float = 0.15,
        max_mask_ratio: float = 0.6,
        return_pil_image: bool = False,
    ):
        self.base_dataset = base_dataset
        self.condition_size = condition_size
        self.target_size = target_size
        self.image_size = image_size
        self.padding = padding
        self.drop_text_prob = drop_text_prob
        self.drop_subject_prob = drop_subject_prob
        self.drop_position_prob = drop_position_prob
        self.min_mask_ratio = min_mask_ratio
        self.max_mask_ratio = max_mask_ratio
        self.return_pil_image = return_pil_image

        self.to_tensor = T.ToTensor()

    def __len__(self):
        return len(self.base_dataset) * 2

    def _generate_fill_mask(self, image, min_ratio=0.15, max_ratio=0.6):
        """
        在图像上生成随机矩形遮挡，用于指示物体应该出现的位置
        
        Args:
            image: PIL Image
            min_ratio: 遮挡区域最小占比
            max_ratio: 遮挡区域最大占比
        
        Returns:
            masked_image: 带遮挡的图像
            mask: 遮挡mask (0=保留, 255=遮挡)
            bbox: 遮挡区域的坐标 [x1, y1, x2, y2]
        """
        w, h = image.size
        
        # 随机生成遮挡区域的大小比例
        area_ratio = random.uniform(min_ratio, max_ratio)
        aspect_ratio = random.uniform(0.5, 2.0)  # 宽高比
        
        # 计算遮挡区域的宽高
        mask_area = w * h * area_ratio
        mask_h = int(np.sqrt(mask_area / aspect_ratio))
        mask_w = int(mask_h * aspect_ratio)
        
        # 确保不超出边界
        mask_w = min(mask_w, w)
        mask_h = min(mask_h, h)
        
        # 随机选择遮挡区域的位置
        x1 = random.randint(0, w - mask_w)
        y1 = random.randint(0, h - mask_h)
        x2 = x1 + mask_w
        y2 = y1 + mask_h
        
        # 创建遮挡mask
        mask = Image.new("L", image.size, 0)
        draw = ImageDraw.Draw(mask)
        draw.rectangle([x1, y1, x2, y2], fill=255)
        
        # 生成带遮挡的图像（遮挡区域填充为黑色）
        masked_image = Image.composite(
            image, Image.new("RGB", image.size, (0, 0, 0)), mask
        )
        
        return masked_image, mask, [x1, y1, x2, y2]

    def __getitem__(self, idx):
        # 决定哪张图作为目标（0=左图, 1=右图）
        target = idx % 2
        item = self.base_dataset[idx // 2]

        # 从 Subjects200K 数据集裁剪左右两张图
        image = item["image"]
        left_img = image.crop(
            (
                self.padding,
                self.padding,
                self.image_size + self.padding,
                self.image_size + self.padding,
            )
        )
        right_img = image.crop(
            (
                self.image_size + self.padding * 2,
                self.padding,
                self.image_size * 2 + self.padding * 2,
                self.image_size + self.padding,
            )
        )

        # 确定目标图像和参考图像
        target_image, reference_img = (
            (left_img, right_img) if target == 0 else (right_img, left_img)
        )

        # 调整尺寸
        reference_img = reference_img.resize(self.condition_size).convert("RGB")
        target_image = target_image.resize(self.target_size).convert("RGB")

        # 生成 fill mask（在目标图像上）
        position_mask_img, mask, bbox = self._generate_fill_mask(
            target_image, self.min_mask_ratio, self.max_mask_ratio
        )

        # 获取描述文本
        description = item["description"][
            "description_0" if target == 0 else "description_1"
        ]

        # 数据增强：随机 dropout
        drop_text = random.random() < self.drop_text_prob
        drop_subject = random.random() < self.drop_subject_prob
        drop_position = random.random() < self.drop_position_prob

        if drop_text:
            description = ""
        
        if drop_subject:
            # Subject 条件 dropout：用黑色图像替换
            reference_img = Image.new("RGB", self.condition_size, (0, 0, 0))
        
        if drop_position:
            # Position 条件 dropout：不遮挡任何区域（全保留）
            position_mask_img = target_image

        # 条件1: Subject 参考图（位于上方）
        # 16 是 VAE 的下采样因子
        subject_position_delta = np.array([0, -self.condition_size[0] // 16])
        
        # 条件2: Fill mask（位于原位置）
        position_position_delta = np.array([0, 0])

        return {
            "image": self.to_tensor(target_image),
            # 条件0: Subject reference image
            "condition_0": self.to_tensor(reference_img),
            "condition_type_0": "subject",
            "position_delta_0": subject_position_delta,
            # 条件1: Position fill mask
            "condition_1": self.to_tensor(position_mask_img),
            "condition_type_1": "fill",
            "position_delta_1": position_position_delta,
            "description": description,
            **({"pil_image": [target_image, reference_img, position_mask_img]} 
               if self.return_pil_image else {}),
        }


@torch.no_grad()
def test_function(model, save_path, file_name):
    """测试函数：生成示例图像"""
    condition_size = model.training_config["dataset"]["condition_size"]
    target_size = model.training_config["dataset"]["target_size"]

    # 获取两个 adapter
    subject_adapter = model.adapter_names[2]  # subject adapter
    fill_adapter = model.adapter_names[3]     # fill adapter
    
    test_list = []

    # 测试用例1: 使用真实参考图像
    if os.path.exists("assets/test_in.jpg"):
        reference_img = Image.open("assets/test_in.jpg").resize(condition_size)
        
        # 创建一个 fill mask（指示生成位置）
        target_blank = Image.new("RGB", target_size, (255, 255, 255))
        mask = Image.new("L", target_size, 0)
        draw = ImageDraw.Draw(mask)
        # 在中心区域创建遮挡
        w, h = target_size
        x1, y1 = w // 4, h // 4
        x2, y2 = w * 3 // 4, h * 3 // 4
        draw.rectangle([x1, y1, x2, y2], fill=255)
        fill_mask_img = Image.composite(
            target_blank, Image.new("RGB", target_size, (0, 0, 0)), mask
        )
        
        # Subject 条件
        subject_condition = Condition(
            reference_img, 
            subject_adapter, 
            [0, -condition_size[0] // 16]
        )
        # Position 条件
        position_condition = Condition(
            fill_mask_img,
            fill_adapter,
            [0, 0]
        )
        
        prompt = "A beautiful object placed in the center of a bright room."
        test_list.append(([subject_condition, position_condition], prompt))

    # 测试用例2: OOD 测试
    if os.path.exists("assets/test_out.jpg"):
        reference_img = Image.open("assets/test_out.jpg").resize(condition_size)
        
        # 在不同位置创建遮挡
        target_blank = Image.new("RGB", target_size, (255, 255, 255))
        mask = Image.new("L", target_size, 0)
        draw = ImageDraw.Draw(mask)
        w, h = target_size
        # 左上角区域
        draw.rectangle([0, 0, w // 2, h // 2], fill=255)
        fill_mask_img = Image.composite(
            target_blank, Image.new("RGB", target_size, (0, 0, 0)), mask
        )
        
        subject_condition = Condition(
            reference_img, 
            subject_adapter, 
            [0, -condition_size[0] // 16]
        )
        position_condition = Condition(
            fill_mask_img,
            fill_adapter,
            [0, 0]
        )
        
        prompt = "It is placed in the upper left corner of a table."
        test_list.append(([subject_condition, position_condition], prompt))

    # 生成测试图像
    os.makedirs(save_path, exist_ok=True)
    for i, (conditions, prompt) in enumerate(test_list):
        generator = torch.Generator(device=model.device)
        generator.manual_seed(42)

        res = generate(
            model.flux_pipe,
            prompt=prompt,
            conditions=conditions,
            height=target_size[1],
            width=target_size[0],
            generator=generator,
            model_config=model.model_config,
            kv_cache=model.model_config.get("independent_condition", False),
        )
        file_path = os.path.join(save_path, f"{file_name}_subject_position_{i}.jpg")
        res.images[0].save(file_path)


def main():
    # 初始化配置
    config = get_config()
    training_config = config["train"]
    torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))

    # 加载 Subjects200K 数据集
    raw_dataset = load_dataset("Yuanshi/Subjects200K")

    # 过滤低质量样本
    def filter_func(item):
        if not item.get("quality_assessment"):
            return False
        return all(
            item["quality_assessment"].get(key, 0) >= 5
            for key in ["compositeStructure", "objectConsistency", "imageQuality"]
        )

    # 数据过滤
    if not os.path.exists("./cache/dataset"):
        os.makedirs("./cache/dataset")
    data_valid = raw_dataset["train"].filter(
        filter_func,
        num_proc=16,
        cache_file_name="./cache/dataset/data_valid_subject_position.arrow",
    )

    # 初始化数据集
    dataset = SubjectPositionDataset(
        data_valid,
        condition_size=training_config["dataset"]["condition_size"],
        target_size=training_config["dataset"]["target_size"],
        image_size=training_config["dataset"]["image_size"],
        padding=training_config["dataset"]["padding"],
        drop_text_prob=training_config["dataset"]["drop_text_prob"],
        drop_subject_prob=training_config["dataset"].get("drop_subject_prob", 0.1),
        drop_position_prob=training_config["dataset"].get("drop_position_prob", 0.1),
        min_mask_ratio=training_config["dataset"].get("min_mask_ratio", 0.15),
        max_mask_ratio=training_config["dataset"].get("max_mask_ratio", 0.6),
    )

    # 初始化模型
    trainable_model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_config=training_config["lora_config"],
        device=f"cuda",
        dtype=getattr(torch, config["dtype"]),
        optimizer_config=training_config["optimizer"],
        model_config=config.get("model", {}),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )

    train(dataset, trainable_model, config, test_function)


if __name__ == "__main__":
    main()
