"""
单文件夹航空目标双条件训练数据集 - Background + Crop 版本
取消 Mask 输入，只保留 Subject (Crop) 和 Background
"""

import os
import random
import numpy as np
import csv
from pathlib import Path
from typing import Tuple, List, Dict
from PIL import Image, ImageDraw, ImageFilter
import torch
import torchvision.transforms as T

# 复用 AircraftMaskWeightedDataset 的大部分逻辑
from .train_aircraft_mask_weighted import AircraftMaskWeightedDataset

class AircraftBackgroundCropDataset(AircraftMaskWeightedDataset):
    """
    单文件夹航空目标数据集 - Background + Crop 版本
    双条件控制：Subject (飞机裁剪), Background (背景)
    继承自 AircraftMaskWeightedDataset 复用大部分逻辑，但修改 __getitem__ 以去掉 Mask 输入
    """
    
    def __init__(
        self,
        dataset_root: str,
        condition_size: Tuple[int, int] = (512, 512),
        target_size: Tuple[int, int] = (512, 512),
        # Dropout 概率
        drop_text_prob: float = 0.1,
        drop_subject_prob: float = 0.1,
        drop_background_prob: float = 0.1,
        # Background 生成参数
        background_blur_prob: float = 0.3,
        # 数据增强参数
        augmentation_prob: float = 0.5,
        rotation_prob: float = 0.5,
        flip_prob: float = 0.5,
        color_jitter_prob: float = 0.5,
        brightness_range: Tuple[float, float] = (0.7, 1.3),
        contrast_range: Tuple[float, float] = (0.75, 1.25),
        saturation_range: Tuple[float, float] = (0.8, 1.2),
        hue_range: Tuple[float, float] = (-0.05, 0.05),
        return_pil_image: bool = False,
        **kwargs # 吸收多余参数
    ):
        # Handle multiple dataset roots (copy from AircraftSolarDataset)
        if isinstance(dataset_root, list):
            self.dataset_roots = [Path(root) for root in dataset_root]
            # Use the first root as the primary one for superclass init
            primary_root = self.dataset_roots[0]
        else:
            self.dataset_roots = [Path(dataset_root)]
            primary_root = dataset_root

        # 调用父类初始化，传入 dummy 参数以满足父类签名
        super().__init__(
            dataset_root=primary_root,
            condition_size=condition_size,
            target_size=target_size,
            drop_text_prob=drop_text_prob,
            drop_subject_prob=drop_subject_prob,
            drop_position_prob=0.0, # 不再使用
            drop_background_prob=drop_background_prob,
            min_mask_ratio=0.0,
            max_mask_ratio=0.0,
            background_blur_prob=background_blur_prob,
            augmentation_prob=augmentation_prob,
            rotation_prob=rotation_prob,
            flip_prob=flip_prob,
            color_jitter_prob=color_jitter_prob,
            brightness_range=brightness_range,
            contrast_range=contrast_range,
            saturation_range=saturation_range,
            hue_range=hue_range,
            return_pil_image=return_pil_image
        )
        
        # Re-scan datasets from all roots (Ensure consistency with AircraftSolarDataset)
        self.category_map = self._load_category_map()
        self.samples = self._scan_dataset()

    def _load_category_map(self) -> Dict[str, str]:
        category_map = {}
        
        for root in self.dataset_roots:
            csv_path = root / "crop_info.csv"
            if not csv_path.exists():
                print(f"⚠️  Warning: {csv_path} not found")
                continue
                
            try:
                with open(csv_path, 'r', encoding='utf-8') as f:
                    reader = csv.DictReader(f)
                    for row in reader:
                        filename = row['output_name']
                        category = row['category']
                        category_map[filename] = category
            except Exception as e:
                print(f"⚠️  Warning: Failed to load crop_info.csv from {root}: {e}")
                
        print(f"✓ Loaded {len(category_map)} category mappings from {len(self.dataset_roots)} roots")
        return category_map

    def _scan_dataset(self) -> List[dict]:
        samples = []
        
        for root in self.dataset_roots:
            root_samples_count = 0
            original_dir = root / "Original"
            if not original_dir.exists():
                print(f"❌ Error: {original_dir} does not exist!")
                continue
                
            print(f"Scanning {root}...")
            # 遍历所有 .png 和 .jpg 文件
            all_files = list(original_dir.glob("*.png")) + list(original_dir.glob("*.jpg"))
            
            for orig_file in all_files:
                # 文件名（不含扩展名）
                base_name = orig_file.stem
                
                # 构造对应的文件路径
                # Note: AircraftMaskWeightedDataset uses Masks2
                # Check for mask with both extensions
                mask_file = root / "Masks2" / f"{base_name}.png"
                if not mask_file.exists():
                     mask_file = root / "Masks2" / f"{base_name}.jpg"
                
                bg_file = root / "Background_Erased" / f"{base_name}.png"
                if not bg_file.exists():
                    bg_file = root / "Background_Erased" / f"{base_name}.jpg"
                
                # Check for Crops or Crops_Blurred with both extensions
                crop_file = None
                for folder in ["Crops_Blurred", "Crops"]:
                    for ext in [".png", ".jpg"]:
                        candidate = root / folder / f"{base_name}{ext}"
                        if candidate.exists():
                            crop_file = candidate
                            break
                    if crop_file: break
                
                if crop_file is None:
                    continue
                    
                if not mask_file.exists() or not bg_file.exists():
                    continue
                
                # 验证 background 是否有效
                try:
                    # Lazy validation to speed up scanning? 
                    # Or just check file size > 0
                    if bg_file.stat().st_size == 0: continue
                except Exception:
                    continue
                
                # 解析场景和对象信息
                parts = base_name.split("_")
                if len(parts) >= 4:
                    scene_id = parts[1]
                    object_id = parts[-1]
                else:
                    scene_id = "unknown"
                    object_id = "unknown"
                
                category = self.category_map.get(base_name, "object")
                
                samples.append({
                    "original": str(orig_file),
                    "mask": str(mask_file),
                    "background": str(bg_file),
                    "crop": str(crop_file),
                    "scene_id": scene_id,
                    "object_id": object_id,
                    "category": category,
                })
                root_samples_count += 1
            
            print(f"  Found {root_samples_count} samples in {root}")
                
        return samples
        
        # Re-scan datasets from all roots (Ensure consistency with AircraftSolarDataset)
        self.category_map = self._load_category_map()
        self.samples = self._scan_dataset()

    def _load_category_map(self) -> Dict[str, str]:
        category_map = {}
        
        for root in self.dataset_roots:
            csv_path = root / "crop_info.csv"
            if not csv_path.exists():
                print(f"⚠️  Warning: {csv_path} not found")
                continue
                
            try:
                with open(csv_path, 'r', encoding='utf-8') as f:
                    reader = csv.DictReader(f)
                    for row in reader:
                        filename = row['output_name']
                        category = row['category']
                        category_map[filename] = category
            except Exception as e:
                print(f"⚠️  Warning: Failed to load crop_info.csv from {root}: {e}")
                
        print(f"✓ Loaded {len(category_map)} category mappings from {len(self.dataset_roots)} roots")
        return category_map

    def _scan_dataset(self) -> List[dict]:
        samples = []
        
        for root in self.dataset_roots:
            root_samples_count = 0
            original_dir = root / "Original"
            if not original_dir.exists():
                print(f"❌ Error: {original_dir} does not exist!")
                continue
                
            print(f"Scanning {root}...")
            # 遍历所有 .png 和 .jpg 文件
            all_files = list(original_dir.glob("*.png")) + list(original_dir.glob("*.jpg"))
            
            for orig_file in all_files:
                # 文件名（不含扩展名）
                base_name = orig_file.stem
                
                # 构造对应的文件路径
                # Note: AircraftMaskWeightedDataset uses Masks2
                # Check for mask with both extensions
                mask_file = root / "Masks2" / f"{base_name}.png"
                if not mask_file.exists():
                     mask_file = root / "Masks2" / f"{base_name}.jpg"
                
                bg_file = root / "Background_Erased" / f"{base_name}.png"
                if not bg_file.exists():
                    bg_file = root / "Background_Erased" / f"{base_name}.jpg"
                
                # Check for Crops or Crops_Blurred with both extensions
                crop_file = None
                for folder in ["Crops_Blurred", "Crops"]:
                    for ext in [".png", ".jpg"]:
                        candidate = root / folder / f"{base_name}{ext}"
                        if candidate.exists():
                            crop_file = candidate
                            break
                    if crop_file: break
                
                if crop_file is None:
                    continue
                    
                if not mask_file.exists() or not bg_file.exists():
                    continue
                
                # 验证 background 是否有效
                try:
                    # Lazy validation to speed up scanning? 
                    # Or just check file size > 0
                    if bg_file.stat().st_size == 0: continue
                except Exception:
                    continue
                
                # 解析场景和对象信息
                parts = base_name.split("_")
                if len(parts) >= 4:
                    scene_id = parts[1]
                    object_id = parts[-1]
                else:
                    scene_id = "unknown"
                    object_id = "unknown"
                
                category = self.category_map.get(base_name, "object")
                
                samples.append({
                    "original": str(orig_file),
                    "mask": str(mask_file),
                    "background": str(bg_file),
                    "crop": str(crop_file),
                    "scene_id": scene_id,
                    "object_id": object_id,
                    "category": category,
                })
                root_samples_count += 1
            
            print(f"  Found {root_samples_count} samples in {root}")
                
        return samples

    def __getitem__(self, idx):
        # 1. 获取基础数据
        sample_info = self.samples[idx]
        
        # 加载图像
        image_path = sample_info["original"]
        mask_path = sample_info["mask"]
        bg_path = sample_info["background"]
        crop_path = sample_info["crop"]
        
        try:
            image = Image.open(image_path).convert("RGB")
            mask = Image.open(mask_path).convert("L")
            background = Image.open(bg_path).convert("RGB")
            # 使用父类的辅助方法加载 crop (处理透明背景)
            subject = self._load_rgba_with_white_background(crop_path)
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            # 出错时随机返回另一个样本
            return self.__getitem__(random.randint(0, len(self) - 1))
            
        # 2. Resize
        image = image.resize(self.target_size, Image.BILINEAR)
        mask = mask.resize(self.target_size, Image.NEAREST)
        background = background.resize(self.condition_size, Image.BILINEAR)
        subject = subject.resize(self.condition_size, Image.BILINEAR)
        
        # 3. 数据增强 (几何变换)
        if random.random() < self.augmentation_prob:
            # 复用父类的同步变换
            image, background, mask, subject = self._apply_sync_transforms(
                image, background, mask, subject
            )
        
        # 4. 颜色增强 (Color Jitter)
        if random.random() < self.color_jitter_prob:
            # target 和 background 使用相同的颜色抖动参数
            brightness_factor_tb = random.uniform(*self.brightness_range)
            contrast_factor_tb = random.uniform(*self.contrast_range)
            saturation_factor_tb = random.uniform(*self.saturation_range)
            hue_shift_tb = random.uniform(*self.hue_range)
            
            image = self._apply_color_jitter(
                image, brightness_factor_tb, contrast_factor_tb, 
                saturation_factor_tb, hue_shift_tb
            )
            background = self._apply_color_jitter(
                background, brightness_factor_tb, contrast_factor_tb, 
                saturation_factor_tb, hue_shift_tb
            )

        # 5. Background 处理 (模糊)
        if random.random() < self.background_blur_prob:
             # 随机选择模糊核大小 (必须是奇数)
            kernel_size = random.choice([3, 5, 7])
            background = self._apply_random_blur(background, kernel_size)
            # image 也可能模糊，参考父类逻辑
            if random.random() < 0.5:
                image = self._apply_random_blur(image, kernel_size)

        # 6. Dropout
        if random.random() < self.drop_subject_prob:
            subject = Image.new("RGB", self.condition_size, (128, 128, 128))
            
        if random.random() < self.drop_background_prob:
            background = Image.new("RGB", self.condition_size, (128, 128, 128))
            
        # Text Dropout
        category = sample_info.get("category", "object")
        category_name = category.replace("-", " ").lower()
        prompt = f"Place a {category_name} at the specified position"
        
        if random.random() < self.drop_text_prob:
            prompt = ""

        # 7. 准备返回值
        # 总是需要 tensor mask 用于 loss
        mask_tensor = self.to_tensor(mask)
        
        if self.return_pil_image:
            return {
                "image": image,
                "condition_0": subject,    # Subject
                "condition_1": background, # Background (原 condition_2)
                "description": prompt,
                "target_mask": mask_tensor, # 依然返回 mask 用于 loss 计算
                "position_delta_0": torch.tensor([0, 0]), 
                "position_delta_1": torch.tensor([0, 0]), 
            }
            
        # 转换为 Tensor 并归一化到 [-1, 1] (除 mask 外)
        # 修正：diffusers 的 image_processor 期望输入为 [0, 1]，它会在内部进行归一化
        # 如果这里先归一化到 [-1, 1]，会导致二次归一化 (变为 [-3, 1])，并触发警告
        image_tensor = self.to_tensor(image)
        subject_tensor = self.to_tensor(subject)
        background_tensor = self.to_tensor(background)
        
        # Position delta (在潜在空间中的偏移) - 保持与父类一致
        position_delta_subject = np.array([-16, -32])
        position_delta_background = np.array([16, -32])

        return {
            "image": image_tensor,
            "condition_0": subject_tensor,
            "condition_1": background_tensor,
            "description": prompt,
            "target_mask": mask_tensor,
            "condition_type_0": "subject",
            "condition_type_1": "background",
            "position_delta_0": position_delta_subject, 
            "position_delta_1": position_delta_background,
        }

