"""
单文件夹航空目标三条件训练数据集 - Mask加权Loss版本

数据结构：
dataset_root/
├── Original/           # 原始图像
│   ├── xxx_orig.jpg
│   └── ...
├── Masks2/            # mask (使用Masks2文件夹)
│   ├── xxx_mask.png
│   └── ...
├── Background_Erased/ # 背景图（擦除飞机）
│   ├── xxx_bg.jpg
│   └── ...
└── Crops/             # 飞机裁剪
    ├── xxx_0_crop.jpg
    ├── xxx_1_crop.jpg
    └── ...
"""

import os
import random
import numpy as np
import csv
from pathlib import Path
from typing import Tuple, List, Dict
from PIL import Image, ImageDraw
import torch
import torchvision.transforms as T


class AircraftMaskWeightedDataset(torch.utils.data.Dataset):
    """
    单文件夹航空目标数据集 - Mask加权Loss版本
    支持三条件控制：Subject (飞机裁剪), Position (mask), Background (背景)
    返回target_mask用于计算mask加权loss
    """
    
    def __init__(
        self,
        dataset_root: str,
        condition_size: Tuple[int, int] = (512, 512),
        target_size: Tuple[int, int] = (512, 512),
        # Dropout 概率
        drop_text_prob: float = 0.1,
        drop_subject_prob: float = 0.1,
        drop_position_prob: float = 0.1,
        drop_background_prob: float = 0.1,
        # Mask 生成参数
        min_mask_ratio: float = 0.15,
        max_mask_ratio: float = 0.6,
        # Background 生成参数
        background_blur_prob: float = 0.3,
        # 数据增强参数
        augmentation_prob: float = 0.5,
        rotation_prob: float = 0.5,
        flip_prob: float = 0.5,
        color_jitter_prob: float = 0.5,
        brightness_range: Tuple[float, float] = (0.7, 1.3),
        contrast_range: Tuple[float, float] = (0.75, 1.25),
        saturation_range: Tuple[float, float] = (0.8, 1.2),
        hue_range: Tuple[float, float] = (-0.05, 0.05),
        return_pil_image: bool = False,
    ):
        self.dataset_root = Path(dataset_root)
        self.condition_size = condition_size
        self.target_size = target_size
        
        # Dropout 概率
        self.drop_text_prob = drop_text_prob
        self.drop_subject_prob = drop_subject_prob
        self.drop_position_prob = drop_position_prob
        self.drop_background_prob = drop_background_prob
        
        # Mask 生成参数
        self.min_mask_ratio = min_mask_ratio
        self.max_mask_ratio = max_mask_ratio
        
        # Background 参数
        self.background_blur_prob = background_blur_prob
        
        # 数据增强参数
        self.augmentation_prob = augmentation_prob
        self.rotation_prob = rotation_prob
        self.flip_prob = flip_prob
        self.color_jitter_prob = color_jitter_prob
        self.brightness_range = brightness_range
        self.contrast_range = contrast_range
        self.saturation_range = saturation_range
        self.hue_range = hue_range
        
        self.return_pil_image = return_pil_image
        self.to_tensor = T.ToTensor()
        
        # 加载类别映射（从 crop_info.csv）
        self.category_map = self._load_category_map()
        
        # 扫描数据集
        self.samples = self._scan_dataset()
    
    def _load_category_map(self) -> Dict[str, str]:
        """
        从 crop_info.csv 加载文件名到类别的映射
        
        CSV 格式:
        output_name,category
        train_10013_0001_obj000,plane
        train_10020_0000_obj000,plane
        ...
        """
        category_map = {}
        csv_path = self.dataset_root / "crop_info.csv"
        
        if not csv_path.exists():
            print(f"⚠️  Warning: {csv_path} not found, using default category")
            return category_map
        
        try:
            with open(csv_path, 'r', encoding='utf-8') as f:
                reader = csv.DictReader(f)
                for row in reader:
                    filename = row['output_name']  # 不含扩展名
                    category = row['category']
                    category_map[filename] = category
            
            print(f"✓ Loaded {len(category_map)} category mappings from crop_info.csv")
            
            # 统计类别分布
            category_counts = {}
            for cat in category_map.values():
                category_counts[cat] = category_counts.get(cat, 0) + 1
            
            print(f"  Top 5 categories:")
            for cat, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
                print(f"    - {cat}: {count}")
        
        except Exception as e:
            print(f"⚠️  Warning: Failed to load crop_info.csv: {e}")
        
        return category_map
    
    def _scan_dataset(self) -> List[dict]:
        """
        扫描单文件夹数据集，建立索引
        注意：使用Masks2文件夹而不是Masks
        
        文件名格式: train_{scene_id}_{frame_id}_obj{obj_id}.png
        所有文件夹（Original, Masks2, Background_Erased, Crops）使用相同的文件名
        """
        samples = []
        
        # Determine roots to scan
        roots_to_scan = getattr(self, 'dataset_roots', [self.dataset_root])
        
        for root in roots_to_scan:
            # 直接扫描 Original 文件夹
            original_dir = root / "Original"
            if not original_dir.exists():
                print(f"❌ Error: {original_dir} does not exist!")
                continue
            
            # 遍历所有 .png 文件
            for orig_file in original_dir.glob("*.png"):
                # 文件名（不含扩展名）
                base_name = orig_file.stem  # 例如: train_10013_0001_obj000
                
                # 构造对应的文件路径（使用相同的文件名，但mask从Masks2读取）
                mask_file = root / "Masks2" / f"{base_name}.png"
                bg_file = root / "Background_Erased" / f"{base_name}.png"
                crop_file = root / "Crops_Blurred" / f"{base_name}.png"
                # crop_file = root / "Crops" / f"{base_name}.png"
                # 必须所有文件都存在
                if not mask_file.exists() or not bg_file.exists() or not crop_file.exists():
                    continue
                
                # 验证 background 是否有效（不能全黑）
                try:
                    bg_img = Image.open(bg_file).convert("RGB")
                    bg_array = np.array(bg_img)
                    if bg_array.max() < 10 or bg_array.mean() < 15:
                        # background 全黑或过暗，跳过
                        continue
                except Exception:
                    continue
                
                # 解析场景和对象信息
                # 格式: train_{scene_id}_{frame_id}_obj{obj_id}
                parts = base_name.split("_")
                if len(parts) >= 4:
                    scene_id = parts[1]
                    object_id = parts[-1]  # obj000, obj001, ...
                else:
                    scene_id = "unknown"
                    object_id = "unknown"
                
                # 查找类别
                category = self.category_map.get(base_name, "object")
                
                # 添加到样本列表（注意：每个样本只有一个 crop）
                samples.append({
                    "original": str(orig_file),
                    "mask": str(mask_file),
                    "background": str(bg_file),
                    "crop": str(crop_file),  # 单个 crop，不是列表
                    "scene_id": scene_id,
                    "object_id": object_id,
                    "category": category,  # 添加类别信息
                    "root": str(root)
                })
        
        print(f"✓ Scanned {len(samples)} samples from {len(roots_to_scan)} roots")
        return samples
    
    def __len__(self):
        return len(self.samples)
    def _load_rgba_with_white_background(self, image_path: str) -> Image.Image:
        """
        加载 RGBA 图像，使用 alpha 通道去除背景，将透明区域替换为白色
        
        Args:
            image_path: 图像路径
        
        Returns:
            RGB 图像，透明区域为白色背景
        """
        # 加载图像
        img = Image.open(image_path)
        
        # 如果已经是 RGB，直接返回
        if img.mode == 'RGB':
            return img
        
        # 如果是 RGBA，使用 alpha 通道处理背景
        if img.mode == 'RGBA':
            # 创建白色背景
            background = Image.new('RGB', img.size, (255, 255, 255))
            # background = Image.new('RGB', img.size, (0, 0, 0))
            # 使用 alpha 通道作为 mask 合成
            background.paste(img, mask=img.split()[3])  # 第4个通道是 alpha
            return background
        
        # 其他模式，转换为 RGB
        return img.convert('RGB')
    def _generate_fill_mask_from_real_mask(
        self, target_image: Image.Image, real_mask: Image.Image
    ) -> Tuple[Image.Image, Image.Image]:
        """
        基于真实 mask 生成 position mask
        直接使用真实 mask 的二值化版本
        """
        # 转换为灰度
        mask_array = np.array(real_mask)
        
        # 检查是否为空 mask
        if mask_array.max() < 128:
            # 空 mask，返回全黑
            fill_mask = Image.new("L", target_image.size, 0)
            position_condition = fill_mask.convert("RGB")
            return position_condition, fill_mask
        
        # 二值化
        binary_mask = (mask_array > 128).astype(np.uint8) * 255
        fill_mask = Image.fromarray(binary_mask, mode="L")
        
        # Position condition: 白色区域表示目标位置
        position_condition = fill_mask.convert("RGB")
        
        return position_condition, fill_mask
    
    def _apply_random_blur(self, image, kernel_size):
        """应用随机高斯模糊"""
        import cv2
        image_array = np.array(image)
        blurred = cv2.GaussianBlur(image_array, (kernel_size, kernel_size), 0)
        return Image.fromarray(blurred)
    
    def _apply_color_jitter(self, image, brightness_factor, contrast_factor, 
                           saturation_factor, hue_shift):
        """应用颜色抖动"""
        # 亮度调整
        if brightness_factor != 1.0:
            image_array = np.array(image, dtype=np.float32)
            image_array = image_array * brightness_factor
            image_array = np.clip(image_array, 0, 255)
            image = Image.fromarray(image_array.astype(np.uint8))
        
        # 对比度调整
        if contrast_factor != 1.0:
            image_array = np.array(image, dtype=np.float32)
            mean = image_array.mean()
            image_array = (image_array - mean) * contrast_factor + mean
            image_array = np.clip(image_array, 0, 255)
            image = Image.fromarray(image_array.astype(np.uint8))
        
        # 饱和度调整
        if saturation_factor != 1.0:
            hsv_image = image.convert('HSV')
            h, s, v = hsv_image.split()
            s_array = np.array(s, dtype=np.float32)
            s_array = s_array * saturation_factor
            s_array = np.clip(s_array, 0, 255)
            s = Image.fromarray(s_array.astype(np.uint8), mode='L')
            hsv_image = Image.merge('HSV', (h, s, v))
            image = hsv_image.convert('RGB')
        
        # 色调调整
        if hue_shift != 0:
            hsv_image = image.convert('HSV')
            h, s, v = hsv_image.split()
            h_array = np.array(h, dtype=np.float32)
            h_array = (h_array + hue_shift * 255) % 256
            h = Image.fromarray(h_array.astype(np.uint8), mode='L')
            hsv_image = Image.merge('HSV', (h, s, v))
            image = hsv_image.convert('RGB')
        
        return image
    
    def _apply_sync_transforms(self, target, background, mask, subject):
        """
        对 target, background, mask, subject 应用同步的随机变换（一起旋转）
        """
        # 随机旋转 (90, 180, 270度)
        if random.random() < self.rotation_prob:
            angle = random.choice([90, 180, 270])
            target = target.rotate(angle, expand=False)
            background = background.rotate(angle, expand=False)
            mask = mask.rotate(angle, expand=False)
            subject = subject.rotate(angle, expand=False)
        
        # 随机水平翻转
        if random.random() < self.flip_prob:
            target = target.transpose(Image.FLIP_LEFT_RIGHT)
            background = background.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
            subject = subject.transpose(Image.FLIP_LEFT_RIGHT)
        
        # 随机垂直翻转
        if random.random() < self.flip_prob:
            target = target.transpose(Image.FLIP_TOP_BOTTOM)
            background = background.transpose(Image.FLIP_TOP_BOTTOM)
            mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
            subject = subject.transpose(Image.FLIP_TOP_BOTTOM)
        
        return target, background, mask, subject
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # 1. 加载图像
        target_image = Image.open(sample["original"]).convert("RGB")
        real_mask = Image.open(sample["mask"]).convert("L")
        background_image = Image.open(sample["background"]).convert("RGB")
        
        # 加载 crop 作为 subject（每个样本只有一个 crop）
        subject_image = self._load_rgba_with_white_background(sample["crop"])

        
        # 2. Resize
        target_image = target_image.resize(self.target_size)
        background_image = background_image.resize(self.condition_size)
        subject_image = subject_image.resize(self.condition_size)
        real_mask = real_mask.resize(self.target_size)
        
        # 3. 数据增强 - 同步变换 (所有图像一起旋转)
        if random.random() < self.augmentation_prob:
            target_image, background_image, real_mask, subject_image = self._apply_sync_transforms(
                target_image, background_image, real_mask, subject_image
            )
        
        # 4. 生成 position mask
        position_mask, fill_mask = self._generate_fill_mask_from_real_mask(
            target_image, real_mask
        )
        position_mask = position_mask.resize(self.condition_size)
        
        # 5. 数据增强 - 颜色抖动
        if random.random() < self.color_jitter_prob:
            # target 和 background 使用相同的颜色抖动参数
            brightness_factor_tb = random.uniform(*self.brightness_range)
            contrast_factor_tb = random.uniform(*self.contrast_range)
            saturation_factor_tb = random.uniform(*self.saturation_range)
            hue_shift_tb = random.uniform(*self.hue_range)
            
            target_image = self._apply_color_jitter(
                target_image, brightness_factor_tb, contrast_factor_tb, 
                saturation_factor_tb, hue_shift_tb
            )
            background_image = self._apply_color_jitter(
                background_image, brightness_factor_tb, contrast_factor_tb, 
                saturation_factor_tb, hue_shift_tb
            )
            # 暂时sunject别抖动了
            # # subject 使用独立的颜色抖动参数
            # brightness_factor_s = random.uniform(*self.brightness_range)
            # contrast_factor_s = random.uniform(*self.contrast_range)
            # saturation_factor_s = random.uniform(*self.saturation_range)
            # hue_shift_s = random.uniform(*self.hue_range)
            
            # subject_image = self._apply_color_jitter(
            #     subject_image, brightness_factor_s, contrast_factor_s, 
            #     saturation_factor_s, hue_shift_s
            # )
        
        # 6. 随机模糊 - target 和 background 同步模糊
        blur_prob = 0.3  # 模糊概率
        if random.random() < blur_prob:
            # 随机选择模糊核大小 (必须是奇数)
            kernel_size = random.choice([3, 5, 7])
            target_image = self._apply_random_blur(target_image, kernel_size)
            background_image = self._apply_random_blur(background_image, kernel_size)
        
        # 7. 生成文本描述（使用类别信息）
        category = sample.get("category", "object")
        # 格式化类别名称（将连字符替换为空格，首字母小写）
        category_name = category.replace("-", " ").lower()
        description = f"Place a {category_name} at the specified position"
        
        # 8. Dropout
        if random.random() < self.drop_text_prob:
            description = ""
        
        if random.random() < self.drop_subject_prob:
            subject_image = Image.new("RGB", self.condition_size, (128, 128, 128))
        
        if random.random() < self.drop_position_prob:
            position_mask = Image.new("RGB", self.condition_size, (128, 128, 128))
        
        if random.random() < self.drop_background_prob:
            background_image = Image.new("RGB", self.condition_size, (128, 128, 128))
        
        # 9. 转换为 tensor
        # fill_mask 总是需要转换为 tensor（用于 loss 计算）
        fill_mask_tensor = self.to_tensor(fill_mask)
        
        if not self.return_pil_image:
            target_image = self.to_tensor(target_image)
            subject_image = self.to_tensor(subject_image)
            position_mask = self.to_tensor(position_mask)
            background_image = self.to_tensor(background_image)
        
        # Position delta (在潜在空间中的偏移)
        position_delta_subject = np.array([-16, -32])
        position_delta_position = np.array([0, 0])
        position_delta_background = np.array([16, -32])
        
        # 返回结果（添加target_mask用于mask加权loss）
        result = {
            "image": target_image,
            "condition_0": subject_image,
            "condition_1": position_mask,
            "condition_2": background_image,
            "condition_type_0": "subject",
            "condition_type_1": "fill",
            "condition_type_2": "background",
            "position_delta_0": position_delta_subject,
            "position_delta_1": position_delta_position,
            "position_delta_2": position_delta_background,
            "description": description,
            "target_mask": fill_mask_tensor,  # 添加target_mask
        }
        
        return result


