"""
Train a GAN using the techniques described in the paper
"StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis".
"""

import os
import json
import torch
import click
import dnnlib

from glob import glob
from metrics import metric_main
from typing import Union, Optional
from training import training_loop
from torch_utils import custom_ops
from torch_utils import distributed as dist


# Utility functions.
def find_latest_network_snapshot(outdir: str) -> Optional[str]:
    """Find the latest network snapshot in the output directory."""
    if not os.path.exists(outdir):
        return None
    
    # Find all network snapshot files with pattern: network-snapshot-*.pkl
    pattern = os.path.join(outdir, 'network-snapshot-*.pkl')
    snapshots = glob(pattern)
    
    if not snapshots:
        return None
    
    # Extract kimg values and find the latest one
    latest_snapshot = None
    latest_kimg = -1
    
    for snapshot in snapshots:
        filename = os.path.basename(snapshot)
        try:
            # Extract kimg from filename like "network-snapshot-00053248.pkl"
            kimg_str = filename.replace('network-snapshot-', '').replace('.pkl', '')
            kimg = int(kimg_str)
            if kimg > latest_kimg:
                latest_kimg = kimg
                latest_snapshot = snapshot
        except ValueError:
            continue
    
    return latest_snapshot

def parse_comma_separated_list(s: Union[None, str, list]) -> Union[list, str]:
    if isinstance(s, list):
        return s
    if s is None or s.lower() == 'none' or s == '':
        return []
    return s.split(',')


def parse_comma_separated_int_list(s: Union[None, str, list]) -> Union[list, str]:
    if isinstance(s, list):
        return s
    if s is None or s.lower() == 'none' or s == '':
        return []
    return [int(x) for x in s.split(',')]  # Ensure it's parsed as a list of integers


def parse_comma_separated_float_list(s: Union[None, str, list]) -> Union[list, str]:
    if isinstance(s, list):
        return s
    if s is None or s.lower() == 'none' or s == '':
        return []
    return [float(x) for x in s.split(',')]  # Ensure it's parsed as a list of floats


def is_power_of_two(n: int) -> bool:
    return (n != 0) and (n & (n-1) == 0)


def init_dataset_kwargs(
    data: str, 
    resolution: int, 
    label_type: Optional[str] = None,
    filter_keys_path: Optional[str] = None, 
    cls_to_text_path: Optional[str] = None,
    data_augmentation: bool = False,
    one_epoch: bool = False,
    processed_tar_read_dir: Optional[str] = None,
    processed_tar_write_dir: Optional[str] = None,
) -> dnnlib.EasyDict:
    d_kwargs = dnnlib.EasyDict(path=data,  xflip=False, use_labels=True)
    is_wds = len(glob(f'{data}/**/*.tar')) > 0  # check if files are tars, then it's a webdataset
    
    if is_wds:
        assert resolution, "Provide desired resolution when training on webdatasets."
        d_kwargs.class_name = 'training.data_wds.WdsWrapper'
        d_kwargs.label_type=label_type
        d_kwargs.filter_keys_path=filter_keys_path
        d_kwargs.cls_to_text_path=cls_to_text_path
        d_kwargs.data_augmentation=data_augmentation
        d_kwargs.one_epoch=one_epoch
        d_kwargs.processed_tar_read_dir = processed_tar_read_dir
        d_kwargs.processed_tar_write_dir = processed_tar_write_dir
    
    else:
        d_kwargs.class_name = 'training.data_zip.ImageFolderDataset'
        dataset_obj = dnnlib.util.construct_class_by_name(**d_kwargs) # subclass of training.dataset.Dataset
        assert resolution <= dataset_obj._raw_shape[-1], f"Native dataset resolution is smaller than {resolution}"
    
    assert is_power_of_two(resolution)
    d_kwargs.resolution = resolution
    
    return d_kwargs