@torch.no_grad()
def test_function_for_training(model, save_path, file_name):
    """训练时的测试函数：生成示例图像"""
    condition_size = model.training_config["dataset"]["condition_size"]
    target_size = model.training_config["dataset"]["target_size"]
    
    # 获取 adapter
    # 注意：adapter_names=[None, None, "subject", "background"]
    subject_adapter = model.adapter_names[2]
    background_adapter = model.adapter_names[3]
    
    from ..pipeline.flux_omini import Condition, generate
    
    test_list = []
    
    # 创建简单的测试样本
    subject_img = Image.new("RGB", condition_size, (128, 128, 128))
    background_img = Image.new("RGB", target_size, (200, 220, 255))
    
    # 创建两个条件
    subject_condition = Condition(subject_img, subject_adapter, [-16, -32])
    background_condition = Condition(background_img, background_adapter, [16, -32])
    
    prompt = "Place an aircraft at the specified position"
    test_list.append(([subject_condition, background_condition], prompt))
    
    # 生成测试图像
    os.makedirs(save_path, exist_ok=True)
    for i, (conditions, prompt) in enumerate(test_list):
        generator = torch.Generator(device=model.device)
        generator.manual_seed(42)
        
        res = generate(
            model.flux_pipe,
            prompt=prompt,
            conditions=conditions,
            height=target_size[1],
            width=target_size[0],
            generator=generator,
            model_config=model.model_config,
            kv_cache=model.model_config.get("independent_condition", False),
        )
        file_path = os.path.join(save_path, f"{file_name}_aircraft_{i}.jpg")
        res.images[0].save(file_path)


