import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

os.environ["TOKENIZERS_PARALLELISM"] = "True"
# os.environ['NCCL_BLOCKING_WAIT'] = '0'

import sys
sys.path.insert(0, './')
sys.path.append('/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl')
sys.path.append('/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl/src')

from torch.utils.data import DataLoader
import torch
# import lightning as L
import pytorch_lightning as L
import yaml
import os
import time
import argparse

from datasets import load_dataset

from src.train.data import (
    ImageConditionDataset,
    Subject200KDateset,
    LooseConditionDataset,
    EligenDepthDataset,
    EligenLooseDataset,
    EligenLoose2DDataset,
    RealEstate10KPose_image,
    EligenPoseDataset,
    PoseDataset,
    EligenCameraDataset,
)
from src.train.model import OminiModel
# from src.train.model_3d import OminiModel_3D
from src.train.callbacks import TrainingCallback

torch.cuda.empty_cache()
torch.set_float32_matmul_precision('high')

def get_rank():
    try:
        rank = int(os.environ.get("LOCAL_RANK"))
    except:
        rank = 0
    return rank


def get_config(config_path):
    if config_path is None:
        config_path = os.environ.get("XFL_CONFIG")
    assert config_path is not None, "Please set the XFL_CONFIG environment variable"
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    return config


def init_wandb(wandb_config, run_name):
    import wandb

    try:
        assert os.environ.get("WANDB_API_KEY") is not None
        wandb.init(
            project=wandb_config["project"],
            name=run_name,
            config={},
        )
    except Exception as e:
        print("Failed to initialize WanDB:", e)


