"""
单文件夹航空目标双条件训练数据集 - Background + Crop (Normalized Loss) 版本
取消 Mask 输入，只保留 Subject (Crop) 和 Background
使用 Normalized Mask-Weighted Loss
"""

import os
import random
import numpy as np
import csv
from pathlib import Path
from typing import Tuple, List, Dict
from PIL import Image, ImageDraw, ImageFilter
import torch
import torchvision.transforms as T

# 复用 AircraftMaskWeightedDataset 的大部分逻辑
from .train_aircraft_mask_weighted import AircraftMaskWeightedDataset

class AircraftBackgroundCropDataset(AircraftMaskWeightedDataset):
    """
    单文件夹航空目标数据集 - Background + Crop 版本
    双条件控制：Subject (飞机裁剪), Background (背景)
    继承自 AircraftMaskWeightedDataset 复用大部分逻辑，但修改 __getitem__ 以去掉 Mask 输入
    """
    
    def __init__(
        self,
        dataset_root: str,
        condition_size: Tuple[int, int] = (512, 512),
        target_size: Tuple[int, int] = (512, 512),
        # Dropout 概率
        drop_text_prob: float = 0.1,
        drop_subject_prob: float = 0.1,
        drop_background_prob: float = 0.1,
        # Background 生成参数
        background_blur_prob: float = 0.3,
        # 数据增强参数
        augmentation_prob: float = 0.5,
        rotation_prob: float = 0.5,
        flip_prob: float = 0.5,
        color_jitter_prob: float = 0.5,
        brightness_range: Tuple[float, float] = (0.7, 1.3),
        contrast_range: Tuple[float, float] = (0.75, 1.25),
        saturation_range: Tuple[float, float] = (0.8, 1.2),
        hue_range: Tuple[float, float] = (-0.05, 0.05),
        return_pil_image: bool = False,
        **kwargs # 吸收多余参数
    ):
        # 调用父类初始化，传入 dummy 参数以满足父类签名
        super().__init__(
            dataset_root=dataset_root,
            condition_size=condition_size,
            target_size=target_size,
            drop_text_prob=drop_text_prob,
            drop_subject_prob=drop_subject_prob,
            drop_position_prob=0.0, # 不再使用
            drop_background_prob=drop_background_prob,
            min_mask_ratio=0.0,
            max_mask_ratio=0.0,
            background_blur_prob=background_blur_prob,
            augmentation_prob=augmentation_prob,
            rotation_prob=rotation_prob,
            flip_prob=flip_prob,
            color_jitter_prob=color_jitter_prob,
            brightness_range=brightness_range,
            contrast_range=contrast_range,
            saturation_range=saturation_range,
            hue_range=hue_range,
            return_pil_image=return_pil_image
        )

    def __getitem__(self, idx):
        # 1. 获取基础数据
        sample_info = self.samples[idx]
        
        # 加载图像
        image_path = sample_info["original"]
        mask_path = sample_info["mask"]
        bg_path = sample_info["background"]
        crop_path = sample_info["crop"]
        
        try:
            image = Image.open(image_path).convert("RGB")
            mask = Image.open(mask_path).convert("L")
            background = Image.open(bg_path).convert("RGB")
            # 使用父类的辅助方法加载 crop (处理透明背景)
            subject = self._load_rgba_with_white_background(crop_path)
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            # 出错时随机返回另一个样本
            return self.__getitem__(random.randint(0, len(self) - 1))
            
        # 2. Resize
        image = image.resize(self.target_size, Image.BILINEAR)
        mask = mask.resize(self.target_size, Image.NEAREST)
        background = background.resize(self.condition_size, Image.BILINEAR)
        subject = subject.resize(self.condition_size, Image.BILINEAR)
        
        # 3. 数据增强 (几何变换)
        if random.random() < self.augmentation_prob:
            # 复用父类的同步变换
            image, background, mask, subject = self._apply_sync_transforms(
                image, background, mask, subject
            )
        
        # 4. 颜色增强 (Color Jitter)
        if random.random() < self.color_jitter_prob:
            # target 和 background 使用相同的颜色抖动参数
            brightness_factor_tb = random.uniform(*self.brightness_range)
            contrast_factor_tb = random.uniform(*self.contrast_range)
            saturation_factor_tb = random.uniform(*self.saturation_range)
            hue_shift_tb = random.uniform(*self.hue_range)
            
            image = self._apply_color_jitter(
                image, brightness_factor_tb, contrast_factor_tb, 
                saturation_factor_tb, hue_shift_tb
            )
            background = self._apply_color_jitter(
                background, brightness_factor_tb, contrast_factor_tb, 
                saturation_factor_tb, hue_shift_tb
            )

        # 5. Background 处理 (模糊)
        if random.random() < self.background_blur_prob:
             # 随机选择模糊核大小 (必须是奇数)
            kernel_size = random.choice([3, 5, 7])
            background = self._apply_random_blur(background, kernel_size)
            # image 也可能模糊，参考父类逻辑
            if random.random() < 0.5:
                image = self._apply_random_blur(image, kernel_size)

        # 6. Dropout
        if random.random() < self.drop_subject_prob:
            subject = Image.new("RGB", self.condition_size, (128, 128, 128))
            
        if random.random() < self.drop_background_prob:
            background = Image.new("RGB", self.condition_size, (128, 128, 128))
            
        # Text Dropout
        category = sample_info.get("category", "object")
        category_name = category.replace("-", " ").lower()
        prompt = f"Place a {category_name} at the specified position"
        
        if random.random() < self.drop_text_prob:
            prompt = ""

        # 7. 准备返回值
        # 总是需要 tensor mask 用于 loss
        mask_tensor = self.to_tensor(mask)
        
        if self.return_pil_image:
            return {
                "image": image,
                "condition_0": subject,    # Subject
                "condition_1": background, # Background (原 condition_2)
                "description": prompt,
                "target_mask": mask_tensor, # 依然返回 mask 用于 loss 计算
                "position_delta_0": torch.tensor([0, 0]), 
                "position_delta_1": torch.tensor([0, 0]), 
            }
            
        # 转换为 Tensor 并归一化到 [-1, 1] (除 mask 外)
        # 修正：diffusers 的 image_processor 期望输入为 [0, 1]，它会在内部进行归一化
        image_tensor = self.to_tensor(image)
        subject_tensor = self.to_tensor(subject)
        background_tensor = self.to_tensor(background)
        
        # Position delta (在潜在空间中的偏移) - 保持与父类一致
        position_delta_subject = np.array([-16, -32])
        position_delta_background = np.array([16, -32])

        return {
            "image": image_tensor,
            "condition_0": subject_tensor,
            "condition_1": background_tensor,
            "description": prompt,
            "target_mask": mask_tensor,
            "condition_type_0": "subject",
            "condition_type_1": "background",
            "position_delta_0": position_delta_subject, 
            "position_delta_1": position_delta_background,
        }