# Main function.
@click.command("click", context_settings={'show_default': True})
# Generator size settings.
@click.option('--cfg',                                  help='Base config.',                                type=click.Choice(['custom', 'lite', 'base', 'large']), default='lite')
@click.option('--cbase',                                help='Capacity multiplier',                         type=click.IntRange(min=1), default=32768)
@click.option('--cmax',                                 help='Max. feature maps',                           type=click.IntRange(min=1), default=512)
@click.option('--res-blocks',                           help='Number of residual blocks',                   type=click.IntRange(min=1), default=2)
# Required.
@click.option('--outdir',                               help='Where to save the results',                   type=str, required=True)
@click.option('--training-data',                        help='Training data',                               type=str, required=True)
@click.option('--input-img-resolution',                 help='Resolution for input images',                 type=click.IntRange(min=8), required=True)
@click.option('--output-img-resolution',                help='Resolution for output images',                type=click.IntRange(min=8), required=True)
@click.option('--batch',                                help='Total batch size',                            type=click.IntRange(min=1), required=True)
@click.option('--batch-gpu',                            help='Limit batch size per GPU',                    type=click.IntRange(min=1), default=8)
# Training dataset settings.
@click.option('--conditional',                          help='Use label condition or not',                  type=bool, default=True)
@click.option('--label-type',                           help='Label type',                                  type=click.Choice(['text', 'cls2text', 'cls2id']), default='text')
@click.option('--filter-keys-path',                     help='Path to filter keys',                         type=str, default=None)
@click.option('--cls-to-text-path',                     help='Path to class to text mapping',               type=str, default=None)
# Validation dataset settings.
@click.option('--validation-data',                      help='Validation data',                             type=str, default=None)
# Blur settings (default off).
@click.option('--blur-init',                            help='Init blur width',                             type=click.IntRange(min=0), default=32,)
@click.option('--blur-fade-kimg',                       help='Discriminator blur duration',                 type=click.IntRange(min=0), default=0)
# Misc settings (default).
@click.option('--suffix',                               help='Suffix of result dirname',                    type=str, default='')
@click.option('--metrics',                              help='Quality metrics',                             type=parse_comma_separated_list, default=[])
@click.option('--seed',                                 help='Random seed',                                 type=click.IntRange(min=0), default=0)
@click.option('--nobench',                              help='Disable cuDNN benchmarking',                  type=bool, default=False)
@click.option('--workers',                              help='DataLoader worker processes',                 type=click.IntRange(min=1), default=3)
@click.option('--dry-run',                              help='Print training options and exit',             type=bool, is_flag=True)
# Wandb settings.
@click.option('--wandb',                                help='Enable wandb logging',                        type=bool, default=False)
@click.option('--wandb-project-name',                   help='Wandb project name',                          type=str, default='stylegan-t')
@click.option('--wandb-run-name',                       help='Wandb run name',                              type=str, default='train')
# Vision foundation model settings.
@click.option('--vfm-name',                             help='Vision foundation model name or path',        type=str, default=None)
@click.option('--patch-from-layers',                    help='Patch featuresfrom layers for VFM',           type=parse_comma_separated_int_list, default=None)
@click.option('--patch-resolutions',                    help='Patch resolutions for VFM',                   type=parse_comma_separated_int_list, default=None)
@click.option('--patch-in-dimensions',                  help='Patch dimensions for VFM',                    type=parse_comma_separated_int_list, default=None)
@click.option('--patch-out-dimensions',                 help='Patch output dimensions for VFM',             type=parse_comma_separated_int_list, default=None)
# Compression & decompression settings.
@click.option('--compression-mode',                     help='Compression mode',                            type=click.Choice(['continuous', 'discrete', 'no']), default='continuous')
@click.option('--how-to-compress',                      help='How to compress',                             type=click.Choice(['conv', 'attnproj']), default='attnproj')
@click.option('--how-to-decompress',                    help='How to decompress',                           type=click.Choice(['conv', 'attnproj']), default='attnproj')
@click.option('--decompress-factor',                    help='Decompression factor',                        type=click.IntRange(min=1), default=4)
@click.option('--attnproj-quant-layers',                help='Number of attnproj layers for quant',         type=click.IntRange(min=1), default=1)
@click.option('--attnproj-post-quant-layers',           help='Number of attnproj layers for post-quant',    type=click.IntRange(min=1), default=1)
# Latent settings.
@click.option('--z-resolution',                         help='Resolution of z',                             type=click.IntRange(min=1), default=16)
@click.option('--z-dimension',                          help='Latent dimension for continuous tokenzier',   type=click.IntRange(min=1), default=64)
@click.option('--vocab-width',                          help='Vocabulary width for discrete VQ',            type=click.IntRange(min=1), default=64)
@click.option('--z-pooled-resolution',                  help='Resolution of pooled z',                      type=click.IntRange(min=1), default=1)
@click.option('--z-dim-for-mapping-mlp-output',         help='Dimension of z for mapping MLP output',       type=click.IntRange(min=1), default=128)
# Discrete VQ settings.
@click.option('--vocab-size',                           help='Vocabulary size for discrete VQ',             type=click.IntRange(min=1), default=32768)
@click.option('--vocab-beta',                           help='Beta for discrete VQ',                        type=float, default=0.25)
@click.option('--entropy-loss-weight',                  help='Entropy loss weight for discrete VQ',         type=float, default=0.0)
@click.option('--entropy-temp',                         help='Entropy temperature for discrete VQ',         type=float, default=0.01)
@click.option('--num-codebooks',                        help='Number of codebooks for discrete VQ',         type=click.IntRange(min=1), default=8)
# Quantization loss settings.
@click.option('--kl-loss-weight',                       help='KL loss weight',                              type=float, default=1e-6)
@click.option('--vq-loss-weight',                       help='VQ loss weight for discrete VQ',              type=float, default=1.0)
@click.option('--vf-loss-weight',                       help='VF loss weight',                              type=float, default=0.0)
@click.option('--use-adaptive-vf-loss',                 help='Use adaptive VF loss',                        type=bool, default=False)
@click.option('--distmat-margin',                       help='Margin for distance matrix in VF loss',       type=float, default=0.0)
@click.option('--cos-margin',                           help='Margin for cosine similarity in VF loss',     type=float, default=0.0)
@click.option('--distmat-weight',                       help='Weight for distance matrix loss in VF loss',  type=float, default=1.0)
@click.option('--cos-weight',                           help='Weight for cosine sim. loss in VF loss',      type=float, default=1.0)
# Concatanated z settings.
@click.option('--concat-z-resolutions',                 help='Resolution for concatenated z',               type=parse_comma_separated_int_list, default=[])
@click.option('--concat-z-mapped-dims',                 help='Dimension of concatenated z after mapping',   type=parse_comma_separated_int_list, default=[])
@click.option('--how-to-process-concat-z',              help='How to process concat z',                     type=click.Choice(['unshuffle', 'pooling']), default='unshuffle')
@click.option('--activation-for-concat-z',              help='Activation for concat z',                     type=click.Choice(['lrelu', 'silu', 'gelu']), default='gelu')
# Generator settings.
@click.option('--attn-resolutions',                     help='Resolutions for attention',                   type=parse_comma_separated_int_list, default=[])
@click.option('--attn-depths',                          help='Depths for attention',                        type=parse_comma_separated_int_list, default=[])
@click.option('--use-self-attn',                        help='Use self-attention in the generator',         type=bool, default=False)
@click.option('--use-cross-attn',                       help='Use cross-attention in the generator',        type=bool, default=False)
@click.option('--use-convnext',                         help='Use ConvNeXt in the generator',               type=bool, default=False)
@click.option('--use-gaussian-blur',                    help='Use Gaussian blur for upsampling',            type=bool, default=False)
@click.option('--add-additional-convnext',              help='Add additional ConvNeXt in the lower res',    type=bool, default=False)
# Pixel loss settings.
@click.option('--l1-pixel-loss-weight',                 help='L1 Pixel loss weight',                        type=float, default=1.0)
@click.option('--l2-pixel-loss-weight',                 help='L2 Pixel loss weight',                        type=float, default=0.0)
@click.option('--pixel-loss-resolution',                help='Pixel loss resolution',                       type=click.IntRange(min=8), default=256)
# Perceptual loss settings.
@click.option('--perceptual-loss-weight',               help='Perceptual loss weight',                      type=float, default=10.0)
# SSIM loss settings.
@click.option('--ssim-loss-weight',                     help='SSIM loss weight',                            type=float, default=0.0)
# Multiscale pixel loss L1 settings.
@click.option('--multiscale-pixel-loss-weights',        help='Multi-scale pixel loss weights',              type=parse_comma_separated_float_list, default=[])
@click.option('--multiscale-resolutions',               help='Resolutions for multi-scale pixel loss',      type=parse_comma_separated_int_list, default=[])
@click.option('--multiscale-pixel-loss-start-kimg',     help='Start kimg for multi-scale pixel loss',       type=click.IntRange(min=0), default=0)
@click.option('--multiscale-pixel-loss-end-kimg',       help='End kimg for multi-scale pixel loss',         type=click.IntRange(min=0), default=1e9)
# CLIP loss settings (default off).
@click.option('--clip-loss-weight',                     help='CLIP loss weight',                            type=float, default=0.0)
@click.option('--clip-loss-start-kimg',                 help='Start kimg for CLIP loss',                    type=click.IntRange(min=0), default=0)
# Matching aware loss settings (default off).
@click.option('--matching-aware-loss-weight',           help='Matching aware loss weight',                  type=float, default=0.0)
@click.option('--matching-aware-loss-start-kimg',       help='Start kimg for matching aware loss',          type=click.IntRange(min=0), default=100)
# Discriminator loss settings.
@click.option('--stylegan-t-discriminator-loss-weight', help='StyleGAN-T discriminator loss weight',        type=float, default=1.0)
@click.option('--patchgan-discriminator-loss-weight',   help='PatchGAN discriminator loss weight',          type=float, default=1.0)
# PatchGAN's feature matching loss settings.
@click.option('--feature-matching-loss-weight',         help='Feature matching loss weight for PatchGAN',   type=float, default=10.0)
# Discriminator warm-up settings.
@click.option('--use-stylegan-t-disc-warmup',           help='Use StyleGAN-T discriminator warm-up',        type=bool, default=False)
@click.option('--use-patchgan-disc-warmup',             help='Use PatchGAN discriminator warm-up',          type=bool, default=False)
# Equivariance Regularization settings.
@click.option('--use-equivariance-regularization',      help='Use equivariance regularization',             type=bool, default=False)
# Traing phase settings.
@click.option('--force-phase-training',                 help='Force phase training',                        type=bool, default=False)
# Resuming settings.
@click.option('--resume',                               help='Resume from given network pickle',            type=str)
@click.option('--resume-kimg',                          help='Resume from given kimg',                      type=click.IntRange(min=0), default=0)
@click.option('--resume-discriminator',                 help='Whether to resume discriminator',             type=bool, default=True)
# Training settings.
@click.option('--lr-multiplier',                        help='Learning rate multiplier for generator',      type=float, default=1.0)
@click.option('--train-mode',                           help='Which layers to train',                       type=click.Choice(['all', 'text_encoder', 'freeze32', 'freeze_encoder']), default='all')
@click.option('--fp32',                                 help='Disable mixed-precision',                     type=bool, default=False)
@click.option('--num-fp16-res',                         help='Number of fp16 resolutions',                  type=click.IntRange(min=0), default=4)
@click.option('--base-mult',                            help='Start resolution of log2',                    type=click.IntRange(min=1), default=3)
@click.option('--tick',                                 help='How often to print progress',                 type=click.IntRange(min=1), default=4)
@click.option('--image-snap',                           help='How often to save image snapshots',           type=click.IntRange(min=1), default=50)
@click.option('--network-snap',                         help='How often to save network snapshots',         type=click.IntRange(min=1), default=50)
@click.option('--kimg',                                 help='Total training duration',                     type=click.IntRange(min=1), default=1000)
@click.option('--one-epoch',                            help='Train for one epoch',                         type=bool, default=False)

