import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
import os
import random
import numpy as np

from PIL import Image, ImageDraw

from datasets import load_dataset

from .trainer import OminiModel, get_config, train
from ..pipeline.flux_omini import Condition, generate


class SubjectPositionBackgroundDataset(Dataset):
    """
    三条件控制数据集：Subject + Position + Background
    - 条件1: Subject 参考图（物体外观）
    - 条件2: Fill mask（生成位置）
    - 条件3: Background 图像（背景场景）
    """
    def __init__(
        self,
        base_dataset,
        background_dataset=None,  # 可选的独立背景数据集
        condition_size=(512, 512),
        target_size=(512, 512),
        image_size: int = 512,
        padding: int = 0,
        drop_text_prob: float = 0.1,
        drop_subject_prob: float = 0.1,
        drop_position_prob: float = 0.1,
        drop_background_prob: float = 0.1,
        min_mask_ratio: float = 0.15,
        max_mask_ratio: float = 0.6,
        background_blur_prob: float = 0.3,  # 背景模糊概率
        use_external_background: bool = False,  # 是否使用外部背景数据集
        return_pil_image: bool = False,
    ):
        self.base_dataset = base_dataset
        self.background_dataset = background_dataset
        self.condition_size = condition_size
        self.target_size = target_size
        self.image_size = image_size
        self.padding = padding
        self.drop_text_prob = drop_text_prob
        self.drop_subject_prob = drop_subject_prob
        self.drop_position_prob = drop_position_prob
        self.drop_background_prob = drop_background_prob
        self.min_mask_ratio = min_mask_ratio
        self.max_mask_ratio = max_mask_ratio
        self.background_blur_prob = background_blur_prob
        self.use_external_background = use_external_background

        self.return_pil_image = return_pil_image
        self.to_tensor = T.ToTensor()

    def __len__(self):
        return len(self.base_dataset) * 2

    def _generate_fill_mask(self, image, min_ratio=0.15, max_ratio=0.6):
        """生成随机矩形遮挡 mask"""
        w, h = image.size
        
        # 随机生成遮挡区域的大小比例
        area_ratio = random.uniform(min_ratio, max_ratio)
        aspect_ratio = random.uniform(0.5, 2.0)
        
        # 计算遮挡区域的宽高
        mask_area = w * h * area_ratio
        mask_h = int(np.sqrt(mask_area / aspect_ratio))
        mask_w = int(mask_h * aspect_ratio)
        
        # 确保不超出边界
        mask_w = min(mask_w, w)
        mask_h = min(mask_h, h)
        
        # 随机选择遮挡区域的位置
        x1 = random.randint(0, w - mask_w)
        y1 = random.randint(0, h - mask_h)
        x2 = x1 + mask_w
        y2 = y1 + mask_h
        
        # 创建遮挡 mask
        mask = Image.new("L", image.size, 0)
        draw = ImageDraw.Draw(mask)
        draw.rectangle([x1, y1, x2, y2], fill=255)
        
        # 生成带遮挡的图像（遮挡区域填充为黑色）
        masked_image = Image.composite(
            image, Image.new("RGB", image.size, (0, 0, 0)), mask
        )
        
        return masked_image, mask, [x1, y1, x2, y2]

    def _extract_background(self, target_image, mask):
        """
        从目标图像中提取背景（mask 外的区域）
        
        Args:
            target_image: 目标图像
            mask: 前景 mask (255=前景, 0=背景)
        
        Returns:
            background_image: 背景图像（前景区域填充为白色/模糊）
        """
        # 反转 mask（现在 255=背景, 0=前景）
        inverted_mask = Image.eval(mask, lambda a: 255 - a)
        
        # 方法1: 前景区域填充白色
        if random.random() > 0.5:
            background = Image.composite(
                target_image,
                Image.new("RGB", target_image.size, (255, 255, 255)),
                inverted_mask
            )
        # 方法2: 前景区域填充黑色
        else:
            background = Image.composite(
                target_image,
                Image.new("RGB", target_image.size, (0, 0, 0)),
                inverted_mask
            )
        
        # 可选：对背景进行轻微模糊
        if random.random() < self.background_blur_prob:
            from PIL import ImageFilter
            background = background.filter(ImageFilter.GaussianBlur(radius=2))
        
        return background

    def _get_external_background(self, idx):
        """从外部数据集获取背景图像"""
        if self.background_dataset is None:
            return None
        
        # 随机选择一张背景图像
        bg_idx = random.randint(0, len(self.background_dataset) - 1)
        bg_item = self.background_dataset[bg_idx]
        
        # 处理 WebDataset 格式
        if "jpg" in bg_item:
            bg_image = bg_item["jpg"]
        elif "image" in bg_item:
            bg_image = bg_item["image"]
        else:
            return None
        
        return bg_image.resize(self.target_size).convert("RGB")

    def __getitem__(self, idx):
        # 决定哪张图作为目标（0=左图, 1=右图）
        target = idx % 2
        item = self.base_dataset[idx // 2]

        # 从 Subjects200K 数据集裁剪左右两张图
        image = item["image"]
        left_img = image.crop(
            (
                self.padding,
                self.padding,
                self.image_size + self.padding,
                self.image_size + self.padding,
            )
        )
        right_img = image.crop(
            (
                self.image_size + self.padding * 2,
                self.padding,
                self.image_size * 2 + self.padding * 2,
                self.image_size + self.padding,
            )
        )

        # 确定目标图像和参考图像
        target_image, reference_img = (
            (left_img, right_img) if target == 0 else (right_img, left_img)
        )

        # 调整尺寸
        reference_img = reference_img.resize(self.condition_size).convert("RGB")
        target_image = target_image.resize(self.target_size).convert("RGB")

        # 生成 fill mask（在目标图像上）
        position_mask_img, mask, bbox = self._generate_fill_mask(
            target_image, self.min_mask_ratio, self.max_mask_ratio
        )

        # 生成背景图像
        if self.use_external_background and random.random() > 0.5:
            # 使用外部背景数据集
            background_img = self._get_external_background(idx)
            if background_img is None:
                # 如果获取失败，回退到提取背景
                background_img = self._extract_background(target_image, mask)
        else:
            # 从目标图像提取背景
            background_img = self._extract_background(target_image, mask)

        # 获取描述文本
        description = item["description"][
            "description_0" if target == 0 else "description_1"
        ]

        # 数据增强：随机 dropout
        drop_text = random.random() < self.drop_text_prob
        drop_subject = random.random() < self.drop_subject_prob
        drop_position = random.random() < self.drop_position_prob
        drop_background = random.random() < self.drop_background_prob

        if drop_text:
            description = ""
        
        if drop_subject:
            # Subject 条件 dropout：用黑色图像替换
            reference_img = Image.new("RGB", self.condition_size, (0, 0, 0))
        
        if drop_position:
            # Position 条件 dropout：不遮挡任何区域
            position_mask_img = target_image
        
        if drop_background:
            # Background 条件 dropout：用白色/灰色图像替换
            background_img = Image.new("RGB", self.target_size, (128, 128, 128))

        # 条件设置
        # 条件0: Subject 参考图（位于上方偏左）
        subject_position_delta = np.array([-self.condition_size[0] // 32, -self.condition_size[0] // 16])
        
        # 条件1: Fill mask（位于原位置）
        position_position_delta = np.array([0, 0])
        
        # 条件2: Background（位于上方偏右）
        background_position_delta = np.array([self.condition_size[0] // 32, -self.condition_size[0] // 16])

        return {
            "image": self.to_tensor(target_image),
            # 条件0: Subject reference image
            "condition_0": self.to_tensor(reference_img),
            "condition_type_0": "subject",
            "position_delta_0": subject_position_delta,
            # 条件1: Position fill mask
            "condition_1": self.to_tensor(position_mask_img),
            "condition_type_1": "fill",
            "position_delta_1": position_position_delta,
            # 条件2: Background image
            "condition_2": self.to_tensor(background_img),
            "condition_type_2": "background",
            "position_delta_2": background_position_delta,
            "description": description,
            **({"pil_image": [target_image, reference_img, position_mask_img, background_img]} 
               if self.return_pil_image else {}),
        }


@torch.no_grad()
def test_function(model, save_path, file_name):
    """测试函数：生成示例图像"""
    condition_size = model.training_config["dataset"]["condition_size"]
    target_size = model.training_config["dataset"]["target_size"]

    # 获取三个 adapter
    subject_adapter = model.adapter_names[2]     # subject adapter
    fill_adapter = model.adapter_names[3]        # fill adapter
    background_adapter = model.adapter_names[4]  # background adapter
    
    test_list = []

    # 测试用例1: 完整三条件控制
    if os.path.exists("assets/test_in.jpg"):
        # Subject 参考图
        reference_img = Image.open("assets/test_in.jpg").resize(condition_size)
        
        # Background 图像（使用另一张图或创建简单背景）
        if os.path.exists("assets/vase_hq.jpg"):
            background_img = Image.open("assets/vase_hq.jpg").resize(target_size)
        else:
            # 创建简单渐变背景
            background_img = Image.new("RGB", target_size, (200, 220, 255))
        
        # Fill mask（中心区域）
        target_blank = Image.new("RGB", target_size, (255, 255, 255))
        mask = Image.new("L", target_size, 0)
        draw = ImageDraw.Draw(mask)
        w, h = target_size
        x1, y1 = w // 4, h // 4
        x2, y2 = w * 3 // 4, h * 3 // 4
        draw.rectangle([x1, y1, x2, y2], fill=255)
        fill_mask_img = Image.composite(
            target_blank, Image.new("RGB", target_size, (0, 0, 0)), mask
        )
        
        # 创建三个条件
        subject_condition = Condition(
            reference_img, 
            subject_adapter, 
            [-condition_size[0] // 32, -condition_size[0] // 16]
        )
        position_condition = Condition(
            fill_mask_img,
            fill_adapter,
            [0, 0]
        )
        background_condition = Condition(
            background_img,
            background_adapter,
            [condition_size[0] // 32, -condition_size[0] // 16]
        )
        
        prompt = "The object is placed in the center of a beautiful scene."
        test_list.append(([subject_condition, position_condition, background_condition], prompt))

    # 测试用例2: 不同位置
    if os.path.exists("assets/test_out.jpg"):
        reference_img = Image.open("assets/test_out.jpg").resize(condition_size)
        background_img = Image.new("RGB", target_size, (180, 200, 220))
        
        # 左上角区域
        target_blank = Image.new("RGB", target_size, (255, 255, 255))
        mask = Image.new("L", target_size, 0)
        draw = ImageDraw.Draw(mask)
        w, h = target_size
        draw.rectangle([0, 0, w // 2, h // 2], fill=255)
        fill_mask_img = Image.composite(
            target_blank, Image.new("RGB", target_size, (0, 0, 0)), mask
        )
        
        subject_condition = Condition(reference_img, subject_adapter, 
                                     [-condition_size[0] // 32, -condition_size[0] // 16])
        position_condition = Condition(fill_mask_img, fill_adapter, [0, 0])
        background_condition = Condition(background_img, background_adapter, 
                                        [condition_size[0] // 32, -condition_size[0] // 16])
        
        prompt = "The object is in the upper left corner."
        test_list.append(([subject_condition, position_condition, background_condition], prompt))

    # 生成测试图像
    os.makedirs(save_path, exist_ok=True)
    for i, (conditions, prompt) in enumerate(test_list):
        generator = torch.Generator(device=model.device)
        generator.manual_seed(42)

        res = generate(
            model.flux_pipe,
            prompt=prompt,
            conditions=conditions,
            height=target_size[1],
            width=target_size[0],
            generator=generator,
            model_config=model.model_config,
            kv_cache=model.model_config.get("independent_condition", False),
        )
        file_path = os.path.join(save_path, f"{file_name}_triple_control_{i}.jpg")
        res.images[0].save(file_path)


def main():
    # 初始化配置
    config = get_config()
    training_config = config["train"]
    torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))

    # 加载 Subjects200K 数据集
    raw_dataset = load_dataset("Yuanshi/Subjects200K")

    # 可选：加载外部背景数据集（如 text-to-image-2M）
    background_dataset = None
    if training_config["dataset"].get("use_external_background", False):
        print("Loading external background dataset...")
        background_urls = training_config["dataset"].get("background_urls", [])
        if background_urls:
            background_dataset = load_dataset(
                "webdataset",
                data_files={"train": background_urls},
                split="train",
                cache_dir="cache/background",
                num_proc=16,
            )
            print(f"Loaded {len(background_dataset)} background images")

    # 过滤低质量样本
    def filter_func(item):
        if not item.get("quality_assessment"):
            return False
        return all(
            item["quality_assessment"].get(key, 0) >= 5
            for key in ["compositeStructure", "objectConsistency", "imageQuality"]
        )

    # 数据过滤
    if not os.path.exists("./cache/dataset"):
        os.makedirs("./cache/dataset")
    data_valid = raw_dataset["train"].filter(
        filter_func,
        num_proc=16,
        cache_file_name="./cache/dataset/data_valid_triple.arrow",
    )

    # 初始化数据集
    dataset = SubjectPositionBackgroundDataset(
        data_valid,
        background_dataset=background_dataset,
        condition_size=training_config["dataset"]["condition_size"],
        target_size=training_config["dataset"]["target_size"],
        image_size=training_config["dataset"]["image_size"],
        padding=training_config["dataset"]["padding"],
        drop_text_prob=training_config["dataset"]["drop_text_prob"],
        drop_subject_prob=training_config["dataset"].get("drop_subject_prob", 0.1),
        drop_position_prob=training_config["dataset"].get("drop_position_prob", 0.1),
        drop_background_prob=training_config["dataset"].get("drop_background_prob", 0.1),
        min_mask_ratio=training_config["dataset"].get("min_mask_ratio", 0.15),
        max_mask_ratio=training_config["dataset"].get("max_mask_ratio", 0.6),
        background_blur_prob=training_config["dataset"].get("background_blur_prob", 0.3),
        use_external_background=training_config["dataset"].get("use_external_background", False),
    )

    # 初始化模型
    trainable_model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_config=training_config["lora_config"],
        device=f"cuda",
        dtype=getattr(torch, config["dtype"]),
        optimizer_config=training_config["optimizer"],
        model_config=config.get("model", {}),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )

    train(dataset, trainable_model, config, test_function)


if __name__ == "__main__":
    main()