@torch.no_grad()
def test_function_for_training(model, save_path, file_name):
    """训练时的测试函数：生成示例图像"""
    condition_size = model.training_config["dataset"]["condition_size"]
    target_size = model.training_config["dataset"]["target_size"]
    
    # 获取 adapter
    # 注意：adapter_names=[None, None, "subject", "background"]
    subject_adapter = model.adapter_names[2]
    background_adapter = model.adapter_names[3]
    
    from ..pipeline.flux_omini import Condition, generate
    
    test_list = []
    
    # 创建简单的测试样本
    subject_img = Image.new("RGB", condition_size, (128, 128, 128))
    background_img = Image.new("RGB", target_size, (200, 220, 255))
    
    # 创建两个条件
    subject_condition = Condition(subject_img, subject_adapter, [-16, -32])
    background_condition = Condition(background_img, background_adapter, [16, -32])
    
    prompt = "Place an aircraft at the specified position"
    test_list.append(([subject_condition, background_condition], prompt))
    
    # 生成测试图像
    os.makedirs(save_path, exist_ok=True)
    for i, (conditions, prompt) in enumerate(test_list):
        generator = torch.Generator(device=model.device)
        generator.manual_seed(42)
        
        res = generate(
            model.flux_pipe,
            prompt=prompt,
            conditions=conditions,
            height=target_size[1],
            width=target_size[0],
            generator=generator,
            model_config=model.model_config,
            kv_cache=model.model_config.get("independent_condition", False),
        )
        file_path = os.path.join(save_path, f"{file_name}_aircraft_{i}.jpg")
        res.images[0].save(file_path)


def main():
    """训练入口函数"""
    # 使用 trainer_mask_weighted_normalized 来获得归一化的 mask weighted loss
    from .trainer_mask_weighted_normalized import get_config, OminiModel, train
    
    # 初始化配置
    config = get_config()
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    print("=" * 70)
    print("Aircraft Background+Crop Training (Normalized Loss)")
    print("=" * 70)
    
    print("\n[1/3] Loading dataset...")
    dataset = AircraftBackgroundCropDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=dataset_config.get("drop_text_prob", 0.1),
        drop_subject_prob=dataset_config.get("drop_subject_prob", 0.1),
        drop_background_prob=dataset_config.get("drop_background_prob", 0.1),
        background_blur_prob=dataset_config.get("background_blur_prob", 0.3),
        augmentation_prob=dataset_config.get("augmentation_prob", 0.5),
        rotation_prob=dataset_config.get("rotation_prob", 0.5),
        flip_prob=dataset_config.get("flip_prob", 0.5),
        color_jitter_prob=dataset_config.get("color_jitter_prob", 0.5),
        brightness_range=tuple(dataset_config.get("brightness_range", [0.8, 1.2])),
        contrast_range=tuple(dataset_config.get("contrast_range", [0.8, 1.2])),
        saturation_range=tuple(dataset_config.get("saturation_range", [0.8, 1.2])),
        hue_range=tuple(dataset_config.get("hue_range", [-0.1, 0.1])),
    )
    print(f"  ✓ Dataset created: {len(dataset)} samples")
    print(f"    - Subject: aircraft crops")
    print(f"    - Background: background image")
    print(f"    - (No Position Mask Input)")
    print(f"    - Loss weighting: Normalized Mask-weighted MSE loss")
    
    if len(dataset) == 0:
        print("\n❌ Error: Dataset is empty!")
        return
    
    print("\n[2/3] Initializing model...")
    # 注意：这里我们只用两个adapter：Subject 和 Background
    model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "background"], 
        optimizer_config=training_config.get("optimizer", None),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )
    model.training_config = training_config
    print("  ✓ Model initialized")
    
    print("\n[3/3] Starting training...")
    train(dataset, model, config, test_function_for_training)


if __name__ == "__main__":
    main()