"""
航空目标三条件训练 Dataloader
数据集：FAIR1M2.0, MAR20, rareplanes

条件映射：
- Original → Target Image
- Crops → Subject Reference
- Background_Erased → Background
- Masks → Position Mask (生成 fill mask)
"""

import os
import random
from pathlib import Path
from typing import List, Tuple, Optional
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageDraw
import torchvision.transforms as T
import numpy as np
import cv2


class AircraftTripleConditionDataset(Dataset):
    """
    航空目标三条件数据集
    
    数据结构:
    dataset_root/
    ├── FAIR1M2.0/
    │   ├── Original/{id}_{aircraft}_orig.jpg
    │   ├── Crops/{id}_{aircraft}_{idx}_crop.jpg
    │   ├── Masks/{id}_{aircraft}_mask.png
    │   └── Background_Erased/{id}_{aircraft}_bg.jpg
    ├── MAR20/
    └── rareplanes/
    """
    
    def __init__(
        self,
        dataset_root: str,
        subdatasets: List[str] = ["FAIR1M2.0", "MAR20", "rareplanes"],
        condition_size: Tuple[int, int] = (512, 512),
        target_size: Tuple[int, int] = (512, 512),
        # Dropout 概率
        drop_text_prob: float = 0.1,
        drop_subject_prob: float = 0.1,
        drop_position_prob: float = 0.1,
        drop_background_prob: float = 0.1,
        # Fill mask 参数
        min_mask_ratio: float = 0.15,
        max_mask_ratio: float = 0.6,
        # Background 参数
        background_blur_prob: float = 0.3,
        # 数据增强参数
        augmentation_prob: float = 0.5,
        rotation_prob: float = 0.5,
        flip_prob: float = 0.5,
        color_jitter_prob: float = 0.5,
        brightness_range: Tuple[float, float] = (0.8, 1.2),
        contrast_range: Tuple[float, float] = (0.8, 1.2),
        saturation_range: Tuple[float, float] = (0.8, 1.2),
        hue_range: Tuple[float, float] = (-0.1, 0.1),
        return_pil_image: bool = False,
    ):
        self.dataset_root = Path(dataset_root)
        self.subdatasets = subdatasets
        self.condition_size = condition_size
        self.target_size = target_size
        
        self.drop_text_prob = drop_text_prob
        self.drop_subject_prob = drop_subject_prob
        self.drop_position_prob = drop_position_prob
        self.drop_background_prob = drop_background_prob
        
        self.min_mask_ratio = min_mask_ratio
        self.max_mask_ratio = max_mask_ratio
        self.background_blur_prob = background_blur_prob
        
        # 数据增强参数
        self.augmentation_prob = augmentation_prob
        self.rotation_prob = rotation_prob
        self.flip_prob = flip_prob
        self.color_jitter_prob = color_jitter_prob
        self.brightness_range = brightness_range
        self.contrast_range = contrast_range
        self.saturation_range = saturation_range
        self.hue_range = hue_range
        
        self.return_pil_image = return_pil_image
        
        self.to_tensor = T.ToTensor()
        
        # 扫描数据集
        self.samples = self._scan_dataset()
        
        # # 统计信息
        # total_crops = sum(len(s["crops"]) for s in self.samples)
        # avg_crops = total_crops / len(self.samples) if self.samples else 0
        # print(f"✓ Found {len(self.samples)} background images in {len(subdatasets)} subdatasets")
        # print(f"  - Total crops: {total_crops}")
        # print(f"  - Average crops per background: {avg_crops:.2f}")
    
    def _scan_dataset(self) -> List[dict]:
        """扫描数据集，建立索引"""
        samples = []
        
        for subdataset in self.subdatasets:
            subdataset_path = self.dataset_root / subdataset
            if not subdataset_path.exists():
                print(f"⚠️  Warning: {subdataset_path} does not exist, skipping...")
                continue
            
            # 扫描 Original 文件夹
            original_dir = subdataset_path / "Original"
            if not original_dir.exists():
                continue
            
            for orig_file in original_dir.glob("*_orig.jpg"):
                # 从文件名解析信息
                filename = orig_file.stem  # 去除扩展名
                # 格式: {id}_{aircraft}_orig
                parts = filename.rsplit("_", 1)  # 从右边分割
                if len(parts) != 2 or parts[1] != "orig":
                    continue
                
                base_name = parts[0]  # {id}_{aircraft}
                
                # 构造文件路径
                mask_file = subdataset_path / "Masks" / f"{base_name}_mask.png"
                bg_file = subdataset_path / "Background_Erased" / f"{base_name}_bg.jpg"
                crops_dir = subdataset_path / "Crops"
                
                # 查找对应的 crops
                crop_files = list(crops_dir.glob(f"{base_name}_*_crop.jpg"))
                
                # 必须有 mask 和 background
                if not mask_file.exists() or not bg_file.exists():
                    continue
                
                # 如果没有 crops，跳过
                if len(crop_files) == 0:
                    continue
                
                # 验证 background 是否有效（保留空 mask 的样本）
                try:
                    # 检查 background 是否全黑或过暗
                    bg_img = Image.open(bg_file).convert("RGB")
                    bg_array = np.array(bg_img)
                    if bg_array.max() < 10 or bg_array.mean() < 15:
                        # background 全黑或过暗，跳过此样本
                        continue
                except Exception:
                    # 文件读取失败，跳过此样本
                    continue
                
                # 提取飞机类型（智能处理不同命名格式）
                # FAIR1M2.0: {id}_{aircraft} → aircraft
                # rareplanes: {prefix}_{id}_tile_{num}_{aircraft_type}_{category} → aircraft_type_category
                aircraft_type = self._extract_aircraft_type(base_name)
                
                # 添加样本
                sample = {
                    "original": str(orig_file),
                    "mask": str(mask_file),
                    "background": str(bg_file),
                    "crops": [str(f) for f in crop_files],
                    "aircraft_type": aircraft_type,
                    "image_id": base_name.split("_")[0],        # 提取图像 ID
                    "base_name": base_name,                      # 保存完整 base_name
                }
                samples.append(sample)
        
        return samples
    
    def _extract_aircraft_type(self, base_name: str) -> str:
        """
        智能提取飞机类型，支持多种命名格式
        
        格式 1 (FAIR1M2.0): {id}_{aircraft}
            例: "10001_A321" → "A321"
        
        格式 2 (rareplanes): {prefix}_{id}_tile_{num}_{aircraft_type}_{category}
            例: "1_104005000FDC8D00_tile_28_Medium Civil Transport_Utility"
            → "Medium Civil Transport Utility"
        """
        parts = base_name.split("_")
        
        # 检测 rareplanes 格式（包含 "tile"）
        if "tile" in parts:
            # 找到 "tile" 后面的索引
            try:
                tile_idx = parts.index("tile")
                # tile 后是数字，再后面是飞机类型
                # 从 tile_idx + 2 开始到结尾都是飞机类型
                aircraft_parts = parts[tile_idx + 2:]
                # 用空格连接，保持可读性
                return " ".join(aircraft_parts)
            except (ValueError, IndexError):
                pass
        
        # FAIR1M2.0 格式或其他简单格式：取最后一个部分
        return parts[-1]
    
    def __len__(self):
        return len(self.samples)
    
    def _generate_fill_mask_from_real_mask(
        self, 
        target_image: Image.Image, 
        real_mask: Image.Image
    ) -> Tuple[Image.Image, Image.Image]:
        """
        基于真实 mask 生成 fill position mask
        
        策略：直接使用真实 mask 作为 position condition
        """
        # 转换为灰度并二值化
        mask = real_mask.convert("L")
        mask_array = np.array(mask)
        
        # 检查 mask 是否为空
        if mask_array.max() < 128:
            # 如果 mask 为空，返回全黑的 mask
            fill_mask = Image.new("L", target_image.size, 0)
            position_condition = fill_mask.convert("RGB")
            return position_condition, fill_mask
        
        # 二值化
        _, binary_mask = cv2.threshold(mask_array, 128, 255, cv2.THRESH_BINARY)
        
        # 直接使用真实 mask
        fill_mask = Image.fromarray(binary_mask).convert("L")
        
        # position condition 直接就是 mask (转换为 RGB 用于显示)
        position_condition = fill_mask.convert("RGB")
        
        return position_condition, fill_mask
    
    def _generate_random_fill_mask(
        self, 
        target_image: Image.Image
    ) -> Tuple[Image.Image, Image.Image]:
        """生成随机的 fill mask（备用方案）"""
        w, h = target_image.size
        
        # 随机生成矩形
        target_area = w * h
        mask_area = random.uniform(
            self.min_mask_ratio * target_area,
            self.max_mask_ratio * target_area
        )
        
        aspect_ratio = random.uniform(0.5, 2.0)
        mask_h = int(np.sqrt(mask_area / aspect_ratio))
        mask_w = int(mask_area / mask_h)
        
        mask_h = min(mask_h, h)
        mask_w = min(mask_w, w)
        
        x1 = random.randint(0, w - mask_w)
        y1 = random.randint(0, h - mask_h)
        x2 = x1 + mask_w
        y2 = y1 + mask_h
        
        # 创建 mask
        fill_mask = Image.new("L", target_image.size, 0)
        draw = ImageDraw.Draw(fill_mask)
        draw.rectangle([x1, y1, x2, y2], fill=255)
        
        # 创建 position condition
        position_condition = Image.composite(
            Image.new("RGB", target_image.size, (0, 0, 0)),
            target_image,
            fill_mask
        )
        
        return position_condition, fill_mask
    
    def _apply_color_jitter(self, image: Image.Image, 
                           brightness_factor: float = None,
                           contrast_factor: float = None,
                           saturation_factor: float = None,
                           hue_shift: float = None) -> Image.Image:
        """应用颜色抖动（亮度、对比度、饱和度、色调）"""
        from PIL import ImageEnhance
        import colorsys
        
        # 亮度调整
        if brightness_factor is not None:
            enhancer = ImageEnhance.Brightness(image)
            image = enhancer.enhance(brightness_factor)
        
        # 对比度调整
        if contrast_factor is not None:
            enhancer = ImageEnhance.Contrast(image)
            image = enhancer.enhance(contrast_factor)
        
        # 饱和度调整
        if saturation_factor is not None:
            enhancer = ImageEnhance.Color(image)
            image = enhancer.enhance(saturation_factor)
        
        # 色调调整（HSV空间）
        if hue_shift is not None and abs(hue_shift) > 1e-6:
            # 转换到 HSV
            hsv_image = image.convert('HSV')
            h, s, v = hsv_image.split()
            
            # 调整色调
            h_array = np.array(h, dtype=np.float32)
            h_array = (h_array + hue_shift * 255) % 256  # 色调是 0-255 的循环
            h = Image.fromarray(h_array.astype(np.uint8), mode='L')
            
            # 合并回 HSV 并转回 RGB
            hsv_image = Image.merge('HSV', (h, s, v))
            image = hsv_image.convert('RGB')
        
        return image
    
    def _apply_sync_transforms(self, target, background, mask, subject):
        """
        对 target, background, mask, subject 应用同步的随机变换（一起旋转）
        返回变换后的图像
        """
        # 随机旋转 (90, 180, 270度)
        if random.random() < self.rotation_prob:
            angle = random.choice([90, 180, 270])
            target = target.rotate(angle, expand=False)
            background = background.rotate(angle, expand=False)
            mask = mask.rotate(angle, expand=False)
            subject = subject.rotate(angle, expand=False)  # subject 也一起旋转
        
        # 随机水平翻转
        if random.random() < self.flip_prob:
            target = target.transpose(Image.FLIP_LEFT_RIGHT)
            background = background.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
            subject = subject.transpose(Image.FLIP_LEFT_RIGHT)  # subject 也一起翻转
        
        # 随机垂直翻转
        if random.random() < self.flip_prob:
            target = target.transpose(Image.FLIP_TOP_BOTTOM)
            background = background.transpose(Image.FLIP_TOP_BOTTOM)
            mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
            subject = subject.transpose(Image.FLIP_TOP_BOTTOM)  # subject 也一起翻转
        
        return target, background, mask, subject
    
    def _apply_subject_transforms(self, subject):
        """
        对 subject 应用独立的随机变换
        返回变换后的图像
        """
        # 随机旋转 (90, 180, 270度)
        if random.random() < self.rotation_prob:
            angle = random.choice([90, 180, 270])
            subject = subject.rotate(angle, expand=False)
        
        # 随机水平翻转
        if random.random() < self.flip_prob:
            subject = subject.transpose(Image.FLIP_LEFT_RIGHT)
        
        # 随机垂直翻转
        if random.random() < self.flip_prob:
            subject = subject.transpose(Image.FLIP_TOP_BOTTOM)
        
        return subject
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # 1. 加载图像
        target_image = Image.open(sample["original"]).convert("RGB")
        real_mask = Image.open(sample["mask"]).convert("L")
        background_image = Image.open(sample["background"]).convert("RGB")
        
        # 随机选择一个 crop 作为 subject
        crop_path = random.choice(sample["crops"])
        subject_image = Image.open(crop_path).convert("RGB")
        
        # 2. Resize
        target_image = target_image.resize(self.target_size)
        background_image = background_image.resize(self.condition_size)
        subject_image = subject_image.resize(self.condition_size)
        real_mask = real_mask.resize(self.target_size)
        
        # 3. 数据增强 - 同步变换 (target, background, mask, subject 一起变换)
        if random.random() < self.augmentation_prob:
            target_image, background_image, real_mask, subject_image = self._apply_sync_transforms(
                target_image, background_image, real_mask, subject_image
            )
        
        # 5. 生成 position mask（基于真实 mask）
        position_mask, fill_mask = self._generate_fill_mask_from_real_mask(
            target_image, real_mask
        )
        position_mask = position_mask.resize(self.condition_size)
        
        # 6. 数据增强 - 颜色抖动（亮度、对比度、饱和度、色调）
        # target 和 background 使用相同的颜色变换因子
        if random.random() < self.color_jitter_prob:
            brightness_factor = random.uniform(*self.brightness_range)
            contrast_factor = random.uniform(*self.contrast_range)
            saturation_factor = random.uniform(*self.saturation_range)
            hue_shift = random.uniform(*self.hue_range)
            
            target_image = self._apply_color_jitter(
                target_image, brightness_factor, contrast_factor, 
                saturation_factor, hue_shift
            )
            background_image = self._apply_color_jitter(
                background_image, brightness_factor, contrast_factor,
                saturation_factor, hue_shift
            )
        
        # subject 使用独立的颜色变换因子
        if random.random() < self.color_jitter_prob:
            brightness_factor = random.uniform(*self.brightness_range)
            contrast_factor = random.uniform(*self.contrast_range)
            saturation_factor = random.uniform(*self.saturation_range)
            hue_shift = random.uniform(*self.hue_range)
            
            subject_image = self._apply_color_jitter(
                subject_image, brightness_factor, contrast_factor,
                saturation_factor, hue_shift
            )
        
        # 7. 背景模糊（可选）
        if random.random() < self.background_blur_prob:
            from PIL import ImageFilter
            blur_radius = random.uniform(1, 5)
            background_image = background_image.filter(
                ImageFilter.GaussianBlur(radius=blur_radius)
            )
        
        # 8. 生成文本描述
        aircraft_type = sample["aircraft_type"]
        # 统一使用固定文本，专注于条件控制
        description = "Place an aircraft at the specified position"
        
        # 9. Dropout
        drop_text = random.random() < self.drop_text_prob
        drop_subject = random.random() < self.drop_subject_prob
        drop_position = random.random() < self.drop_position_prob
        drop_background = random.random() < self.drop_background_prob
        
        if drop_text:
            description = ""
        if drop_subject:
            subject_image = Image.new("RGB", self.condition_size, (0, 0, 0))
        if drop_position:
            position_mask = Image.new("RGB", self.condition_size, (0, 0, 0))
        if drop_background:
            background_image = Image.new("RGB", self.condition_size, (0, 0, 0))
        
        # 10. Position delta (VAE 下采样因子 16)
        position_delta_subject = np.array([-16, -32])
        position_delta_position = np.array([0, 0])
        position_delta_background = np.array([16, -32])
        
        # 11. 转换为 tensor
        result = {
            "image": self.to_tensor(target_image),
            "condition_0": self.to_tensor(subject_image),
            "condition_type_0": "subject",
            "position_delta_0": position_delta_subject,
            "condition_1": self.to_tensor(position_mask),
            "condition_type_1": "fill",
            "position_delta_1": position_delta_position,
            "condition_2": self.to_tensor(background_image),
            "condition_type_2": "background",
            "position_delta_2": position_delta_background,
            "description": description,
        }
        
        if self.return_pil_image:
            result["pil_images"] = {
                "target": target_image,
                "subject": subject_image,
                "position": position_mask,
                "background": background_image,
            }
        
        return result