def main(**kwargs) -> None:
    # Initialize config.
    torch.multiprocessing.set_start_method('spawn')
    dist.init()
    opts = dnnlib.EasyDict(kwargs)
    c = dnnlib.EasyDict()

    # Auto-resume logic: find latest snapshot if no resume specified
    if opts.resume is None:
        latest_snapshot = find_latest_network_snapshot(opts.outdir)
        if latest_snapshot:
            opts.resume = latest_snapshot
            # Extract kimg from filename for resume_kimg if not specified
            if opts.resume_kimg == 0:
                filename = os.path.basename(latest_snapshot)
                try:
                    kimg_str = filename.replace('network-snapshot-', '').replace('.pkl', '')
                    opts.resume_kimg = int(kimg_str)
                except ValueError:
                    pass
            dist.print0(f'Auto-resuming from: {opts.resume} at kimg {opts.resume_kimg}')
        else:
            dist.print0('No existing snapshots found, starting training from scratch.')

    # One epoch settings.
    c.one_epoch = opts.one_epoch
    if opts.one_epoch:
        opts.kimg = 1e9 # 1e9 is a large number, just contain the one epoch
        
    # Wandb.
    if opts.wandb:
        c.wandb_project_name = opts.wandb_project_name
        c.wandb_run_name = opts.wandb_run_name

    # Networks.
    # --------------------------------------------------
    # Discriminator settings.
    # ---------------------------------------------------
    c.D_kwargs = dnnlib.EasyDict(
        class_name='networks.discriminator.ProjectedDiscriminator',
        vfm_name=opts.vfm_name,                                                     # align the preprocessing of images with the VFM name
        use_stylegan_t_discriminator=opts.stylegan_t_discriminator_loss_weight > 0, # whether to use StyleGAN-T discriminator
        use_patchgan_discriminator=opts.patchgan_discriminator_loss_weight > 0,     # whether to use PatchGAN discriminator
        get_interm_feat=opts.feature_matching_loss_weight > 0,                      # whether to get intermediate features for PatchGAN
    )

    # --------------------------------------------------
    # Generator settings.
    # --------------------------------------------------
    # Conditioning settings.
    conditioning_kwargs = {
        'conditional': opts.conditional,                                            # whether to use label condition
        'label_type': opts.label_type,                                              # the type of label: text, cls2text, cls2id
    }
    
    # Vision Foundation Model (VFM) settings.
    assert opts.vfm_name is not None, "VFM name or path is required."
    vfm_kwargs = {
        'vfm_name': opts.vfm_name,
        'patch_from_layers': opts.patch_from_layers,
        'patch_resolutions': opts.patch_resolutions,
        'patch_in_dimensions': opts.patch_in_dimensions,
        'patch_out_dimensions': opts.patch_out_dimensions,
    }

    ldm_kwargs = {
        # Compression and decompression settings.
        'compression_mode': opts.compression_mode,
        'how_to_compress': opts.how_to_compress,
        'how_to_decompress': opts.how_to_decompress,
        'decompress_factor': opts.decompress_factor,
        'attnproj_quant_layers': opts.attnproj_quant_layers,
        'attnproj_post_quant_layers': opts.attnproj_post_quant_layers,
        # Latent (Z) settings.
        'z_resolution': opts.z_resolution,
        'z_dimension': opts.z_dimension,  # for continuous tokenizer    
        'vocab_width': opts.vocab_width,  # for discrete tokenizer
        'z_pooled_resolution': opts.z_pooled_resolution,
        'z_dim_for_mapping_mlp_output': opts.z_dim_for_mapping_mlp_output,
        # VQ settings.
        'vocab_size': opts.vocab_size,
        'vocab_width': opts.vocab_width,
        'vocab_beta': opts.vocab_beta,
        'use_entropy_loss': opts.entropy_loss_weight > 0,
        'entropy_temp': opts.entropy_temp,
        'num_codebooks': opts.num_codebooks,
        # Losses settings.
        'use_kl_loss': opts.kl_loss_weight > 0,
        'use_vf_loss': opts.vf_loss_weight > 0,
        'use_adaptive_vf_loss': opts.use_adaptive_vf_loss,
        'distmat_margin': opts.distmat_margin,
        'cos_margin': opts.cos_margin,
        'distmat_weight': opts.distmat_weight,
        'cos_weight': opts.cos_weight,
    }

    # Concatenated z settings.
    assert any([isinstance(res, int) for res in opts.concat_z_resolutions]), "All resolutions must be integers."
    if opts.concat_z_mapped_dims:
        assert len(opts.concat_z_mapped_dims) == len(opts.concat_z_resolutions), "Each resolution must have a corresponding mapped dimension."
        dist.print0("Using manually defined mapped dimensions for concatenated z.")
    else:
        dist.print0("Using default mapped dimensions for concatenated z.")
    concat_z_kwargs = {
        'concat_z_resolutions': opts.concat_z_resolutions,
        'concat_z_mapped_dims': opts.concat_z_mapped_dims,
        'how_to_process_concat_z': opts.how_to_process_concat_z,
        'activation_for_concat_z': opts.activation_for_concat_z,
    }

    assert not opts.use_cross_attn if opts.label_type == 'cls2id' else True, "Cross-attention is not supported for cls2id label type."
    assert len(opts.attn_resolutions) == len(opts.attn_depths), "attn_resolutions and attn_depths must have the same length."
    arch_kwargs = {
        'use_multiscale_output': len(opts.multiscale_resolutions) > 0,
        'attn_resolutions': opts.attn_resolutions,
        'attn_depths': opts.attn_depths,
        'use_self_attn': opts.use_self_attn,
        'use_cross_attn': opts.use_cross_attn,
        'use_convnext': opts.use_convnext,
        'use_gaussian_blur': opts.use_gaussian_blur,
        'add_additional_convnext': opts.add_additional_convnext,
    }

    equivariance_regularization_kwargs = {
        'use_equivariance_regularization': opts.use_equivariance_regularization,
        'equivariance_regularization_p_prior': 0.5,
        'equivariance_regularization_p_prior_scale': 0.25,
    }

    output_image_kwargs = {
        'img_resolution': opts.output_img_resolution,   # resolution of the output image
        'img_channels': 3,                              # number of channels in the output image
    }

    training_kwargs = {
        'train_mode': opts.train_mode,                  # training mode: all, text_encoder, freeze32
        'num_fp16_res': opts.num_fp16_res,              # number of resolutions to use fp16 from the end to the beginning
        'base_mult': opts.base_mult,                    # start resolution = 2 ** {base_mult}
    }

    c.G_kwargs = dnnlib.EasyDict(
        class_name='networks.generator.Generator', # the mapping network is slightly different from the StyleGAN-T
        **conditioning_kwargs,
        **vfm_kwargs,
        **ldm_kwargs,
        **concat_z_kwargs,
        **arch_kwargs,
        **equivariance_regularization_kwargs,
        **output_image_kwargs,
        **training_kwargs,
    )

    # Synthesis settings.
    cfg_synthesis = {
        'large':    dnnlib.EasyDict(channel_base=65536, channel_max=1024, num_res_blocks=4),
        'base':     dnnlib.EasyDict(channel_base=32768, channel_max=1024, num_res_blocks=3),
        'lite':     dnnlib.EasyDict(channel_base=32768, channel_max=512,  num_res_blocks=2),
        'custom':   dnnlib.EasyDict(channel_base=opts.cbase, channel_max=opts.cmax, num_res_blocks=opts.res_blocks),
    }
    c.G_kwargs.synthesis_kwargs = cfg_synthesis[opts.cfg]
    c.G_kwargs.synthesis_kwargs.architecture = 'skip'

    # Optimizer.
    # Learning rate should be set according to the total batch size: lr = 0.002 is great for 'lite' with image size 256, batch size per GPU = 8 on 4 A100 GPUs.
    c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=(0.0, 0.99), eps=1e-8, lr=0.002 * opts.lr_multiplier)
    c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=(0.0, 0.99), eps=1e-8, lr=0.002 * opts.lr_multiplier)
    if c.G_kwargs.train_mode == 'freeze32':
        c.G_opt_kwargs.lr = 3e-6

    # Loss.
    c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.ProjectedGANLoss')

    c.loss_kwargs.vfm_name = opts.vfm_name # align the preprocessing of images with the VFM name
    c.loss_kwargs.resume_kimg = opts.resume_kimg
    c.loss_kwargs.output_img_resolution = opts.output_img_resolution
    c.loss_kwargs.use_equivariance_regularization = opts.use_equivariance_regularization
    c.loss_kwargs.blur_init_sigma = opts.blur_init
    c.loss_kwargs.blur_fade_kimg = opts.blur_fade_kimg

    # Pixel loss settings.
    c.loss_kwargs.l1_pixel_loss_weight = opts.l1_pixel_loss_weight
    c.loss_kwargs.l2_pixel_loss_weight = opts.l2_pixel_loss_weight
    c.loss_kwargs.pixel_loss_resolution = opts.pixel_loss_resolution

    # Perceptual loss settings.
    c.loss_kwargs.perceptual_loss_weight = opts.perceptual_loss_weight

    # SSIM loss settings.
    c.loss_kwargs.ssim_loss_weight = opts.ssim_loss_weight

    # Multi-scale pixel loss L1 settings.
    c.loss_kwargs.multiscale_pixel_loss_weights = opts.multiscale_pixel_loss_weights
    c.loss_kwargs.multiscale_resolutions = opts.multiscale_resolutions
    c.loss_kwargs.multiscale_pixel_loss_start_kimg = opts.multiscale_pixel_loss_start_kimg
    c.loss_kwargs.multiscale_pixel_loss_end_kimg = opts.multiscale_pixel_loss_end_kimg

    # VF loss settings.
    c.loss_kwargs.vf_loss_weight = opts.vf_loss_weight
    c.loss_kwargs.use_adaptive_vf_loss = opts.use_adaptive_vf_loss

    # CLIP loss settings.
    assert opts.clip_loss_weight == 0. if opts.label_type == 'cls2id' else True, "CLIP loss is not supported for cls2id label type."
    c.loss_kwargs.clip_loss_weight = opts.clip_loss_weight
    c.loss_kwargs.clip_loss_start_kimg = opts.clip_loss_start_kimg

    # Matching aware loss settings.
    c.loss_kwargs.matching_aware_loss_weight = opts.matching_aware_loss_weight
    c.loss_kwargs.matching_aware_loss_start_kimg = opts.matching_aware_loss_start_kimg

    # KL loss settings.
    c.loss_kwargs.compression_mode = opts.compression_mode
    c.loss_kwargs.kl_loss_weight = opts.kl_loss_weight
    c.loss_kwargs.entropy_loss_weight = opts.entropy_loss_weight
    c.loss_kwargs.vq_loss_weight = opts.vq_loss_weight

    # Discriminator loss settings.
    c.loss_kwargs.stylegan_t_discriminator_loss_weight = opts.stylegan_t_discriminator_loss_weight
    c.loss_kwargs.patchgan_discriminator_loss_weight = opts.patchgan_discriminator_loss_weight
    
    # PatchGAN's feature matching loss settings.
    c.loss_kwargs.feature_matching_loss_weight = opts.feature_matching_loss_weight

    # Discriminator warm-up settings.
    c.loss_kwargs.use_stylegan_t_disc_warmup = opts.use_stylegan_t_disc_warmup
    c.loss_kwargs.use_patchgan_disc_warmup = opts.use_patchgan_disc_warmup

    # Total training duration.
    c.loss_kwargs.total_kimg = opts.kimg

    # Force phase training settings.
    # ------------------------------------------------------------------------------
    # 2-Phase Training Strategy (auto or manual via `force_phase_training`):
    #
    # Phase 0 – Pixel Reconstruction Warmup:
    #     • StyleGAN-T Discriminator ON
    #     • Optimize: Pixel losses + LPIPS + SSIM
    #     • Trigger to Phase 1: D loss becomes stable
    #         → the absolute difference between the two windows of D loss is less than threshold
    #
    # Phase 1 – Detail Enhancement (StyleGAN-T + PatchGAN):
    #     • Both Discriminators ON
    #     • Freeze generator blocks ≤ 32px resolution
    #     • Optimize: disable all reconstruction losses, including pixel losses, LPIPS, and SSIM
    #
    # `force_phase_training' will automatically set the training phases based on the loss curves.
    #
    # ------------------------- Phase Control Parameters --------------------------
    #
    # The following CLI argument controls how phases are entered:
    #
    #   --use_patchgan_discriminator:
    #       - Only meaningful if force_phase == -2 (custom mode).
    #       - Enables PatchGAN discriminator during training.
    #
    #   --pixel_loss_resolution <int>:
    #       - If set (e.g., 64), input and generated images are downsampled to this
    #         resolution before computing pixel-based losses.
    #       - Useful in Phase 1 or fine-tuning when high-frequency generation is desired,
    #         and pixel alignment is no longer strictly enforced.
    #
    # Typical usage:
    #   - For full 2-phase training from scratch:
    #       --force_phase_training True
    #
    #   - For manual fine-tuning with PatchGAN:
    #       --force_phase_training False -- use_patchgan_discriminator True
    #
    #   - For static reconstruction-only training (no GANs):
    #       --force_phase_training False -- use_patchgan_discriminator False
    # ------------------------------------------------------------------------------
    c.D_kwargs.force_phase_training = opts.force_phase_training     # whether to use the related discriminator
    c.loss_kwargs.force_phase_training = opts.force_phase_training  # whether to use 2-phase training strategy

    # Data.
    # Training.
    assert opts.cls_to_text_path is not None if opts.label_type in ['cls2text', 'cls2id'] else True, "cls_to_text_path is required for cls2text and cls2id label types."
    c.training_set_kwargs = init_dataset_kwargs(
        data=opts.training_data,
        resolution=opts.input_img_resolution, 
        label_type=opts.label_type,
        filter_keys_path=opts.filter_keys_path,
        cls_to_text_path=opts.cls_to_text_path,
        data_augmentation=True if opts.label_type in ['cls2text', 'cls2id'] else False,
        one_epoch=opts.one_epoch,
        processed_tar_read_dir = os.path.join(os.path.dirname(opts.resume), 'processed_tars') if opts.resume is not None and opts.one_epoch else None,
        processed_tar_write_dir = os.path.abspath(os.path.join(opts.outdir, 'processed_tars')) if opts.one_epoch else None,
    )
    c.training_data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)

    # Validation: the same as training set, except the data path, and it only supports image folder dataset.
    if opts.validation_data is not None:
        c.validation_set_kwargs = init_dataset_kwargs(
            data=opts.validation_data, 
            resolution=opts.input_img_resolution,
        )
        c.validation_data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)
    
    else:
        c.validation_set_kwargs = {}
        c.validation_data_loader_kwargs = {}

    # Logging.
    c.random_seed = c.training_set_kwargs.random_seed = opts.seed
    c.image_snapshot_ticks = opts.image_snap
    c.network_snapshot_ticks = opts.network_snap
    c.metrics = opts.metrics
    c.total_kimg = opts.kimg
    c.kimg_per_tick = opts.tick

    # GPUs and batch size.
    c.batch_size = opts.batch
    c.batch_gpu = opts.batch_gpu
    c.ema_kimg = c.batch_size * 10 / 32

    # Sanity checks.
    if c.batch_size % dist.get_world_size() != 0:
        raise click.ClickException('--batch must be a multiple of --gpus')
    if c.batch_size % (dist.get_world_size() * c.batch_gpu) != 0:
        raise click.ClickException('--batch must be a multiple of --gpus times --batch-gpu')
    if any(not metric_main.is_valid_metric(metric) for metric in c.metrics):
        err = ['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()
        raise click.ClickException('\n'.join(err))

    # Resume.
    if opts.resume is not None:
        c.resume_pkl = opts.resume
        c.resume_kimg = opts.resume_kimg
        c.resume_discriminator = opts.resume_discriminator
        c.ema_rampup = None  # disable EMA rampup

    # Performance-related toggles.
    if opts.fp32:
        c.G_kwargs.num_fp16_res = 0
        c.G_kwargs.conv_clamp = None
    if opts.nobench:
        c.cudnn_benchmark = False

    # Output directory
    c.run_dir=opts.outdir
    c.train_sample_dir=os.path.join(c.run_dir, 'train_samples')

    # Print options.
    dist.print0()
    dist.print0('Training options:')
    dist.print0(json.dumps(c, indent=2))
    dist.print0(f'Output directory:         {c.run_dir}')
    dist.print0(f'Number of GPUs:           {dist.get_world_size()}')
    dist.print0(f'Batch size:               {c.batch_size} images')
    dist.print0(f'Training duration:        {c.total_kimg} kimg')
    dist.print0(f'Dataset path:             {c.training_set_kwargs.path}')
    dist.print0(f'Dataset resolution:       {c.training_set_kwargs.resolution}')
    dist.print0(f'Dataset labels:           {c.training_set_kwargs.use_labels}')
    dist.print0(f'Label type:               {c.training_set_kwargs.label_type}')
    dist.print0(f'Label filter keys:        {c.training_set_kwargs.filter_keys_path}')
    dist.print0(f'Dataset augmentations:    {c.training_set_kwargs.data_augmentation}')
    dist.print0()

    # Dry run?
    if opts.dry_run:
        dist.print0('Dry run; exiting.')
        return

    # Create output directory.
    dist.print0('Creating output directory...')
    if dist.get_rank() == 0:
        os.makedirs(c.run_dir, exist_ok=True)                                       # create output directory
        os.makedirs(c.train_sample_dir, exist_ok=True)                              # create train sample directory
        with open(os.path.join(c.run_dir, 'training_options.json'), 'wt+') as f:     
            json.dump(c, f, indent=2)
        dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)
    else:
        custom_ops.verbosity = 'none'

    # Train.
    training_loop.training_loop(**c)

if __name__ == "__main__":
    main()  # pylint: disable=no-value-for-parameter