@torch.no_grad()
def test_function_for_training(model, save_path, file_name):
    """训练时的测试函数：生成示例图像"""
    condition_size = model.training_config["dataset"]["condition_size"]
    target_size = model.training_config["dataset"]["target_size"]
    
    # 获取三个 adapter
    subject_adapter = model.adapter_names[2]
    fill_adapter = model.adapter_names[3]
    background_adapter = model.adapter_names[4]
    
    from ..pipeline.flux_omini import Condition, generate
    
    test_list = []
    
    # 创建简单的测试样本
    subject_img = Image.new("RGB", condition_size, (128, 128, 128))
    background_img = Image.new("RGB", target_size, (200, 220, 255))
    
    # Fill mask（中心区域）
    target_blank = Image.new("RGB", target_size, (255, 255, 255))
    mask = Image.new("L", target_size, 0)
    draw = ImageDraw.Draw(mask)
    w, h = target_size
    x1, y1 = w // 4, h // 4
    x2, y2 = w * 3 // 4, h * 3 // 4
    draw.rectangle([x1, y1, x2, y2], fill=255)
    fill_mask_img = Image.composite(
        target_blank, Image.new("RGB", target_size, (0, 0, 0)), mask
    )
    
    # 创建三个条件
    subject_condition = Condition(subject_img, subject_adapter, [-16, -32])
    position_condition = Condition(fill_mask_img, fill_adapter, [0, 0])
    background_condition = Condition(background_img, background_adapter, [16, -32])
    
    prompt = "Place an aircraft at the specified position"
    test_list.append(([subject_condition, position_condition, background_condition], prompt))
    
    # 生成测试图像
    os.makedirs(save_path, exist_ok=True)
    for i, (conditions, prompt) in enumerate(test_list):
        generator = torch.Generator(device=model.device)
        generator.manual_seed(42)
        
        res = generate(
            model.flux_pipe,
            prompt=prompt,
            conditions=conditions,
            height=target_size[1],
            width=target_size[0],
            generator=generator,
            model_config=model.model_config,
            kv_cache=model.model_config.get("independent_condition", False),
        )
        file_path = os.path.join(save_path, f"{file_name}_aircraft_{i}.jpg")
        res.images[0].save(file_path)