def test_function():
    """测试数据集加载"""
    import matplotlib.pyplot as plt
    
    print("=" * 70)
    print("Testing Aircraft Triple Condition Dataset")
    print("=" * 70)
    
    # 创建数据集
    dataset = AircraftTripleConditionDataset(
        dataset_root="/data2/aaa/datasets/targetgen/",
        subdatasets=["FAIR1M2.0"],
        condition_size=(512, 512),
        target_size=(512, 512),
        drop_text_prob=0.0,
        drop_subject_prob=0.0,
        drop_position_prob=0.0,
        drop_background_prob=0.0,
        return_pil_image=True,
    )
    
    print(f"\nDataset size: {len(dataset)}")
    
    # 测试几个样本
    num_samples = min(3, len(dataset))
    
    for i in range(num_samples):
        sample = dataset[i]
        pil_images = sample["pil_images"]
        
        # 可视化
        fig, axes = plt.subplots(1, 4, figsize=(16, 4))
        
        axes[0].imshow(pil_images["target"])
        axes[0].set_title(f"Target\n{sample['description']}")
        axes[0].axis("off")
        
        axes[1].imshow(pil_images["subject"])
        axes[1].set_title("Subject Reference")
        axes[1].axis("off")
        
        axes[2].imshow(pil_images["position"])
        axes[2].set_title("Position Mask")
        axes[2].axis("off")
        
        axes[3].imshow(pil_images["background"])
        axes[3].set_title("Background")
        axes[3].axis("off")
        
        plt.tight_layout()
        os.makedirs("test_outputs_aircraft", exist_ok=True)
        plt.savefig(f"test_outputs_aircraft/sample_{i}.jpg", dpi=150)
        print(f"✓ Saved visualization: test_outputs_aircraft/sample_{i}.jpg")
        plt.close()
    
    print("\n" + "=" * 70)
    print("✓ Test completed!")
    print("=" * 70)


