"""Argument parsing for AutoFID training."""

import argparse
import os

def parse_args():
    """Parse command line flags for AutoFID training."""
    parser = argparse.ArgumentParser(
        description="Train diffusion model via policy gradient method."
    )
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default="runwayml/stable-diffusion-v1-5",
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help=(
            "The name of the Dataset (from the HuggingFace hub) to train on"
            " (could be your own, possibly private, dataset). It can also be a"
            " path pointing to a local copy of a dataset in your filesystem, or"
            " to a folder containing files that 🤗 Datasets can understand."
        ),
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The config of the Dataset, leave as None if there's only one config.",
    )
    parser.add_argument(
        "--train_data_dir",
        type=str,
        default=None,
        help=(
            "A folder containing the training data. Folder contents must follow"
            " the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In"
            " particular, a `metadata.jsonl` file must exist to provide the"
            " captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--image_column",
        type=str,
        default="image",
        help="The column of the dataset containing an image.",
    )
    parser.add_argument(
        "--caption_column",
        type=str,
        default="text",
        help="The column of the dataset containing a caption or a list of captions.",
    )
    parser.add_argument(
        "--max_train_samples",
        type=int,
        default=None,
        help=(
            "For debugging purposes or quicker training, truncate the number of"
            " training examples to this value if set."
        ),
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="online_model",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="The directory where the downloaded models and datasets will be stored.",
    )
    parser.add_argument(
        "--seed", type=int, default=None, help="A seed for reproducible training."
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,
        help=(
            "The resolution for input images, all the images in the"
            " train/validation dataset will be resized to this resolution"
        ),
    )
    parser.add_argument(
        "--center_crop",
        default=True,
        help=(
            "Whether to center crop the input images to the resolution. If not"
            " set, the images will be randomly cropped. The images will be"
            " resized to the resolution first before cropping."
        ),
    )
    parser.add_argument(
        "--random_flip",
        default=True,
        help="whether to randomly flip images horizontally",
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument("--num_train_epochs", type=int, default=100)
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=10000,
        help=(
            "Total number of training steps to perform. If provided, overrides"
            " num_train_epochs."
        ),
    )
    parser.add_argument(
        "--gradient_checkpointing",
        default=True,
        help=(
            "Whether or not to use gradient checkpointing to save memory at the"
            " expense of slower backward pass."
        ),
    )
    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=False,
        help=(
            "Scale the learning rate by the number of GPUs, gradient accumulation"
            " steps, and batch size."
        ),
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine",'
            ' "cosine_with_restarts", "polynomial", "constant",'
            ' "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_warmup_steps",
        type=int,
        default=0,
        help="Number of steps for the warmup in the lr scheduler.",
    )
    parser.add_argument(
        "--use_8bit_adam",
        action="store_true",
        help="Whether or not to use 8-bit Adam from bitsandbytes.",
    )
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up"
            " training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--use_ema", default=False, help="Whether to use EMA model."
    )
    parser.add_argument(
        "--non_ema_revision",
        type=str,
        default=None,
        required=False,
        help=(
            "Revision of pretrained non-ema model identifier. Must be a branch,"
            " tag or git identifier of the local or remote repository specified"
            " with --pretrained_model_name_or_path."
        ),
    )
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=0,
        help=(
            "Number of subprocesses to use for data loading. 0 means that the"
            " data will be loaded in the main process."
        ),
    )
    parser.add_argument(
        "--adam_beta1",
        type=float,
        default=0.9,
        help="The beta1 parameter for the Adam optimizer.",
    )
    parser.add_argument(
        "--adam_beta2",
        type=float,
        default=0.999,
        help="The beta2 parameter for the Adam optimizer.",
    )
    parser.add_argument(
        "--adam_weight_decay",
        type=float,
        default=0.0,
        help="Weight decay to use.",
    )
    parser.add_argument(
        "--adam_epsilon",
        type=float,
        default=1e-08,
        help="Epsilon value for the Adam optimizer",
    )
    parser.add_argument(
        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--push_to_hub",
        action="store_true",
        help="Whether or not to push the model to the Hub.",
    )
    parser.add_argument(
        "--hub_token",
        type=str,
        default=None,
        help="The token to use to push to the Model Hub.",
    )
    parser.add_argument(
        "--hub_model_id",
        type=str,
        default=None,
        help="The name of the repository to keep in sync with the local `output_dir`.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory."
            " Will default to *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="bf16",
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16"
            " (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU."
            "  Default to the value of accelerate config of the current system or"
            " the flag passed with the `accelerate.launch` command. Use this"
            " argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="tensorboard",
        help=(
            "The integration to report the results and logs to. Supported"
            ' platforms are `"tensorboard"` (default), `"wandb"` and'
            ' `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="For distributed training: local_rank",
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=100,
        help=(
            "Save a checkpoint of the training state every X updates. These"
            " checkpoints are only suitable for resuming training using"
            " `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--checkpoints_total_limit",
        type=int,
        default=None,
        help=(
            "Max number of checkpoints to store. Passed as `total_limit` to the"
            " `Accelerator` `ProjectConfiguration`."
        ),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a"
            ' path saved by `--checkpointing_steps`, or `"latest"` to'
            " automatically select the last available checkpoint."
        ),
    )
    parser.add_argument(
        "--enable_xformers_memory_efficient_attention",
        action="store_true",
        help="Whether or not to use xformers.",
    )
    # Custom arguments
    parser.add_argument(
        "--p_step",
        type=int,
        default=5,
        help="The number of steps to update the policy per sampling step",
    )
    parser.add_argument(
        "--p_batch_size",
        type=int,
        default=2,
        help=(
            "batch size for policy update per gpu, before gradient accumulation;"
            " total batch size per gpu = gradient_accumulation_steps *"
            " p_batch_size"
        ),
    )
    parser.add_argument(
        "--g_batch_size",
        type=int,
        default=6,
        help="batch size of prompts for sampling per gpu",
    )
    parser.add_argument(
        "--sft_path",
        type=str,
        default="./checkpoints/models/finetune_b512_lr2e-05_max10000_w0.01",
        help="path to the pretrained supervised finetuned model",
    )
    parser.add_argument(
        "--reward_model_path",
        type=str,
        default="./checkpoints/reward/reward_model_5007.pkl",
        help="path to the pretrained reward model",
    )
    parser.add_argument(
        "--reward_weight", type=float, default=100, help="weight of reward loss"
    )
    parser.add_argument(
        "--reward_flag",
        type=int,
        default=0,
        help="0: ImageReward, 1: Custom reward model",
    )
    parser.add_argument(
        "--reward_filter",
        type=int,
        default=0,
        help="0: raw value, 1: took positive",
    )
    parser.add_argument(
        "--kl_weight", type=float, default=0.01, help="weight of kl loss"
    )
    parser.add_argument(
        "--kl_warmup", type=int, default=-1, help="warm up for kl weight"
    )
    parser.add_argument(
        "--buffer_size", type=int, default=1000, help="size of replay buffer"
    )
    parser.add_argument(
        "--save_interval",
        type=int,
        default=100,
        help="save model every save_interval steps",
    )
    parser.add_argument(
        "--clip_norm", type=float, default=0.1, help="norm for gradient clipping"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=(
            "Number of updates steps to accumulate before performing a"
            " backward/update pass for policy"
        ),
    )
    parser.add_argument("--lora_rank", type=int, default=4, help="rank for LoRA")
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-5,
        help="Learning rate for policy",
    )
    parser.add_argument(
        "--prompt_path",
        type=str,
        default="./dataset/imagenet1k/data_meta_9.json",
        help="path to the prompt dataset",
    )
    parser.add_argument(
        "--prompt_category",
        type=str,
        default="all",
        help="all or specific categories with comma [e.g., color,count]",
    )
    parser.add_argument(
        "--single_flag",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--single_prompt",
        type=str,
        default="A green colored rabbit.",
    )
    parser.add_argument(
        "--sft_initialization",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--num_validation_images",
        type=int,
        default=2,
    )
    parser.add_argument(
        "--multi_gpu",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--ratio_clip",
        type=float,
        default=1e-4,
    )
    parser.add_argument(
        "--project_name",
        type=str,
        default="fidrl",
        help="Name of the project for tracking."
    )
    parser.add_argument(
        "--use_non_repeating_samples",
        action="store_true",
        help="Whether to use non-repeating samples for value and policy function training."
    )
    parser.add_argument(
        "--first_n_vis",
        type=int,
        default=8,
        help="Number of samples to visualize."
    )
    parser.add_argument(
        "--resume_from_saved_model",
        type=str,
        default=None,
        help="Path to a saved model to resume training from."
    )
    parser.add_argument(
        "--num_inference_steps",
        type=int,
        default=50,
        help="Number of denoising steps for the diffusion model."
    )
    parser.add_argument(
        "--num_vis_images",
        type=int,
        default=8,
        help="Number of images to visualize."
    )

    # * AutoFID
    parser.add_argument(
        "--num_fid_images",
        type=int,
        default=5000,
        help="Total number of images to generate for FID calculation."
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=50,
        help="Number of images to replace per rollout."
    )
    parser.add_argument(
        "--num_groups",
        type=int,
        default=12,
        help="Number of fat buffer rows per rollout_collect"
    )
    parser.add_argument(
        "--image_pool_dir",
        type=str,
        default=None,
        help="Directory to save/load the image pool. If None, uses output_dir/image_pool"
    )
    parser.add_argument(
        "--guidance_scale",
        type=float,
        default=1.0,
        help="Guidance scale for classifier-free guidance during image generation"
    )
    parser.add_argument(
        "--image_size",
        type=int,
        default=256,
        help="Size of generated images"
    )
    parser.add_argument(
        "--gt_fid_stats",
        type=str,
        required=True,
        help="Path to the FID statistics of the ground truth images."
    )
    parser.add_argument(
        "--grpo_flag",
        type=int,
        default=1,
        help=(
            "0: No advantage normalization, "
            "1: Synchronized normalization across GPUs, "
            "2: Normalize in advance, "
            "3: Remove duplicates before normalization"
        ),
    )
    # * 组策略标志：是否使用组策略 (0为不使用，1为使用多 GPU 同步归一化的 GRPO)

    parser.add_argument(
        "--p_loss_for_each_img",
        action="store_true",
        help="Whether to use p_loss_for_each_img for policy training."
    )

    parser.add_argument(
        "--p_num_groups",
        type=int,
        default=12,
        help="Number of groups for each step of policy training."
    )

    parser.add_argument(
        "--flat_rollout",
        action="store_true",
        help="Whether to use flat rollout for policy training."
    )

    # Model type and related arguments
    parser.add_argument(
        "--model_type",
        type=str,
        default="sd15",
        choices=["sd15", "edm2", "sit"],
        help="Type of diffusion model to use (sd15, edm2, sit)"
    )
    parser.add_argument(
        "--model_output_dir",
        type=str,
        default=None,
        help="Directory to save model outputs separately from logs. If None, uses output_dir"
    )

    # * from Rejection Sampling
    parser.add_argument(
        "--dist_method",
        type=str,
        default="gpu",
        choices=["filesystem", "gpu"],
        help="Method to distribute the image pool."
    )
    parser.add_argument(
        "--refill_interval",
        type=int,
        default=10,
        help="Interval to refill the image pool."
    )
    # Bin loss parameters
    parser.add_argument(
        "--num_loss_bins",
        type=int,
        default=10,
        help="Number of bins to divide num_inference_steps for bin loss calculation"
    )
    parser.add_argument(
        "--global_flag",
        type=int,
        default=-1,
        help=(
            "-1: no filtering (default), "
            "0: local best samples, "
            "1: global best samples"
            # more
            "2: local best and worst samples"
        )
    )
    parser.add_argument(
        "--num_best_samples",
        type=int,
        default=4,
        help="Number of best samples to keep for rejection sampling"
    )

    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    if args.p_loss_for_each_img:
        assert args.flat_rollout == False, "p_loss_for_each_img is for non-flat rollout"

    return args