def main(config_path: dict = None):
    # Initialize
    is_main_process, rank = get_rank() == 0, get_rank()
    torch.cuda.set_device(rank)
    config = get_config(config_path)
    training_config = config["train"]
    # run_name = time.strftime("%Y%m%d-%H%M%S")
    run_name = time.strftime("%Y%m%d-%H%M%S") + f"_{training_config['condition_type']}_{training_config['dataset']['type']}_{config['model']['inter_controller_type']}"

    # Initialize WanDB
    wandb_config = training_config.get("wandb", None)
    if wandb_config is not None and is_main_process:
        init_wandb(wandb_config, run_name)

    print("Rank:", rank)
    if is_main_process:
        print("Config:", config)

    # Initialize dataset and dataloader
    if training_config["dataset"]["type"] == "subject":
        dataset = load_dataset("Yuanshi/Subjects200K")

        # Define filter function
        def filter_func(item):
            if not item.get("quality_assessment"):
                return False
            return all(
                item["quality_assessment"].get(key, 0) >= 5
                for key in ["compositeStructure", "objectConsistency", "imageQuality"]
            )

        # Filter dataset
        if not os.path.exists("./cache/dataset"):
            os.makedirs("./cache/dataset")
        data_valid = dataset["train"].filter(
            filter_func,
            num_proc=16,
            cache_file_name="./cache/dataset/data_valid.arrow",
        )
        dataset = Subject200KDateset(
            data_valid,
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            image_size=training_config["dataset"]["image_size"],
            padding=training_config["dataset"]["padding"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
        )
    elif training_config["dataset"]["type"] == "img":
        # Load dataset text-to-image-2M
        base_url = "data/data_512_2M/data_{i:06d}.json"
        num_shards = 2 # 46  # Number of webdataset tar files
        data_path = [base_url.format(i=i) for i in range(num_shards)]
        dataset = ImageConditionDataset(
            data_path,
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
        )
    elif training_config["dataset"]["type"] == "camera":
        root_path = "/mnt/nas_jianchong/datasets/RealEstate10K/copy1/dataset/"
        data_path = "/mnt/nas_jianchong/datasets/RealEstate10K/copy1/RealEstate10K"
        caption_path = "RealEstate10K/train_captions.json"
        # Load dataset text-to-image-2M
        base_url = "data/data_512_2M/data_{i:06d}.json"
        num_shards = 40 # 46  # Number of webdataset tar files
        mix_data_path = [base_url.format(i=i) for i in range(num_shards)]
        dataset = RealEstate10KPose_image(
            root_path,
            data_path,
            caption_path,
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
            mix_data_rate=0.,
            mix_data_path=mix_data_path,
        )
    elif training_config["dataset"]["type"] == "pose":
        dataset = PoseDataset(
            data_path = training_config["dataset"]["data_path"],
            json_path = training_config["dataset"]["json_path"],
            depth_path = training_config["dataset"]["depth_path"],
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
            max_entity_len = training_config["dataset"].get("max_entity_len", 2),
            aug = training_config["dataset"].get("aug", False),
        )
    elif training_config["dataset"]["type"] == "eligen_camera":
        dataset = EligenCameraDataset(
            data_path = training_config["dataset"]["data_path"],
            json_path = training_config["dataset"]["json_path"],
            depth_path = training_config["dataset"]["depth_path"],
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
            max_entity_len = training_config["dataset"].get("max_entity_len", 2),
            aug = training_config["dataset"].get("aug", False),
        )
    elif training_config["dataset"]["type"] == "eligen_pose":
        dataset = EligenPoseDataset(
            data_path = training_config["dataset"]["data_path"],
            json_path = training_config["dataset"]["json_path"],
            depth_path = training_config["dataset"]["depth_path"],
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
            max_entity_len = training_config["dataset"].get("max_entity_len", 2),
            aug = training_config["dataset"].get("aug", False),
        )
    elif training_config["dataset"]["type"] == "loose_condition":
        dataset = LooseConditionDataset(
            data_path = training_config["dataset"]["data_path"],
            json_path = training_config["dataset"]["json_path"],
            depth_path = training_config["dataset"]["depth_path"],
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
            max_entity_len = training_config["dataset"].get("max_entity_len", 2),
            aug = training_config["dataset"].get("aug", False),
        )
    elif training_config["dataset"]["type"] == "eligen_depth":
        dataset = EligenDepthDataset(
            data_path = training_config["dataset"]["data_path"],
            json_path = training_config["dataset"]["json_path"],
            depth_path = training_config["dataset"]["depth_path"],
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
            max_entity_len = training_config["dataset"].get("max_entity_len", 2),
        )
    elif training_config["dataset"]["type"] == "eligen_loose":
        dataset = EligenLooseDataset(
            data_path = training_config["dataset"]["data_path"],
            json_path = training_config["dataset"]["json_path"],
            depth_path = training_config["dataset"]["depth_path"],
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
            max_entity_len = training_config["dataset"].get("max_entity_len", 2),
            aug = training_config["dataset"].get("aug", False),
        )
    elif training_config["dataset"]["type"] == "eligen_loose_2d":
        dataset = EligenLoose2DDataset(
            data_path = training_config["dataset"]["data_path"],
            json_path = training_config["dataset"]["json_path"],
            depth_path = training_config["dataset"]["depth_path"],
            condition_size=training_config["dataset"]["condition_size"],
            target_size=training_config["dataset"]["target_size"],
            condition_type=training_config["condition_type"],
            drop_text_prob=training_config["dataset"]["drop_text_prob"],
            drop_image_prob=training_config["dataset"]["drop_image_prob"],
            max_entity_len = training_config["dataset"].get("max_entity_len", 2),
            aug = training_config["dataset"].get("aug", False),
        )
    else:
        raise NotImplementedError

    print("Dataset length:", len(dataset))
    train_loader = DataLoader(
        dataset,
        batch_size=training_config["batch_size"],
        shuffle=True,
        num_workers=training_config["dataloader_workers"],
    )

    # Initialize model
    if not training_config["dataset"]["type"] == "pose":
        trainable_model = OminiModel(
            flux_pipe_id=config["flux_path"],
            lora_config=training_config["lora_config"],
            lora_path=training_config.get("lora_path", None),
            condition_type=training_config["condition_type"],
            device=f"cuda",
            dtype=getattr(torch, config["dtype"]),
            optimizer_config=training_config["optimizer"],
            model_config=config.get("model", {}),
            gradient_checkpointing=training_config.get("gradient_checkpointing", False),
        )
    else:
        trainable_model = OminiModel_3D(
            flux_pipe_id=config["flux_path"],
            lora_config=training_config["lora_config"],
            lora_path=training_config.get("lora_path", None),
            condition_type=training_config["condition_type"],
            device=f"cuda",
            dtype=getattr(torch, config["dtype"]),
            optimizer_config=training_config["optimizer"],
            model_config=config.get("model", {}),
            gradient_checkpointing=training_config.get("gradient_checkpointing", False),
        )

    # Callbacks for logging and saving checkpoints
    training_callbacks = (
        [TrainingCallback(run_name, training_config=training_config)]
        if is_main_process
        else []
    )

    # Initialize trainer
    trainer = L.Trainer(
        accumulate_grad_batches=training_config["accumulate_grad_batches"],
        callbacks=training_callbacks,
        enable_checkpointing=False,
        enable_progress_bar=False,
        logger=False,
        max_steps=training_config.get("max_steps", -1),
        max_epochs=training_config.get("max_epochs", -1),
        gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
        strategy=training_config["training_strategy"],
        # strategy='ddp_find_unused_parameters_true',
    )

    setattr(trainer, "training_config", training_config)

    # Save config
    save_path = training_config.get("save_path", "./output")
    if is_main_process:
        os.makedirs(f"{save_path}/{run_name}", exist_ok=True)
        with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
            yaml.dump(config, f)

    # Start training
    trainer.fit(trainable_model, train_loader)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default="./train/config/loose_512.yaml", help='training config file')
    args = parser.parse_args()

    # args.config = "train/loose_ablation/loose_depth.yaml"
    main(args.config)