@torch.no_grad()
def test_function_for_training(model, save_path, file_name):
    """训练时的测试函数：生成示例图像"""
    condition_size = model.training_config["dataset"]["condition_size"]
    target_size = model.training_config["dataset"]["target_size"]
    
    # 获取三个 adapter
    subject_adapter = model.adapter_names[2]     # subject adapter
    fill_adapter = model.adapter_names[3]        # fill adapter
    background_adapter = model.adapter_names[4]  # background adapter
    
    from ..pipeline.flux_omini import Condition, generate
    
    test_list = []
    
    # 测试用例：使用真实的航空数据集样本
    # 这里使用预定义的测试图像或创建简单示例
    # 创建简单的测试样本
    subject_img = Image.new("RGB", condition_size, (128, 128, 128))
    background_img = Image.new("RGB", target_size, (200, 220, 255))
    
    # Fill mask（中心区域）
    target_blank = Image.new("RGB", target_size, (255, 255, 255))
    mask = Image.new("L", target_size, 0)
    draw = ImageDraw.Draw(mask)
    w, h = target_size
    x1, y1 = w // 4, h // 4
    x2, y2 = w * 3 // 4, h * 3 // 4
    draw.rectangle([x1, y1, x2, y2], fill=255)
    fill_mask_img = Image.composite(
        target_blank, Image.new("RGB", target_size, (0, 0, 0)), mask
    )
    
    # 创建三个条件
    subject_condition = Condition(
        subject_img, 
        subject_adapter, 
        [-16, -32]
    )
    position_condition = Condition(
        fill_mask_img,
        fill_adapter,
        [0, 0]
    )
    background_condition = Condition(
        background_img,
        background_adapter,
        [16, -32]
    )
    
    prompt = "Place an aircraft at the specified position"
    test_list.append(([subject_condition, position_condition, background_condition], prompt))
    
    # 生成测试图像
    os.makedirs(save_path, exist_ok=True)
    for i, (conditions, prompt) in enumerate(test_list):
        generator = torch.Generator(device=model.device)
        generator.manual_seed(42)
        
        res = generate(
            model.flux_pipe,
            prompt=prompt,
            conditions=conditions,
            height=target_size[1],
            width=target_size[0],
            generator=generator,
            model_config=model.model_config,
            kv_cache=model.model_config.get("independent_condition", False),
        )
        file_path = os.path.join(save_path, f"{file_name}_aircraft_{i}.jpg")
        res.images[0].save(file_path)


