# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Main training loop."""

import os
import time
import copy
import json
import pickle
import psutil
import wandb
import random
import numpy as np
import torch
import dnnlib
from training.datasets import *
from torch_utils import distributed as dist
from torch_utils import training_stats
from torch_utils import misc
from einops import rearrange
from evaluate_checkpoints import evaluate_checkpoint
from plot_loss_curves import plot_loss_curve
from training.utils import analyze_uno, analyze_spectralconv2d
from training.pde_res_viz_utils import PDEResidualTracker   
import subprocess
#----------------------------------------------------------------------------

def training_loop(
    run_dir             = '.',      # Output directory.
    dataset_kwargs      = {},       # Options for training set.
    data_loader_kwargs  = {},       # Options for torch.utils.data.DataLoader.
    network_kwargs      = {},       # Options for model and preconditioning.
    loss_kwargs         = {},       # Options for loss function.
    sampler_kwargs      = {},       # Options for noise sampler.
    optimizer_kwargs    = {},       # Options for optimizer.
    augment_kwargs      = None,     # Options for augmentation pipeline, None = disable.
    seed                = 0,        # Global random seed.
    batch_size          = 512,      # Total batch size for one training iteration.
    batch_gpu           = None,     # Limit batch size per GPU, None = no limit.
    total_kimg          = 200000,   # Training duration, measured in thousands of training images.
    ema_halflife_kimg   = 500,      # Half-life of the exponential moving average (EMA) of model weights.
    ema_rampup_ratio    = 0.05,     # EMA ramp-up coefficient, None = no rampup.
    lr_rampup_kimg      = 10000,    # Learning rate ramp-up duration.
    loss_scaling        = 1,        # Loss scaling factor for reducing FP16 under/overflows.
    kimg_per_tick       = 50,       # Interval of progress prints.
    snapshot_ticks      = 50,       # How often to save network snapshots, None = disable. was 50
    state_dump_ticks    = 500,      # How often to dump training state, None = disable. was 500
    resume_pkl          = None,     # Start from the given network snapshot, None = random initialization.
    resume_state_dump   = None,     # Start from the given training state, None = reset training state.
    resume_kimg         = 0,        # Start from the given training progress.
    cudnn_benchmark     = True,     # Enable torch.backends.cudnn.benchmark?
    device              = torch.device('cuda'),
    use_fast_math       = True,     # Enable torch.backends.cudnn.allow_tf32 and torch.backends.cuda.matmul.allow_tf32?
    validate_mode       = False,
    validate_data       = None,
    debug               = False,
    pde_plot_ticks      = 10,       # How often to plot PDE residual, None = disable.
):
    # Initialize.
    start_time = time.time()
    
    # Set deterministic seeds correctly
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    # os.environ['NCCL_TIMEOUT'] = '1800' 
    os.environ['NCCL_DEBUG'] = 'DETAIL'
    os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL'
    # For multi-GPU training, ensure different processes have different but deterministic seeds
    if dist.get_world_size() > 1:
        process_seed = seed * (dist.get_rank() + 1)
        torch.manual_seed(process_seed)
        np.random.seed(process_seed)
        random.seed(process_seed)
        print(f"Process {dist.get_rank()} using seed: {process_seed}")
    
    torch.backends.cudnn.benchmark = cudnn_benchmark
    torch.backends.cudnn.allow_tf32 = use_fast_math
    torch.backends.cuda.matmul.allow_tf32 = use_fast_math
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = use_fast_math


    # Select batch size per GPU.
    batch_gpu_total = batch_size // dist.get_world_size()
    if batch_gpu is None or batch_gpu > batch_gpu_total:
        batch_gpu = batch_gpu_total
    num_accumulation_rounds = batch_gpu_total // batch_gpu
    assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()

    # Load dataset.
    dist.print0('Loading dataset...')
    dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
    dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed, shuffle=False)
    dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs))
    # dataset_obj.update_normalizer_path(run_dir)
    # dataset_length = len(dataset_obj)
    # # Randomly sample 15 indices from the dataset
    # random_indices = random.sample(range(dataset_length), 15)
    if hasattr(dataset_obj, 'set_training'):
        dataset_obj.set_training(True)
        print("Set dataset to training mode")
        print("Dataset obj training mode:", dataset_obj.training)
    if hasattr(dataset_obj, 'set_process_seed'):
        process_seed = seed * (dist.get_rank() + 1) if dist.get_world_size() > 1 else seed
        dataset_obj.set_process_seed(process_seed, rank=dist.get_rank())
    # breakpoint()

    # Load validation dataset once, if validation mode is enabled
    if validate_mode and dist.get_rank() == 0:
        dist.print0('Loading validation dataset...')
        val_dataset_kwargs = dataset_kwargs.copy()
        val_dataset_kwargs['path'] = validate_data
        dataset_obj_val = dnnlib.util.construct_class_by_name(**val_dataset_kwargs)
        dataset_iterator_val = iter(
            torch.utils.data.DataLoader(
                dataset=dataset_obj_val, 
                batch_size=batch_gpu,  
                **data_loader_kwargs
            )
        )
        dist.print0('Validation dataset loaded successfully!')
    else:
        dataset_obj_val = None
        dataset_iterator_val = None


    # Construct network.
    dist.print0('Constructing network...')
    interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
    print("Interface kwargs set:", interface_kwargs)
    # breakpoint()
    net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
    print("Network class constructed.")
    net.train().requires_grad_(True).to(device)
    print("done construction network")

    dist.print0("Number of params: {}".format(misc.count_parameters(net)))
    torch.cuda.empty_cache()
    # Commenting it as it runs into an error
    # Print network statistics.
    if debug:
        if dist.get_rank() == 0:
            with torch.no_grad():
                images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device)
                sigma = torch.ones([batch_gpu], device=device)
                #modify labels to not be classes but proper images
                labels = torch.zeros([batch_gpu, net.label_dim, net.img_resolution, net.img_resolution], device=device)
                # print(images.shape)
                # print(labels.shape)
                misc.print_module_summary(net, [images, sigma, labels], max_nesting=2)
            torch.cuda.empty_cache()
            #### adding check to see if model has modules with complex data types, causes ddp error
            for name, param in net.named_parameters():
                if param.dtype.is_complex:
                    print(f"Complex dtype found in parameter: {name} - dtype: {param.dtype}")
                # else:
                #     print(f"Parameter: {name} - dtype: {param.dtype}")

            for name, buffer in net.named_buffers():
                if buffer.dtype.is_complex:
                    print(f"Complex dtype found in buffer: {name} - dtype: {buffer.dtype}")
                # else:
                #     print(f"Buffer: {name} - dtype: {buffer.dtype}")

        # Analyze modes and parameters
        analyze_uno(net.model)
        analyze_spectralconv2d(net.model)   

    # breakpoint()
    dist.print0('Setting up loss function...')
    # Construct loss function.
    sampler_kwargs.n_in = dataset_obj.num_channels
    sampler_kwargs.Ln1 = dataset_obj.resolution
    sampler_kwargs.Ln2 = dataset_obj.resolution
    sampler_kwargs.device = device
    loss_kwargs.sampler = dnnlib.util.construct_class_by_name(**sampler_kwargs)
    if dataset_obj.pde_loss_function is not None:
        loss_kwargs['pde_loss_fn'] = dataset_obj.pde_loss_function
    # Pass dataset object to loss function for denormalization
    loss_kwargs['dataset_obj'] = dataset_obj
    loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss

    # Setup PDE residual tracker if using residual loss
    pde_res_tracker = None
    if hasattr(loss_fn, 'loss_fn') and hasattr(loss_fn.loss_fn, 'set_pde_res_tracker'):
        pde_res_tracker = PDEResidualTracker(
            max_samples=10, #num samples per batch
            log_frequency=10, #in kimg?
            detailed_log_frequency=10,  # Sigma vs residual plots every 100 steps
        )
        loss_fn.loss_fn.set_pde_res_tracker(pde_res_tracker)
        dist.print0('PDE residual tracker initialized')

    # Setup optimizer.
    dist.print0('Setting up optimizer...')
    optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer
    augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe
    ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[dist.get_rank()], broadcast_buffers=False, find_unused_parameters=True)
    ema = copy.deepcopy(net).eval().requires_grad_(False)

    # Resume training from previous snapshot.
    if resume_pkl is not None:
        dist.print0(f'Loading network weights from "{resume_pkl}"...')
        if dist.get_rank() != 0:
            torch.distributed.barrier() # rank 0 goes first
        with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f:
            data = pickle.load(f)
        if dist.get_rank() == 0:
            torch.distributed.barrier() # other ranks follow
        misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False)
        misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False)
        del data # conserve memory
    if resume_state_dump:
        dist.print0(f'Loading training state from "{resume_state_dump}"...')
        data = torch.load(resume_state_dump, map_location=torch.device('cpu'))
        misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True)
        optimizer.load_state_dict(data['optimizer_state'])
        del data # conserve memory

    # Train.
    dist.print0(f'Training for {total_kimg} kimg...')
    dist.print0()
    cur_nimg = resume_kimg * 1000
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    dist.update_progress(cur_nimg // 1000, total_kimg)
    stats_jsonl = None
    while True:
        # Update current step in loss function
        if hasattr(loss_fn, 'loss_fn') and hasattr(loss_fn.loss_fn, 'set_current_step'):
            loss_fn.loss_fn.set_current_step(cur_nimg // 1000)
        
        # if dataset_obj.current_kimg != cur_nimg // 1000:
        #     print("current kimg:", cur_nimg // 1000)
        #     print("dataset current kimg:", dataset_obj.current_kimg)
            # breakpoint()    
        # Accumulate gradients.
        optimizer.zero_grad(set_to_none=True)
        for round_idx in range(num_accumulation_rounds):
            with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
                batch_data = next(dataset_iterator)

                # Handle both masked and unmasked data
                if len(batch_data) == 3:  # images, labels, masks
                    images, labels, masks = batch_data
                    # original_labels = labels.clone()
                    # masks = masks.to(device).to(torch.float32)
                else:  # images, labels (no masks)
                    images, labels = batch_data
                    masks = None
                images = images.to(device).to(torch.float32)
                labels = labels.to(device).to(torch.float32) if labels is not None else None
                masks = masks.to(device).to(torch.float32) if masks is not None else None
                # print("Dataset Training mode:", dataset_obj.training_mode)
                # print("images shape: ", images.shape)
                # print("labels shape: ", labels.shape)
                # print("Masks shape: ", masks.shape if masks is not None else "N/A")
                # print("Images min/max:", images.min().item(), images.max().item(), images.mean().item())
                # if images.shape[1] == 2:
                #     print("Images channel 0 min/max:", images[:,0].min().item(), images[:,0].max().item(), images[:,0].mean().item())
                #     print("Images channel 1 min/max:", images[:,1].min().item(), images[:,1].max().item(), images[:,1].mean().item())
                # print("Labels min/max:", labels.min().item(), labels.max().item(), labels.mean().item())
                # print("Masks min/max:", masks.min().item() if masks is not None else "N/A", masks.max().item() if masks is not None else "N/A", masks.mean().item() if masks is not None else "N/A")
                # if dataset_obj.training_mode == 'conditional':
                #     images_denorm = dataset_obj.denorm_output(images)
                #     labels_denorm = dataset_obj.denorm_input(labels)
                #     print("Images denorm min/max:", images_denorm.min().item(), images_denorm.max().item(), images_denorm.mean().item())
                #     print("Labels denorm min/max:", labels_denorm.min().item(), labels_denorm.max().item(), labels_denorm.mean().item())
                # if dataset_obj.training_mode == 'unified':
                #     images_denorm = dataset_obj.denorm_tensor(images)
                #     print("Images denorm min/max:", images_denorm.min().item(), images_denorm.max().item(), images_denorm.mean().item())
                #     print("Images channel 0 denorm min/max:", images_denorm[:,0].min().item(), images_denorm[:,0].max().item(), images_denorm[:,0].mean().item())
                #     print("Images channel 1 denorm min/max:", images_denorm[:,1].min().item(), images_denorm[:,1].max().item(), images_denorm[:,1].mean().item())
                # #
                # breakpoint()
                # debug_masking = False
                # if dataset_obj.current_kimg != cur_nimg // 1000:
                #     debug_masking = True
                #     print("Labels min/max:", labels.min().item(), labels.max().item(), labels.mean().item())
                if dataset_obj.training_mode == 'conditional' and dataset_obj.use_sparse_conditioning:
                    labels, masks = dataset_obj.apply_masking(labels, cur_nimg // 1000)
                    labels, masks = labels.to(device), masks.to(device)
                # if debug_masking: 
                #     print("Applied masking to labels")
                #     print("Labels shape after masking (if applied): ", labels.shape)
                #     print("Masks shape (if applicable): ", masks.shape if masks is not None else "N/A")
                #     print("Labels min/max:", labels.min().item(), labels.max().item(), labels.mean().item())
                # print("mask shape (if applicable): ", masks.shape if masks is not None else "N/A")
                # print("Masks min/max:", masks.min().item() if masks is not None else "N/A", masks.max().item() if masks is not None else "N/A", masks.mean().item() if masks is not None else "N/A")
                # print("Mask channel 0 min/max:", masks[:,0].min().item() if masks is not None else "N/A", masks[:,0].max().item() if masks is not None else "N/A", masks[:,0].mean().item() if masks is not None else "N/A")
                # print("Mask channel 1 min/max:", masks[:,1].min().item() if masks is not None else "N/A", masks[:,1].max().item() if masks is not None else "N/A", masks[:,1].mean().item() if masks is not None else "N/A")
                # Add this after dataset_obj.apply_masking() in training_loop.py
                if dataset_obj.training_mode == 'conditional' and dataset_obj.use_sparse_conditioning and dist.get_world_size() > 1:
                    # kimg = cur_nimg // 1000
                    # Create tensor with current curriculum values from rank 0
                    # if dist.get_rank() == 0:
                        # This updates the dataset_obj's internal state on rank 0
                        # dataset_obj.update_curriculum(kimg)
                    curriculum_tensor = torch.tensor([
                            dataset_obj.current_kimg,
                            dataset_obj.current_obs_rate,
                            dataset_obj.current_sample_rate
                        ], device=device)
                    # Broadcast curriculum values from rank 0 to all processes
                    torch.distributed.broadcast(curriculum_tensor, src=0)
            
                    # Update local values on all ranks
                    dataset_obj.current_kimg = curriculum_tensor[0].item()
                    dataset_obj.current_obs_rate = curriculum_tensor[1].item()
                    dataset_obj.current_sample_rate = curriculum_tensor[2].item()
                if dataset_obj.training_mode == 'conditional':  #unified model already has channel dimensions
                    images = rearrange(images, 'bs h w -> bs 1 h w')
                    if labels is not None: #labels is dummy in unified at this stage
                        labels = rearrange(labels, 'bs h w -> bs 1 h w')
                    if masks is not None:
                            masks = rearrange(masks, 'b h w -> b 1 h w')
                
                # original_labels = rearrange(original_labels, 'bs h w -> bs 1 h w')
                # print("images shape: ", images.shape)
                # print("labels shape: ", labels.shape)
                # breakpoint()
                # images = images.to(device).to(torch.float32) # / 127.5 - 1
                # labels = labels.to(device).to(torch.float32) 
                # masks  = masks.to(device).to(torch.float32) if masks is not None else None
                # original_labels = original_labels.to(device).to(torch.float32) 
                # print("going into loss")
                # print("Shape of images going into loss:", images.shape)
                # print("Shape of labels going into loss:", labels.shape)
                loss = loss_fn(net=ddp, images=images, labels=labels, masks=masks, augment_pipe=augment_pipe)
                training_stats.report('Loss/loss', loss)
                # print("back from loss, going to backward")
                torch.cuda.synchronize() 
                # print("loss backward staring")
                loss.sum().mul(loss_scaling / batch_gpu_total).backward()
                # print("back from loss")
                if dataset_obj.training_mode == 'conditional' and dataset_obj.use_sparse_conditioning and masks is not None:
                    # Calculate the mean of each mask in the batch along spatial dimensions
                    mask_means_per_sample = masks.mean(dim=[1, 2, 3]) # Shape: [batch_gpu]
                    
                    # Identify which samples were actually masked (mean < 1.0)
                    is_masked_sample = mask_means_per_sample < 1.0
                    num_masked_samples = is_masked_sample.sum().item()

                    actual_sample_rate = num_masked_samples / masks.shape[0]
                    training_stats.report('Mask/actual_masked_sample_rate', actual_sample_rate)
                    # print("Actual sample rate in batch:", actual_sample_rate)

                    if num_masked_samples > 0:
                        obs_rate_on_masked = mask_means_per_sample[is_masked_sample].mean().item()
                        training_stats.report('Mask/actual_obs_rate_on_masked', obs_rate_on_masked)
                        # print("Average observation rate on masked samples:", obs_rate_on_masked)
                # Log unified task distribution.
                if dataset_obj.training_mode == 'unified' and masks is not None:
                    m_a = masks[:, 0]  # Batch of m_a masks
                    m_u = masks[:, 1]  # Batch of m_u masks
                    
                    # Calculate sums for each sample in the batch
                    m_a_sum = m_a.view(m_a.shape[0], -1).sum(dim=1)
                    m_u_sum = m_u.view(m_u.shape[0], -1).sum(dim=1)
                    
                    # Total number of pixels in a single mask
                    num_pixels = m_a.shape[1] * m_a.shape[2]
                    
                    # Identify task type for each sample
                    is_full_fwd = (m_a_sum == num_pixels) & (m_u_sum == 0)
                    is_full_inv = (m_a_sum == 0) & (m_u_sum == num_pixels)
                    is_uncond = (m_a_sum == 0) & (m_u_sum == 0)
                    is_sparse_fwd = (m_a_sum > 0) & (m_a_sum < num_pixels) & (m_u_sum == 0)
                    is_sparse_inv = (m_a_sum == 0) & (m_u_sum > 0) & (m_u_sum < num_pixels)
                    
                    # Calculate percentage of each task in the batch
                    batch_size_float = masks.shape[0]
                    training_stats.report('UnifiedTasks/full_fwd_pct', is_full_fwd.sum() / batch_size_float)
                    training_stats.report('UnifiedTasks/full_inv_pct', is_full_inv.sum() / batch_size_float)
                    training_stats.report('UnifiedTasks/uncond_pct', is_uncond.sum() / batch_size_float)
                    training_stats.report('UnifiedTasks/sparse_fwd_pct', is_sparse_fwd.sum() / batch_size_float)
                    training_stats.report('UnifiedTasks/sparse_inv_pct', is_sparse_inv.sum() / batch_size_float)

                    # Calculate and log average observation rate for sparse tasks
                    if is_sparse_fwd.any():
                        sparse_fwd_obs_rate = m_a[is_sparse_fwd].mean()
                        training_stats.report('UnifiedTasks/sparse_fwd_obs_rate', sparse_fwd_obs_rate)
                    
                    if is_sparse_inv.any():
                        sparse_inv_obs_rate = m_u[is_sparse_inv].mean()
                        training_stats.report('UnifiedTasks/sparse_inv_obs_rate', sparse_inv_obs_rate)
                # if masks is not None:
                #     mask_ratio = masks.mean().item()
                #     training_stats.report('Mask/batch_observation_ratio', mask_ratio)
                # print("Mask mean:", masks.mean().item() if masks is not None else "N/A")
                # print("Mask std:", masks.std().item() if masks is not None else "N/A")
                # diff = (original_labels != labels).float().sum() / original_labels.numel()
                # training_stats.report('Mask/changed_values_ratio', diff.item())
                # if diff.item() < 0.001:  # If almost no values changed
                #     print("WARNING: Masking doesn't seem to be working - almost no values changed!")
                #     print(f"Original values: min={original_labels.min().item():.4f}, max={original_labels.max().item():.4f}, mean={original_labels.mean().item():.4f}")
                #     print(f"Masked values:   min={labels.min().item():.4f}, max={labels.max().item():.4f}, mean={labels.mean().item():.4f}")
                #     print(f"Mask ratio: {mask_ratio:.4f}, Actual sample rate: {actual_sample_rate:.4f}")
            
                # if num_masked_samples > 0:
                #     # Get index of first masked sample
                #     first_masked_idx = torch.where(is_masked_sample)[0][0].item()
                #     sample_mask = masks[first_masked_idx][0]
                #     masked_pts = (sample_mask < 0.99).nonzero()
                #     if len(masked_pts) > 0:
                #         pt = tuple(masked_pts[0].tolist())
                #         print(f"Example masked point at {pt}:")
                #         # Access with correct channel dimension (0 for the first channel)
                #         print(f"  Original value: {original_labels[first_masked_idx, 0, pt[0], pt[1]].item():.4f}")
                #         print(f"  Masked value:   {labels[first_masked_idx, 0, pt[0], pt[1]].item():.4f}")
                #         # Sample_mask already has the correct dimensions after rearrange
                #         print(f"  Mask value:     {sample_mask[pt[0], pt[1]].item():.4f}")
        # Update weights.
        for g in optimizer.param_groups:
            g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1)
        # for param in net.parameters():
        #     if param.grad is not None:
        #         torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
        # print("stepping optim")
        torch.cuda.synchronize()
        torch.distributed.barrier()
        optimizer.step()
        # breakpoint()
        # Update EMA.
        ema_halflife_nimg = ema_halflife_kimg * 1000
        if ema_rampup_ratio is not None:
            ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
        ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
        for p_ema, p_net in zip(ema.parameters(), net.parameters()):
            p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))
        # print("ema done")

        # Perform maintenance tasks once per tick.
        cur_nimg += batch_size
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line, accumulating the same information in training_stats.
        tick_end_time = time.time()
        fields = []
        fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
        fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
        fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
        fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
        fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
        fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
        fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
        fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
        fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
        torch.cuda.reset_peak_memory_stats()
        dist.print0(' '.join(fields))
        # print("done printing")

        # Check for abort.
        if (not done) and dist.should_stop():
            done = True
            dist.print0()
            dist.print0('Aborting...')
        # #Vizualize results
        # if cur_tick % vizualise_ticks == 0:
        #     cond_norm_gt = dataset_obj.input_data[random_indices]
        #     images_norm_gt = dataset_obj.output_data[random_indices]


        # Save network snapshot.
        if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0):
            data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs))
            for key, value in data.items():
                if isinstance(value, torch.nn.Module):
                    value = copy.deepcopy(value).eval().requires_grad_(False)
                    misc.check_ddp_consistency(value)
                    data[key] = value.cpu()
                del value # conserve memory
            if dist.get_rank() == 0:
                with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f:
                    pickle.dump(data, f)
            del data # conserve memory
            if validate_mode and dist.get_rank() == 0: # Only one process should evaluate
                model_kimg = cur_nimg // 1000
                checkpoint_path = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
                outdir = os.path.join(run_dir, "model_validation")
                os.makedirs(outdir, exist_ok=True)
                # Evaluate model after saving
                print(f"Evaluating model at {model_kimg} kimg...")
                # dataset_kwargs['path'] = validate_data 
                # dataset_obj_val = dnnlib.util.construct_class_by_name(**dataset_kwargs) 
                # dataset_iterator_val = iter(
                #     torch.utils.data.DataLoader(
                #         dataset=dataset_obj_val, 
                #         batch_size=batch_gpu,  
                #         **data_loader_kwargs
                #     )
                # )
                # print("Loaded validation data")
                # Call evaluate_checkpoint
                # evaluate_checkpoint(
                #     net=ema, 
                #     device=device, 
                #     batch_size=batch_gpu, 
                #     dataset_obj=dataset_obj_val, 
                #     dataset_iterator=dataset_iterator_val, 
                #     outdir=outdir, 
                #     model_kimg=model_kimg, 
                #     sampler_kwargs={},  
                #     plot=True, 
                #     output_mat=False, 
                #     viz_indices=None,
                #     debug=debug,
                # )
                eval_command = [
                "python", "evaluate_checkpoints.py",  
                "--checkpoint_dir", run_dir,
                "--outdir", outdir,
                "--data", validate_data,  
                "--batch", str(batch_gpu),
                "--kimg_intervals", str(model_kimg)
                ]

                # Run evaluation in a separate process (Non-blocking)
                subprocess.Popen(eval_command)
            # print("saving snapshots done")
            torch.distributed.barrier()
            # print("Subprocess launched! Resuming training...")
            # print("barrier out")


        # Save full dump of the training state.
        if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
            torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'))
            
        # Log PDE residual plots periodically
        if pde_res_tracker is not None and dist.get_rank() == 0:
            # Main plots every snapshot_ticks
            if cur_tick % pde_plot_ticks == 0:
                pde_res_tracker.log_to_wandb(cur_nimg // 1000)
                pde_res_tracker.create_global_scatter_plot(save_dir=run_dir)
                pde_res_tracker.create_comprehensive_analysis(save_dir=run_dir)
                
                # Log detailed statistics
                stats = pde_res_tracker.get_statistics()
                if stats:
                    wandb_stats = {f"PDE_detailed/{k}": v for k, v in stats.items()}
                    wandb.log(wandb_stats, step=cur_nimg // 1000)
                # Clear history periodically to prevent memory growth
                pde_res_tracker.clear_recent_history()
                # print("PDE residual plots logged.")
        # Update logs.
        training_stats.default_collector.update()
        if dist.get_rank() == 0:
            if stats_jsonl is None:
                stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at')
            log_dict = training_stats.default_collector.as_dict()
            stats_jsonl.write(json.dumps(dict(log_dict, timestamp=time.time())) + '\n')
            stats_jsonl.flush()
            # wandb.log(
            #     {
            #         "loss": log_dict["Loss/loss"]["mean"],
            #         "lr": optimizer.param_groups[0]["lr"],
            #         "tick": cur_tick,
            #         "kimg": cur_nimg // 1000,
            #     },
            #     step=cur_nimg,
            # )
            wandb_dict = {
                "loss": log_dict["Loss/loss"]["mean"],
                "lr": optimizer.param_groups[0]["lr"],
                "tick": cur_tick,
                "kimg": cur_nimg // 1000,
            }
            # Add PDE metrics if available
            pde_keys = ['PDE/residual_mean', 'PDE/residual_max', 'PDE/residual_sum', 'PDE/sigma_mean',
                    'PDE/residual_low_sigma', 'PDE/residual_med_sigma', 'PDE/residual_high_sigma']
            for key in pde_keys:
                if key in log_dict:
                    wandb_dict[key] = log_dict[key]["mean"]

            # Add mask metrics if available
            if 'Mask/observation_ratio' in log_dict:
                wandb_dict["mask_curriculum/observation_ratio"] = log_dict["Mask/observation_ratio"]["mean"]
            # mask_curriculum_keys = [
            # 'Mask/observation_ratio',     # Actual observed ratio in batch
            # 'Mask-Curriculum/obs_rate',        # Current sparsity curriculum rate
            # 'Mask-Curriculum/sample_rate',]     # Current sample masking curriculum rate

            # for key in mask_curriculum_keys:
            #     if key in log_dict:
            #         # Convert key to wandb-friendly format
            #         wandb_key = key.lower().replace('/', '_')
            #         wandb_dict[wandb_key] = log_dict[key]["mean"]
            # Add additional curriculum state directly from dataset
            if dataset_obj.training_mode == 'conditional' and dataset_obj.use_sparse_conditioning:
                if hasattr(dataset_obj, 'current_obs_rate'):
                    wandb_dict["mask_curriculum/curriculum_current_obs_rate"] = dataset_obj.current_obs_rate
                if hasattr(dataset_obj, 'current_sample_rate'):
                    wandb_dict["mask_curriculum/curriculum_current_sample_rate"] = dataset_obj.current_sample_rate
                if hasattr(dataset_obj, 'current_kimg'):
                    wandb_dict["mask_curriculum/curriculum_current_kimg"] = dataset_obj.current_kimg
                if 'Mask/observation_ratio' in log_dict:
                    wandb_dict["mask_curriculum/observation_ratio"] = log_dict["Mask/observation_ratio"]["mean"]
                if 'Mask/actual_masked_sample_rate' in log_dict:
                    wandb_dict["mask_curriculum/actual_masked_sample_rate"] = log_dict["Mask/actual_masked_sample_rate"]["mean"]
                if 'Mask/actual_obs_rate_on_masked' in log_dict:
                    wandb_dict["mask_curriculum/actual_obs_rate_on_masked"] = log_dict["Mask/actual_obs_rate_on_masked"]["mean"]
                # Add curriculum progress percentages for easier visualization
                if hasattr(dataset_obj, 'sparsity_curriculum_kimg') and dataset_obj.sparsity_curriculum_kimg > 0:
                    sparsity_progress = min(1.0, (cur_nimg // 1000) / dataset_obj.sparsity_curriculum_kimg)
                    wandb_dict["mask_curriculum/curriculum_sparsity_progress"] = sparsity_progress

                if hasattr(dataset_obj, 'sample_curriculum_kimg') and dataset_obj.sample_curriculum_kimg > 0:
                    sample_progress = min(1.0, (cur_nimg // 1000) / dataset_obj.sample_curriculum_kimg)
                    wandb_dict["mask_curriculum/curriculum_sample_progress"] = sample_progress

            # Add unified task metrics if available
            unified_task_keys = [
                'UnifiedTasks/full_fwd_pct', 'UnifiedTasks/full_inv_pct', 'UnifiedTasks/uncond_pct',
                'UnifiedTasks/sparse_fwd_pct', 'UnifiedTasks/sparse_inv_pct',
                'UnifiedTasks/sparse_fwd_obs_rate', 'UnifiedTasks/sparse_inv_obs_rate'
            ]
            for key in unified_task_keys:
                if key in log_dict:
                    wandb_dict[key] = log_dict[key]["mean"]

            wandb.log(wandb_dict, step=cur_nimg// 1000)
        # print("wandb logged")
        dist.update_progress(cur_nimg // 1000, total_kimg)
        # print("dist progress updated")
        # print("will plot loss curve")
        if dist.get_rank() == 0:  # Ensure only one process calls this
            # print("plotting loss curve")
            plot_loss_curve(os.path.join(run_dir, "stats.jsonl"), save_path=run_dir)
            # print("plotted loss curve")
        
        # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - tick_end_time
        # update permissions #TODO remove when experimentation is over
        # misc.chmod_recursive(run_dir)
        if done:
            break

    # Done.
    dist.print0()
    dist.print0('Exiting...')

#----------------------------------------------------------------------------