def main():
    """训练入口函数"""
    from .trainer_mask_weighted import get_config, OminiModel, train
    
    # 初始化配置
    config = get_config()
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    print("=" * 70)
    print("Aircraft Background+Crop Training (No Mask Input)")
    print("=" * 70)
    
    print("\n[1/3] Loading dataset...")
    dataset = AircraftBackgroundCropDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=dataset_config.get("drop_text_prob", 0.1),
        drop_subject_prob=dataset_config.get("drop_subject_prob", 0.1),
        drop_background_prob=dataset_config.get("drop_background_prob", 0.1),
        background_blur_prob=dataset_config.get("background_blur_prob", 0.3),
        augmentation_prob=dataset_config.get("augmentation_prob", 0.5),
        rotation_prob=dataset_config.get("rotation_prob", 0.5),
        flip_prob=dataset_config.get("flip_prob", 0.5),
        color_jitter_prob=dataset_config.get("color_jitter_prob", 0.5),
        brightness_range=tuple(dataset_config.get("brightness_range", [0.8, 1.2])),
        contrast_range=tuple(dataset_config.get("contrast_range", [0.8, 1.2])),
        saturation_range=tuple(dataset_config.get("saturation_range", [0.8, 1.2])),
        hue_range=tuple(dataset_config.get("hue_range", [-0.1, 0.1])),
    )
    print(f"  ✓ Dataset created: {len(dataset)} samples")
    print(f"    - Subject: aircraft crops")
    print(f"    - Background: background image")
    print(f"    - (No Position Mask Input)")
    print(f"    - Loss weighting: Mask-weighted MSE loss (using internal mask)")
    
    if len(dataset) == 0:
        print("\n❌ Error: Dataset is empty!")
        return
    
    print("\n[2/3] Initializing model...")
    # 注意：这里我们只用两个adapter：Subject 和 Background
    model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "background"], 
        optimizer_config=training_config.get("optimizer", None),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )
    model.training_config = training_config
    print("  ✓ Model initialized")
    
    print("\n[3/3] Starting training...")
    train(dataset, model, config, test_function_for_training)


if __name__ == "__main__":
    main()