def main():
    """训练入口函数"""
    from .trainer_mask_weighted import get_config, OminiModel, train
    from torch.utils.data import DataLoader
    
    # 初始化配置
    config = get_config()
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    print("=" * 70)
    print("Aircraft Mask Weighted Training (Using Masks2)")
    print("=" * 70)
    
    print("\n[1/3] Loading dataset...")
    dataset = AircraftMaskWeightedDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=dataset_config.get("drop_text_prob", 0.1),
        drop_subject_prob=dataset_config.get("drop_subject_prob", 0.1),
        drop_position_prob=dataset_config.get("drop_position_prob", 0.1),
        drop_background_prob=dataset_config.get("drop_background_prob", 0.1),
        min_mask_ratio=dataset_config.get("min_mask_ratio", 0.15),
        max_mask_ratio=dataset_config.get("max_mask_ratio", 0.6),
        background_blur_prob=dataset_config.get("background_blur_prob", 0.3),
        # 数据增强参数
        augmentation_prob=dataset_config.get("augmentation_prob", 0.5),
        rotation_prob=dataset_config.get("rotation_prob", 0.5),
        flip_prob=dataset_config.get("flip_prob", 0.5),
        color_jitter_prob=dataset_config.get("color_jitter_prob", 0.5),
        brightness_range=tuple(dataset_config.get("brightness_range", [0.8, 1.2])),
        contrast_range=tuple(dataset_config.get("contrast_range", [0.8, 1.2])),
        saturation_range=tuple(dataset_config.get("saturation_range", [0.8, 1.2])),
        hue_range=tuple(dataset_config.get("hue_range", [-0.1, 0.1])),
    )
    print(f"  ✓ Dataset created: {len(dataset)} samples")
    print(f"    - Subject: aircraft crops")
    print(f"    - Position: aircraft mask (from Masks2)")
    print(f"    - Background: background image")
    print(f"    - Prompt: Place an aircraft at the specified position")
    print(f"    - Loss weighting: Mask-weighted MSE loss")
    
    if len(dataset) == 0:
        print("\n❌ Error: Dataset is empty!")
        print("Please check:")
        print(f"  1. Dataset root: {dataset_config['dataset_root']}")
        print("  2. Folder structure:")
        print("     - Original/")
        print("     - Masks2/  (注意：使用Masks2而不是Masks)")
        print("     - Background_Erased/")
        print("     - Crops/")
        return
    
    print("\n[2/3] Initializing model...")
    model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "fill", "background"],
        optimizer_config=training_config.get("optimizer", None),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )
    model.training_config = training_config
    print("  ✓ Model initialized")
    
    print("\n[3/3] Starting training...")
    train(dataset, model, config, test_function_for_training)


if __name__ == "__main__":
    main()
