# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Train diffusion-based generative model using the techniques described in the
paper "Elucidating the Design Space of Diffusion-Based Generative Models"."""

import os
import sys
import re
import json
import click
import wandb
import torch
import yaml
import dnnlib
from torch_utils import distributed as dist
from training import training_loop
from training.datasets import *
import warnings
warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12.


# Load configuration file
def load_and_update_config(opts):
    config_path = os.path.join('./configs', opts.config)
    if not os.path.exists(config_path):
        raise click.ClickException(f'Config file not found: {config_path}')
    
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Get the Click command's parameters to check which were explicitly set
    ctx = click.get_current_context()
    cli_params = {}
    for param in ctx.command.params:
        cli_params[param.name] = ctx.get_parameter_source(param.name)

    # Set server-specific paths
    if opts.server == 'server':
        project_prefix = "/projects/physics_diffusion"
        scratch_prefix = "/scratch"

    
     # Update opts with config values ONLY if they weren't provided via command line
    for key, value in config.items():
        if key in cli_params:  # If it's a Click parameter
            if cli_params[key] == click.core.ParameterSource.DEFAULT:  # Only update if using default value
                setattr(opts, key, value)
        else:  # If it's not a Click parameter, add it as a new attribute
            setattr(opts, key, value)
            
    # Update paths in config
    if cli_params.get('outdir') == click.core.ParameterSource.DEFAULT:
        outdir_prefix = f"{scratch_prefix}/physics_diffusion_exp/experiments_{opts.user}/"
        opts.outdir = os.path.join(outdir_prefix, 
                                f"{config.get('main_dataset', 'unknown')}_data",
                                config.get('equation', 'unknown'),
                                opts.mode,
                                opts.model_type,
                                config.get('run_name', 'default_run'))
    if cli_params.get('data') == click.core.ParameterSource.DEFAULT:
        opts.data = os.path.join(project_prefix, config['data'])
    if cli_params.get('validate_data') == click.core.ParameterSource.DEFAULT:
        opts.validate_data = os.path.join(project_prefix, config['validate_data'])
    
    # Set batch size based on number of GPUs and batch per GPU
    opts.batch = opts.num_gpus * opts.batch_gpu
    print(f"opts = {opts}")
    # sys.exit(0)  # Exit after printing opts for debugging
    return opts


    # Add this function after the load_and_update_config function
def build_masking_strategy(opts):
    """Build masking strategy dictionary from options."""
    if not opts.use_sparse_conditioning:
        return None
        
    masking_strategy = {
        'random_sample_masking': getattr(opts, 'random_sample_masking', False),
        'final_mask_sample_rate': getattr(opts, 'final_mask_sample_rate', 0.0),
        'final_mask_observation_rate': getattr(opts, 'final_mask_observation_rate', 1.0),
        'enable_sparsity_curriculum': getattr(opts, 'enable_sparsity_curriculum', False),
        'enable_sample_curriculum': getattr(opts, 'enable_sample_curriculum', False),
        'fill_strategy': getattr(opts, 'fill_strategy', 'mean'),
        'noise_scale': getattr(opts, 'noise_scale', 0.01),
    }
    
    if masking_strategy['enable_sparsity_curriculum']:
        masking_strategy['sparsity_curriculum'] = {
            'initial_obs_rate': getattr(opts, 'initial_obs_rate', 1.0),
            'sparsity_curriculum_kimg': getattr(opts, 'sparsity_curriculum_kimg', 0),
            'sparsity_schedule': getattr(opts, 'sparsity_schedule', 'linear'),
        }

    if masking_strategy['enable_sample_curriculum']:
        masking_strategy['sample_curriculum'] = {
            'initial_sample_rate': getattr(opts, 'initial_sample_rate', 1.0),
            'sample_curriculum_kimg': getattr(opts, 'sample_curriculum_kimg', 0),
            'sample_schedule': getattr(opts, 'sample_schedule', 'linear'),
        }
    
    # breakpoint()
    # print(f"Masking strategy: {masking_strategy}")
    return masking_strategy
#----------------------------------------------------------------------------
# Parse a comma separated list of numbers or ranges and return a list of ints.
# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]

def parse_int_list(s):
    if isinstance(s, list): return s
    ranges = []
    range_re = re.compile(r'^(\d+)-(\d+)$')
    for p in s.split(','):
        m = range_re.match(p)
        if m:
            ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
        else:
            ranges.append(int(p))
    return ranges

#----------------------------------------------------------------------------

@click.command()

# New configuration options
@click.option('--config',        help='Path to the configuration file', metavar='PATH',             type=str, required=True)
@click.option('--user',          help='Username for path configuration', metavar='STR',             type=str, required=True)
@click.option('--server',        help='Server name ', metavar='STR',                  type=click.Choice(['server']), required=True)
@click.option('--num_gpus',      help='Number of GPUs to use', metavar='INT',                      type=int, required=True)

# Main options.
@click.option('--outdir',        help='Where to save the results', metavar='DIR',                   type=str)
@click.option('--data',          help='Path to the dataset', metavar='STR',                         type=str)
@click.option('--dataset',       help='Dataset name', metavar='STR',                                type=str)
@click.option('--training_mode', help='Training mode: conditional or unified', metavar='STR',       type=click.Choice(['conditional', 'unified']), default='conditional', show_default=True)
@click.option('--mode',          help='PDE modeling mode: forward or inverse', metavar='STR',       type=str, default="forward", show_default=True)
@click.option('--resolution',    help='Spatial resolution', metavar='INT',                          type=int)
@click.option('--cond',          help='Train class-conditional model', metavar='BOOL',              type=bool, default=False, show_default=True)
@click.option('--arch',          help='Network architecture', metavar='ddpmpp|ncsnpp|adm',          type=click.Choice(['ddpmpp', 'ncsnpp', 'adm']), default='ddpmpp', show_default=True)
@click.option('--precond',       help='Preconditioning & loss function', metavar='vp|ve|edm',       type=click.Choice(['vp', 've', 'edm']), default='edm', show_default=True)
@click.option('--model_type',    help='Model Type for Diffusion Sampling', metavar='SongUNO|SongUNet|SongUNOResidual', type=click.Choice(['SongUNO', 'SongUNet', 'SongUNOResidual']), default='SongUNet', show_default=True)
@click.option('--sigma_data',    help='Sigma data for EDM loss', metavar='FLOAT',                   type=click.FloatRange(min=0, min_open=True), default=0.5, show_default=True)
@click.option('--train_downsample', help='Train on downsampled data', metavar='BOOL',               type=bool, default=False, show_default=True)
@click.option('--normalizer',    help='Type of normalizer to use', metavar='UnitGaussian|ScaledGaussian|ScaledGaussian2',type=click.Choice(['UnitGaussian', 'ScaledGaussian', 'ScaledGaussian2']), default='UnitGaussian', show_default=True)

# COND DIFF EDM RESIDUAL Specific options
@click.option('--pde-residual-mode', help='Mode for incorporating residuals', metavar='concat|freq_mag|freq_complex|freq_attn_real|freq_attn|hybrid|residual_guided_freq_mask|no_pde_res', type=click.Choice(['concat', 'freq_mag', 'freq_complex', 'freq_attn_real', 'freq_attn', 'hybrid', 'residual_guided_freq_mask', 'no_pde_res']), default='concat', show_default=True)
@click.option('--pde-residual-mode-secondary', help='Secondary mode for incorporating residuals with concat', metavar='concat|freq_mag|freq_complex|freq_attn_real|freq_attn|hybrid', type=click.Choice(['concat', 'freq_mag', 'freq_complex', 'freq_attn_real', 'freq_attn', 'hybrid']), default='concat', show_default=True)
@click.option('--pde-residual-step-mode', help='Step mode for training', metavar='one_step|two_step|iterative', type=click.Choice(['one_step', 'two_step']), default='two_step', show_default=True)
@click.option('--pde-residual-gate-type', help='Gate type for PDE residual: scalar|spatial',                        type=click.Choice(['scalar','spatial']), default='scalar', show_default=True)
@click.option('--use-alpha', help='Use learnable alpha for pde res attn', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--use-gating', help='Use gating for PDE residual', metavar='BOOL', type=bool, default=True, show_default=True)
@click.option('--gating-mode', help='Gating mode for PDE residual: 1|2|3|4|5', metavar='INT', type=click.IntRange(min=1, max=5), default=1, show_default=True)
@click.option('--do_conjugate', help='Use conjugate symmetry in freq_complex mode', metavar='BOOL', type=bool, default=True, show_default=True)
@click.option('--normalize-pde-residual', help='Normalize PDE residual', metavar='BOOL', type=bool, default=False, show_default=True   )
@click.option('--spectral-inject-pos', help='Position for spectral injection: pre|post', metavar='STR', type=click.Choice(['pre', 'post']), default='post', show_default=True   )
@click.option('--spatial-film-pos', help='Position for spatial FILM: none|pre|post', metavar='STR', type=click.Choice(['none', 'pre', 'post']), default='pre', show_default=True   )    
@click.option('--guided-pde-residual-mode', help='Use ground truth to guide PDE residual computation in unified mode', metavar='BOOL', type=bool, default=False, show_default=True)

@click.option('--use-sparse-conditioning', help='Enable sparse conditioning with masks', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--unified-task-prob-full-fwd', help='Probability of full forward task (a->u) in unified mode', type=click.FloatRange(min=0, max=1), default=0.25, show_default=True)
@click.option('--unified-task-prob-full-inv', help='Probability of full inverse task (u->a) in unified mode', type=click.FloatRange(min=0, max=1), default=0.25, show_default=True)
@click.option('--unified-task-prob-sparse-fwd', help='Probability of sparse forward task (a->u) in unified mode', type=click.FloatRange(min=0, max=1), default=0.2, show_default=True)
@click.option('--unified-task-prob-sparse-inv', help='Probability of sparse inverse task (u->a) in unified mode', type=click.FloatRange(min=0, max=1), default=0.2, show_default=True)
@click.option('--unified-task-prob-uncond', help='Probability of unconditional task in unified mode', type=click.FloatRange(min=0, max=1), default=0.1, show_default=True)
@click.option('--sparse-obs-range-start', help='Start of observation rate range for sparse tasks', type=click.FloatRange(min=0, max=1), default=0.03, show_default=True)
@click.option('--sparse-obs-range-end', help='End of observation rate range for sparse tasks', type=click.FloatRange(min=0, max=1), default=0.5, show_default=True)


# Hyperparameters.
@click.option('--duration',      help='Training duration', metavar='MIMG',                          type=click.FloatRange(min=0, min_open=True), default=200, show_default=True)
@click.option('--batch',         help='Total batch size', metavar='INT',                            type=click.IntRange(min=1), default=512, show_default=True)
@click.option('--batch-gpu',     help='Limit batch size per GPU', metavar='INT',                    type=click.IntRange(min=1))
@click.option('--lr',            help='Learning rate', metavar='FLOAT',                             type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True)
@click.option('--lr-rampup',     help='Learning raterampup', metavar='FLOAT',                       type=click.FloatRange(min=0, min_open=True), default=10, show_default=True)
@click.option('--weight-decay',  help='Learning rate weight decay', metavar='FLOAT',                type=click.FloatRange(min=0, min_open=False), default=0.0, show_default=True)
@click.option('--ema',           help='EMA half-life', metavar='MIMG',                              type=click.FloatRange(min=0), default=0.5, show_default=True)
@click.option('--dropout',       help='Dropout probability', metavar='FLOAT',                       type=click.FloatRange(min=0, max=1), default=0.13, show_default=True)
@click.option('--augment',       help='Augment probability', metavar='FLOAT',                       type=click.FloatRange(min=0, max=1), default=0.12, show_default=True)
@click.option('--xflip',         help='Enable dataset x-flips', metavar='BOOL',                     type=bool, default=False, show_default=True)
# Validation options.
@click.option('--validate_mode', help='Validate during training', metavar='BOOL',                   type=bool, default=False, show_default=True)
@click.option('--validate_data', help='Path to val dataset', metavar='STR',                         type=str, default=None)

# Architecture-related options
@click.option('--channel-mult',   help='Channel multipliers per resolution', metavar='LIST',        type=parse_int_list, default=[2, 2, 2], show_default=True)
@click.option('--num-blocks',     help='Number of blocks per layer', metavar='INT',                 type=click.IntRange(min=1), default=4, show_default=True)
@click.option('--model-channels', help='Base number of channels for the model', metavar='INT', type=click.IntRange(min=1), default=128, show_default=True)
# UNO based architecture's specific options
@click.option('--fmult',          help='Frequency multiplier to determine modes', metavar='FLOAT',  type=click.FloatRange(0, 1, min_open=True), default=1.0, show_default=True)
@click.option('--spectral-conv',  help='Type of spectral convolution', metavar='STR',               type=click.Choice(['standard', 'tucker']), default='standard', show_default=True)
@click.option('--rank',           help='Rank for factorisation of weight matrices', metavar='FLOAT',type=click.FloatRange(0, 1, min_open=True), default=1.0, show_default=True)
@click.option('--rbf-scale',      help='RBF kernel scale for noise sampler', metavar='FLOAT',        type=click.FloatRange(min=0, min_open=True), default=0.05, show_default=True)
@click.option('--noise_src',     help='Noise source Gaussian vs Gaussian Random Field: gauss|grf', type=click.Choice(['gauss','grf']), default='grf', show_default=True)

# Performance-related.
@click.option('--fp16',          help='Enable mixed-precision training', metavar='BOOL',            type=bool, default=False, show_default=True)
@click.option('--ls',            help='Loss scaling', metavar='FLOAT',                              type=click.FloatRange(min=0, min_open=True), default=1, show_default=True)
@click.option('--bench',         help='Enable cuDNN benchmarking', metavar='BOOL',                  type=bool, default=True, show_default=True)
@click.option('--cache',         help='Cache dataset in CPU memory', metavar='BOOL',                type=bool, default=True, show_default=True)
@click.option('--workers',       help='DataLoader worker processes', metavar='INT',                 type=click.IntRange(min=1), default=1, show_default=True)
@click.option('--use_fast_math', help='Enable torch.backends.cudnn.allow_tf32 and torch.backends.cuda.matmul.allow_tf32', metavar='BOOL', type=bool, default=True, show_default=True)
# I/O-related.
@click.option('--desc',          help='String to include in result dir name', metavar='STR',        type=str)
@click.option('--nosubdir',      help='Do not create a subdirectory for results',                   is_flag=True)
@click.option('--tick',          help='How often to print progress', metavar='KIMG',                type=click.IntRange(min=1), default=50, show_default=True)
@click.option('--snap',          help='How often to save snapshots', metavar='TICKS',               type=click.IntRange(min=1), default=50, show_default=True)
@click.option('--dump',          help='How often to dump state', metavar='TICKS',                   type=click.IntRange(min=1), default=500, show_default=True)
@click.option('--pde-plot-ticks', help='How often to plot PDE residual', metavar='TICKS',        type=click.IntRange(min=1), default=10, show_default=True)
@click.option('--seed',          help='Random seed  [default: random]', metavar='INT',              type=int)
@click.option('--transfer',      help='Transfer learning from network pickle', metavar='PKL|URL',   type=str)
@click.option('--resume',        help='Resume from previous training state', metavar='PT',          type=str)
@click.option('-n', '--dry-run', help='Print training options and exit',                            is_flag=True)
@click.option('--debug',         help='Print statements for architecture/debug mode',               is_flag=True)
@click.option('--wandb_mode',    help='WandB mode', metavar='online|ofline|disabled',               type=str, default='online', show_default=True)
@click.option('--wandb_project', help='WandB project', metavar='STR',                               type=str, default='PgDNO', show_default=True)
@click.option('--wandb_team',    help='WandB team', metavar='STR',                                  type=str, default='physics-diffusion', show_default=True)

def main(**kwargs):
    """Train diffusion-based generative model using the techniques described in the
    paper "Elucidating the Design Space of Diffusion-Based Generative Models".

    Examples:

    \b
    # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
    """
    opts = dnnlib.EasyDict(kwargs)
    torch.multiprocessing.set_start_method('spawn')
    dist.init()
    opts = load_and_update_config(opts)

    # Initialize config dict.
    c = dnnlib.EasyDict()
    
    # Random seed.
    if opts.seed is not None:
        c.seed = opts.seed
        # Ensure the seed is properly broadcast to all processes
        seed = torch.tensor([opts.seed], device=torch.device('cuda'))
    else:
        # Only generate a random seed on the main process
        seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) if dist.get_rank() == 0 else torch.zeros([], device=torch.device('cuda'))

    # Make sure all processes have the same seed
    torch.distributed.broadcast(seed, src=0)
    c.seed = int(seed.item())  # Ensure we store the exact same seed value in config

    # Initialize wandb
    if dist.get_rank() == 0:
        wandb.init(
            config=opts,
            name=opts.run_name,
            mode=opts.wandb_mode,
            project=opts.wandb_project,
            entity=opts.wandb_team,
        )
        wandb.run.log_code(root=".")

    # if opts.training_mode == 'unified':
    #     opts.mode = 'forward'  # Unified model assumes forward direction for residual calculation for simplicity
    # c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache)
    # c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)
    DatasetClass = get_dataset_class(opts.dataset)
    masking_strategy = build_masking_strategy(opts)
    # Configure dataset_kwargs using the selected class
    c.dataset_kwargs = dnnlib.EasyDict(class_name=f"{DatasetClass.__module__}.{DatasetClass.__name__}", path=opts.data,  # dataset path
        use_labels=opts.cond, pde_direction = opts.mode, resolution=opts.resolution, train_downsample=opts.train_downsample, normalizer=opts.normalizer,
        training_mode=opts.training_mode, use_sparse_conditioning=opts.use_sparse_conditioning, masking_strategy=masking_strategy)
    if opts.training_mode == 'unified':
        unified_regime = {
            'task_probs': {
                'full_fwd': opts.unified_task_prob_full_fwd,
                'full_inv': opts.unified_task_prob_full_inv,
                'sparse_fwd': opts.unified_task_prob_sparse_fwd,
                'sparse_inv': opts.unified_task_prob_sparse_inv,
                'uncond': opts.unified_task_prob_uncond,
            },
            'sparse_obs_rates': {
                'sparse_fwd': [opts.sparse_obs_range_start, opts.sparse_obs_range_end],
                'sparse_inv': [opts.sparse_obs_range_start, opts.sparse_obs_range_end],
            }
        }
        c.dataset_kwargs.update(unified_regime=unified_regime)
        
    c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)

    # Validate dataset options.
    try:
        dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)
        dataset_name = dataset_obj.name
        normalizer_data = dataset_obj.get_normalizers()
        # c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution
        # c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size
        # if opts.cond and not dataset_obj.has_labels:
        #     raise click.ClickException('--cond=True requires labels specified in dataset.json')
        del dataset_obj # conserve memory
    except IOError as err:
        raise click.ClickException(f'--data: {err}')
    
    # breakpoint()

    c.network_kwargs = dnnlib.EasyDict()
    c.loss_kwargs = dnnlib.EasyDict()
    c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8, weight_decay=opts.weight_decay)
    c.sampler_kwargs = dnnlib.EasyDict(class_name="training.noise_samplers.RBFKernel", scale=opts.rbf_scale)

    # Network architecture.
    if opts.arch == 'ddpmpp':
        c.network_kwargs.update(model_type=opts.model_type, embedding_type='positional', encoder_type='standard', decoder_type='standard')
        c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1])
    elif opts.arch == 'ncsnpp':
        c.network_kwargs.update(model_type=opts.model_type, embedding_type='fourier', encoder_type='residual', decoder_type='standard')
        c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2])
    else:
        assert opts.arch == 'adm'
        c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4])

    # Update `fmult` and `spectral_conv` if model_type is SongUNO
    if opts.model_type == 'SongUNO':
        c.network_kwargs.update(fmult=opts.fmult, spectral_conv=opts.spectral_conv, rank=opts.rank)
    if opts.model_type == 'SongUNOResidual':
        c.network_kwargs.update(fmult=opts.fmult, spectral_conv=opts.spectral_conv, training_mode=opts.training_mode,
                                pde_residual_mode=opts.pde_residual_mode, pde_residual_mode_secondary=opts.pde_residual_mode_secondary,
                                residual_gate_type=opts.pde_residual_gate_type, use_alpha=opts.use_alpha, gating_mode=opts.gating_mode,
                                use_gating=opts.use_gating, do_conjugate=opts.do_conjugate,
                                rank=opts.rank, use_sparse_conditioning=opts.use_sparse_conditioning,
                                spectral_inject_pos=opts.spectral_inject_pos, spatial_film_pos=opts.spatial_film_pos)

    # Network options.
    if opts.model_channels is not None:
        c.network_kwargs.model_channels = opts.model_channels
    if opts.channel_mult is not None:
        c.network_kwargs.channel_mult = opts.channel_mult
    if opts.num_blocks is not None:
        c.network_kwargs.num_blocks = opts.num_blocks
    c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16)

    # Preconditioning & loss function.
    if opts.precond == 'vp':
        c.network_kwargs.class_name = 'training.networks.VPPrecond'
        c.loss_kwargs.class_name = 'training.loss.VPLoss'
    elif opts.precond == 've':
        c.network_kwargs.class_name = 'training.networks.VEPrecond'
        c.loss_kwargs.class_name = 'training.loss.VELoss'
    else:
        assert opts.precond == 'edm'
        c.network_kwargs.class_name = 'training.networks.EDMPrecond'
        # c.loss_kwargs.class_name = 'training.loss.EDMLoss'
        c.loss_kwargs.class_name = 'training.loss.EDMLossWrapper'
        c.loss_kwargs.update(loss_type="edm")
    
    if opts.model_type == 'SongUNOResidual':
        c.network_kwargs.class_name = 'training.networks.EDMPrecond'
        c.loss_kwargs.class_name = 'training.loss.EDMLossWrapper'
        c.loss_kwargs.update(loss_type="edm_residual",  training_mode=opts.training_mode, pde_direction=opts.mode, pde_residual_step_mode=opts.pde_residual_step_mode,
                             noise_src=opts.noise_src, normalize_pde_residual=opts.normalize_pde_residual, guided_pde_residual_mode=opts.guided_pde_residual_mode,)
        c.loss_kwargs.update(sigma_data=opts.sigma_data)

    # Training options.
    c.total_kimg = max(int(opts.duration * 1000), 1)
    c.ema_halflife_kimg = int(opts.ema * 1000)
    c.lr_rampup_kimg = int(opts.lr_rampup * 1000)
    c.use_fast_math = opts.use_fast_math
    c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu)
    c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench)
    c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump, pde_plot_ticks=opts.pde_plot_ticks)
    c.update(validate_mode=opts.validate_mode, validate_data=opts.validate_data)

    # # Random seed.
    # if opts.seed is not None:
    #     c.seed = opts.seed
    #     # Ensure the seed is properly broadcast to all processes
    #     seed = torch.tensor([opts.seed], device=torch.device('cuda'))
    # else:
    #     # Only generate a random seed on the main process
    #     seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) if dist.get_rank() == 0 else torch.zeros([], device=torch.device('cuda'))

    # # Make sure all processes have the same seed
    # torch.distributed.broadcast(seed, src=0)
    # c.seed = int(seed.item())  # Ensure we store the exact same seed value in config

    # Transfer learning and resume.
    if opts.transfer is not None:
        if opts.resume is not None:
            raise click.ClickException('--transfer and --resume cannot be specified at the same time')
        c.resume_pkl = opts.transfer
        c.ema_rampup_ratio = None
    elif opts.resume is not None:
        match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume))
        # print(match)
        if not match or not os.path.isfile(opts.resume):
            # print(os.path.isfile(opts.resume))
            raise click.ClickException('--resume must point to training-state-*.pt from a previous training run')
        c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl')
        c.resume_kimg = int(match.group(1))
        c.resume_state_dump = opts.resume
    if opts.debug:
        c.debug = True
        c.loss_kwargs.update(debug=True)
        c.network_kwargs.update(debug=True)
    else:
        c.debug=False
    

    # Description string.
    cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond'
    dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32'
    desc = f'{str(dataset_name):s}-{cond_str:s}-{opts.arch:s}-{opts.model_type}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}'
    if opts.desc is not None:
        desc += f'-{opts.desc}'

    # Pick output directory.
    if dist.get_rank() != 0:
        c.run_dir = None
    elif opts.nosubdir:
        c.run_dir = opts.outdir
    else:
        prev_run_dirs = []
        if os.path.isdir(opts.outdir):
            prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))]
        prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
        prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
        cur_run_id = max(prev_run_ids, default=-1) + 1
        c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}')
        assert not os.path.exists(c.run_dir)

    # Print options.
    dist.print0()
    dist.print0('Training options:')
    dist.print0(json.dumps(c, indent=2))
    dist.print0()
    dist.print0(f'Output directory:        {c.run_dir}')
    dist.print0(f'Dataset path:            {c.dataset_kwargs.path}')
    dist.print0(f'Class-conditional:       {c.dataset_kwargs.use_labels}')
    dist.print0(f'Network architecture:    {opts.arch}')
    dist.print0(f'Preconditioning & loss:  {opts.precond}')
    dist.print0(f'Number of GPUs:          {dist.get_world_size()}')
    dist.print0(f'Batch size:              {c.batch_size}')
    dist.print0(f'Mixed-precision:         {c.network_kwargs.use_fp16}')
    dist.print0(f'Debug mode:              {c.debug}')
    dist.print0()

    # Dry run?
    if opts.dry_run:
        dist.print0('Dry run; exiting.')
        return

    # Create output directory.
    dist.print0('Creating output directory...')
    if dist.get_rank() == 0:
        os.makedirs(c.run_dir, exist_ok=True)
        with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
            json.dump(c, f, indent=2)

        normalizer_path = os.path.join(c.run_dir, "normalizers.npy")
        torch.save(normalizer_data, normalizer_path)
        print(f"Normalization parameters saved to {normalizer_path}")
    
        dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)

    # Train.
    training_loop.training_loop(**c)

    if dist.get_rank() == 0:
        wandb.finish()

#----------------------------------------------------------------------------

if __name__ == "__main__":
    main()

#----------------------------------------------------------------------------
