from torch.utils.data import DataLoader
import torch
import lightning as L
import yaml
import os
import time
import re

from datasets import load_dataset

from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset, SceneDataset
from .model import OminiModel
from .callbacks import TrainingCallback
import safetensors.torch
from peft import PeftModel

import os
from PIL import Image
import pandas as pd
from torch.utils.data import Dataset

from torchvision import transforms
from torch.utils.data import DataLoader

class LocalSubjectsDataset(Dataset):
    def __init__(self, csv_file, image_dir, transform=None):
        """
        Args:
            csv_file (str): 包含图片文件名和描述的CSV文件路径，文件中应包含 "imageA", "prompt", "imageB" 三列
            image_dir (str): 存储图片的根目录
            transform (callable, optional): 对图片进行预处理的变换函数
        """
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transform = transform
        self.features = {
            'imageA': 'PIL.Image',
            'prompt': 'str',
            'imageB': 'PIL.Image'
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 获取图片A、描述和图片B的文件名
        imgA_value = self.data.iloc[idx]['imageA']
        if isinstance(imgA_value, pd.Series):
            imgA_value = imgA_value.values[0]  # 假设Series中只有一个值
        imgA_name = os.path.join(self.image_dir, str(imgA_value))  # 转换为字符串

        prompt = self.data.iloc[idx]['prompt']
        imgB_value = self.data.iloc[idx]['imageB']
        if isinstance(imgB_value, pd.Series):
            imgB_value = imgB_value.values[0]
        imgB_name = os.path.join(self.image_dir, str(imgB_value))

        # 加载图片
        imageA = Image.open(imgA_name).convert("RGB")
        imageB = Image.open(imgB_name).convert("RGB")

        if self.transform:
            imageA = self.transform(imageA)
            imageB = self.transform(imageB)

        sample = {'imageA': imageA, 'prompt': prompt, 'imageB': imageB}
        return sample
    
transform = transforms.Compose([
    transforms.Resize((600, 600)),
    # transforms.ToTensor(),
])

# class LocalSubjectsDataset(Dataset):
#     def __init__(self, csv_file, image_dir, transform=None):
#         """
#         Args:
#             csv_file (str): 包含图片文件名和描述的CSV文件路径，文件中应包含 "imageA", "prompt", "imageB" 三列
#             image_dir (str): 存储图片的根目录
#             transform (callable, optional): 对图片进行预处理的变换函数
#         """
#         self.data = pd.read_csv(csv_file)
#         self.image_dir = image_dir
#         self.transform = transform
#         self.features = {
#             'imageA': 'PIL.Image',
#             'prompt': 'str',
#             'imageB': 'PIL.Image'
#         }

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         # 获取图片A、描述和图片B的文件名
#         imgA_value = self.data.iloc[idx]['imageA']
#         if isinstance(imgA_value, pd.Series):
#             imgA_value = imgA_value.values[0]  # 假设Series中只有一个值
#         imgA_name = os.path.join(self.image_dir, str(imgA_value))  # 转换为字符串

#         prompt = self.data.iloc[idx]['prompt']
#         imgB_value = self.data.iloc[idx]['imageB']
#         if isinstance(imgB_value, pd.Series):
#             imgB_value = imgB_value.values[0] 
#         imgB_name = os.path.join(self.image_dir, str(imgB_value))
#         imgC_name = re.sub(r'a(?=[^a]*$)', 'c', imgA_name)

#         # 加载图片
#         imageA = Image.open(imgA_name).convert("RGB")
#         imageB = Image.open(imgB_name).convert("RGB")
#         imageC = Image.open(imgC_name).convert("RGB")


#         if self.transform:
#             imageA = self.transform(imageA)
#             imageB = self.transform(imageB)
#             imageC = self.transform(imageC)


#         sample = {'imageA': imageA, 'prompt': prompt, 'imageB': imageB, 'imageC': imageC}
#         return sample
    
# transform = transforms.Compose([
#     transforms.Resize((600, 600)),
#     # transforms.ToTensor(),
# ])

def get_rank():
    try:
        rank = int(os.environ.get("LOCAL_RANK"))
    except:
        rank = 0
    return rank


def get_config():
    config_path = os.environ.get("XFL_CONFIG")
    assert config_path is not None, "Please set the XFL_CONFIG environment variable"
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    return config


def init_wandb(wandb_config, run_name):
    import wandb
    wandb.init(
            project=wandb_config["project"],
            name=run_name,
            config={},
        )

    # try:
    #     assert os.environ.get("WANDB_API_KEY") is not None
    #     wandb.init(
    #         project=wandb_config["project"],
    #         name=run_name,
    #         config={},
    #     )
    # except Exception as e:
    #     print("Failed to initialize WanDB:", e)


def main():
    # Initialize
    is_main_process, rank = get_rank() == 0, get_rank()
    torch.cuda.set_device(rank)
    config = get_config()
    training_config = config["train"]
    run_name = time.strftime("%Y%m%d-%H%M%S")

    # Initialize WanDB
    wandb_config = training_config.get("wandb", None)
    if wandb_config is not None and is_main_process:
        init_wandb(wandb_config, run_name)

    print("Rank:", rank)
    if is_main_process:
        print("Config:", config)

    # Initialize dataset and dataloader
    if training_config["dataset"]["type"] == "scene":
        # dataset = load_dataset("Yuanshi/Subjects200K")
        # dataset = LocalSubjectsDataset(csv_file='/home/svu/e0787962/scratch_ln/yujia/dataset_remove2/dataset_remove2_inpaint.csv', image_dir='/home/svu/e0787962/scratch_ln/yujia/dataset_remove2/images', transform=transform)
        dataset = LocalSubjectsDataset(csv_file='/home/svu/e0787962/scratch_ln/yujia/dataset_replace/dataset_merge1.csv', image_dir='/home/svu/e0787962/scratch_ln/yujia/dataset_replace/images', transform=transform)


        # Define filter function
        # def filter_func(item):
        #     if not item.get("quality_assessment"):
        #         return False
        #     return all(
        #         item["quality_assessment"].get(key, 0) >= 5
        #         for key in ["compositeStructure", "objectConsistency", "imageQuality"]
        #     )

        # # Filter dataset
        # if not os.path.exists("./cache/dataset"):
        #     os.makedirs("./cache/dataset")
        # data_valid = dataset["train"].filter(
        #     filter_func,
        #     num_proc=16,
        #     cache_file_name="./cache/dataset/data_valid.arrow",
        # )
        data_valid = dataset
        # 查看数据集的特征（列名）
        print(data_valid.features)

        # 查看数据集的前几个样本
        # print(data_valid[:5])

        # 查看数据集的长度
        print(len(data_valid))
        print(training_config["dataset"])
        dataset = SceneDataset(
            data_valid,
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            image_size=training_config["dataset"]["image_size"],
            padding=training_config["dataset"]["padding"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
        )
    elif training_config["dataset"]["type"] == "img":
        # Load dataset text-to-image-2M
        dataset = load_dataset(
            "webdataset",
            data_files={"train": training_config["dataset"]["urls"]},
            split="train",
            cache_dir="cache/t2i2m",
            num_proc=32,
        )
        dataset = ImageConditionDataset(
            dataset,
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
            position_scale=training_config["dataset"].get("position_scale", 1.0),
        )
    elif training_config["dataset"]["type"] == "cartoon":
        dataset = load_dataset("saquiboye/oye-cartoon", split="train")
        dataset = CartoonDataset(
            dataset,
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            image_size=training_config["dataset"]["image_size"],
            padding=training_config["dataset"]["padding"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
        )
    elif training_config["dataset"]["type"] == "scene":
        dataset = dataset
    else:
        raise NotImplementedError

    print("Dataset length:", len(dataset))
    train_loader = DataLoader(
        dataset,
        batch_size=training_config["batch_size"],
        shuffle=True,
        num_workers=training_config["dataloader_workers"],
    )
    print("Trainloader generated.")

    # Initialize model
    trainable_model = OminiModel(
        flux_pipe_id=config["flux_path"],
        lora_config=training_config["lora_config"],
        device=f"cuda",
        dtype=getattr(torch, config["dtype"]),
        optimizer_config=training_config["optimizer"],
        model_config=config.get("model", {}),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )

    # checkpoint_path = "/home/svu/e0787962/scratch_ln/yujia/wandb/runs/20250502-144505/ckpt/10000/pytorch_lora_weights.safetensors"

    # 加载LoRA权重（Safetensors格式）
    checkpoint_path = "/home/svu/e0787962/scratch_ln/yujia/wandb/runs/20250503-110300/ckpt/20000/pytorch_lora_weights.safetensors"

    if os.path.exists(checkpoint_path):
        print(f"Loading LoRA weights from {checkpoint_path}")
        # 方法1：通过safetensors直接加载
        state_dict = safetensors.torch.load_file(checkpoint_path)
        
        # 方法2：通过peft加载（若模型是PeftModel）
        # trainable_model = PeftModel(trainable_model, checkpoint_path)
        
        # 加载到模型中（需确保参数名匹配）
        trainable_model.load_state_dict(state_dict, strict=False)
    else:
        print("Checkpoint not found. Starting from scratch.")

    # Callbacks for logging and saving checkpoints
    training_callbacks = (
        [TrainingCallback(run_name, training_config=training_config)]
        if is_main_process
        else []
    )

    # Initialize trainer
    trainer = L.Trainer(
        accumulate_grad_batches=training_config["accumulate_grad_batches"],
        callbacks=training_callbacks,
        enable_checkpointing=False,
        enable_progress_bar=False,
        logger=False,
        max_steps=training_config.get("max_steps", -1),
        max_epochs=training_config.get("max_epochs", -1),
        gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
    )

    setattr(trainer, "training_config", training_config)

    # Save config
    save_path = training_config.get("save_path", "./output")
    if is_main_process:
        os.makedirs(f"{save_path}/{run_name}")
        with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
            yaml.dump(config, f)

    # Start training
    trainer.fit(trainable_model, train_loader)
    # checkpoint_path = "/home/svu/e0787962/scratch_ln/yujia/wandb/runs/20250418-000657/ckpt/80000/pytorch_lora_weights.safetensors"
    # if checkpoint_path and os.path.exists(checkpoint_path):
    #     print(f"Resuming training from checkpoint: {checkpoint_path}")
    #     # 从检查点恢复训练
    #     trainer.fit(trainable_model, train_loader, ckpt_path=checkpoint_path)
    # else:
    #     print("No valid checkpoint found. Starting training from scratch.")
    #     trainer.fit(trainable_model, train_loader)


if __name__ == "__main__":
    main()