def main():
    """训练入口函数"""
    from .trainer import get_config, OminiModel, train
    from torch.utils.data import DataLoader
    
    # 初始化配置
    config = get_config()
    training_config = config["train"]
    torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
    
    dataset_config = training_config["dataset"]
    
    print("=" * 80)
    print("Aircraft Triple Condition Training")
    print("=" * 80)
    
    # 创建数据集
    print("\n[1/3] Loading dataset...")
    dataset = AircraftTripleConditionDataset(
        dataset_root=dataset_config["dataset_root"],
        subdatasets=dataset_config.get("subdatasets", ["FAIR1M2.0"]),
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=dataset_config.get("drop_text_prob", 0.1),
        drop_subject_prob=dataset_config.get("drop_subject_prob", 0.1),
        drop_position_prob=dataset_config.get("drop_position_prob", 0.1),
        drop_background_prob=dataset_config.get("drop_background_prob", 0.1),
        min_mask_ratio=dataset_config.get("min_mask_ratio", 0.15),
        max_mask_ratio=dataset_config.get("max_mask_ratio", 0.6),
        background_blur_prob=dataset_config.get("background_blur_prob", 0.3),
        # 数据增强参数
        augmentation_prob=dataset_config.get("augmentation_prob", 0.5),
        rotation_prob=dataset_config.get("rotation_prob", 0.5),
        flip_prob=dataset_config.get("flip_prob", 0.5),
        color_jitter_prob=dataset_config.get("color_jitter_prob", 0.5),
        brightness_range=tuple(dataset_config.get("brightness_range", [0.8, 1.2])),
        contrast_range=tuple(dataset_config.get("contrast_range", [0.8, 1.2])),
        saturation_range=tuple(dataset_config.get("saturation_range", [0.8, 1.2])),
        hue_range=tuple(dataset_config.get("hue_range", [-0.1, 0.1])),
    )
    print(f"  ✓ Dataset created: {len(dataset)} samples")
    print(f"    - Subject: aircraft crops")
    print(f"    - Position: aircraft mask")
    print(f"    - Background: background image")
    print(f"    - Prompt: Place an aircraft at the specified position")
    
    # 初始化模型
    print("\n[2/3] Initializing model...")
    trainable_model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_config=training_config["lora_config"],
        device=f"cuda",
        dtype=getattr(torch, config["dtype"]),
        optimizer_config=training_config["optimizer"],
        model_config=config.get("model", {}),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
        adapter_names=[None, None, "subject", "fill", "background"],
    )
    print(f"  ✓ Model initialized")
    
    # 训练
    print("\n[3/3] Starting training...")
    train(dataset, trainable_model, config, test_function_for_training)
    
    print("\n" + "=" * 80)
    print("✓ Training completed!")
    print("=" * 80)


if __name__ == "__main__":
    # 训练模式
    main()
