import os
import csv
import time
import copy
import json
import pickle
import random
import dnnlib
import numpy as np
import torch
from torch import autocast
from torch_utils import distributed as dist
from torch_utils import training_stats
from torch_utils import misc
from torch_utils.download_util import check_file_by_key
from models.Plugin import MLP
from models.Plugin_Conv import Plugin_UNetBlock
from models.Traj import CustomVectorModel, load_model
from tqdm import tqdm
import copy
from models.lora_custom_layers import add_lora_to_custom_model
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

def plot_loss_curves(epoch_loss_means, epoch_loss_stds, 
                     test_epoch_loss_means, test_epoch_loss_stds,
                     save_path, title="Loss Curves with Std Deviation"):
    plt.rcParams["axes.unicode_minus"] = False
    epochs = range(1, len(epoch_loss_means) + 1)
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, epoch_loss_means, label='train loss', color='blue', linewidth=2)
    plt.fill_between(epochs, 
                    np.array(epoch_loss_means) - np.array(epoch_loss_stds),
                    np.array(epoch_loss_means) + np.array(epoch_loss_stds),
                    color='blue', alpha=0.2, label='train loss (std)')
    
    plt.plot(epochs, test_epoch_loss_means, label='test loss', color='red', linewidth=2)
    plt.fill_between(epochs, 
                    np.array(test_epoch_loss_means) - np.array(test_epoch_loss_stds),
                    np.array(test_epoch_loss_means) + np.array(test_epoch_loss_stds),
                    color='red', alpha=0.2, label='test loss (std)')
    plt.title(title, fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(fontsize=10)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close() 

def load_and_move_to_cpu(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)  
    def _move_tensor_to_cpu(x):
        if isinstance(x, torch.Tensor):
            return x.cpu() 
        elif isinstance(x, list):
            return [_move_tensor_to_cpu(item) for item in x]
        elif isinstance(x, tuple):
            return tuple(_move_tensor_to_cpu(item) for item in x)
        else:
            return x  
    return _move_tensor_to_cpu(data)

def load_ldm_model(config, ckpt, verbose=False):
    from models.ldm.util import instantiate_from_config
    if ckpt.endswith("ckpt"):
        pl_sd = torch.load(ckpt, map_location="cpu", weights_only=False)
        if "global_step" in pl_sd:
            dist.print0(f"Global Step: {pl_sd['global_step']}")
        sd = pl_sd["state_dict"]
    else:
        raise NotImplementedError
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)
    return model

def check_model_dtype(model, device=None):
        if device is not None:
            model = model.to(device)
        print(f"===  {next(model.parameters()).device} ===")
        print(f"=== : {sum(p.numel() for p in model.parameters())} ===")
        print("\n" + "="*80)
        
        for name, param in model.named_parameters():
            dtype = param.dtype 
            shape = param.shape 
            is_half = dtype == torch.float16  
            print(f"name: {name:<40} | dtype: {str(dtype):<15} | is_half: {is_half!s:<5} | shape: {shape}")
        print("="*80 + "\n")
#----------------------------------------------------------------------------

def create_model(dataset_name=None, model_path=None, guidance_type=None, guidance_rate=None, device=None, is_second_stage=False):
    if is_second_stage: # for second-stage distillation
        assert model_path is not None
        dist.print0(f'Loading the second-stage teacher model from "{model_path}"...')
        with dnnlib.util.open_url(model_path, verbose=(dist.get_rank() == 0)) as f:
            net = pickle.load(f)['model'].to(device)
        model_source = 'edm' if dataset_name in ['cifar10', 'ffhq', 'afhqv2', 'imagenet64'] else 'ldm'
        return net, model_source

    if model_path is None:
        model_path, _ = check_file_by_key(dataset_name)
    dist.print0(f'Loading the pre-trained diffusion model from "{model_path}"...')
    if dataset_name in ['cifar10', 'ffhq', 'afhqv2', 'imagenet64']:         # models from EDM
        with dnnlib.util.open_url(model_path, verbose=(dist.get_rank() == 0)) as f:
            net_temp = pickle.load(f)['ema'].to(device)
        network_kwargs = dnnlib.EasyDict()
        if dataset_name in ['cifar10']:
            network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
            network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2])
            network_kwargs.update(dropout=0.13, use_fp16=False)
            network_kwargs.augment_dim = 9
            interface_kwargs = dict(img_resolution=32, img_channels=3, label_dim=0)
        elif dataset_name in ['ffhq', 'afhqv2']:
            network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
            network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[1,2,2,2])
            network_kwargs.update(dropout=0.05, use_fp16=False)
            network_kwargs.augment_dim = 9
            interface_kwargs = dict(img_resolution=64, img_channels=3, label_dim=0)
        else:
            network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4])
            interface_kwargs = dict(img_resolution=64, img_channels=3, label_dim=1000)
        network_kwargs.class_name = 'models.networks_edm.EDMPrecond'
        net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
        net.to(device)
        net.load_state_dict(net_temp.state_dict(), strict=False)
        del net_temp
        net.sigma_min = 0.006
        net.sigma_max = 80.0
        model_source = 'edm'
    elif dataset_name in ['lsun_bedroom', 'lsun_cat']:                      # models from Consistency Models
        from models.cm.cm_model_loader import load_cm_model
        from models.networks_edm import CMPrecond
        net = load_cm_model(model_path)
        net = CMPrecond(net).to(device)
        model_source = 'cm'
    elif dataset_name in ['lsun_bedroom_ldm', 'ffhq_ldm', 'ms_coco']:   # models from LDM
        from omegaconf import OmegaConf
        from models.networks_edm import CFGPrecond
        if dataset_name in ['lsun_bedroom_ldm']:
            assert guidance_type == 'uncond'
            config = OmegaConf.load('./models/ldm/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml')
            net = load_ldm_model(config, model_path)
            net = CFGPrecond(net, img_resolution=64, img_channels=3, guidance_rate=1., guidance_type='uncond', label_dim=0).to(device)
            net.sigma_min = 0.006
        elif dataset_name in ['ffhq_ldm']:
            assert guidance_type == 'uncond'
            config = OmegaConf.load('./models/ldm/configs/latent-diffusion/ffhq-ldm-vq-4.yaml')
            net = load_ldm_model(config, model_path)
            net = CFGPrecond(net, img_resolution=64, img_channels=3, guidance_rate=1., guidance_type='uncond', label_dim=0).to(device)
            net.sigma_min = 0.006
        elif dataset_name in ['ms_coco']:
            assert guidance_type == 'cfg'
            config = OmegaConf.load('./models/ldm/configs/stable-diffusion/v1-inference.yaml')
            net = load_ldm_model(config, model_path)
            net = CFGPrecond(net, img_resolution=64, img_channels=4, guidance_rate=guidance_rate, guidance_type='classifier-free', label_dim=True).to(device)
            net.sigma_min = 0.1
        model_source = 'ldm'
    else:
        raise ValueError(f"Unsupported dataset_name: {dataset_name}")
    
    return net, model_source

#----------------------------------------------------------------------------
# Check model structure
lora_layer = []
def print_network_layers(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Module):
            dist.print0(f"path: {name}\type: {type(module).__name__}")
            if type(module).__name__ in ['Conv2d', 'Linear']:
                lora_layer.append(name)
    dist.print0(f"\nLora layers:\n{lora_layer}\n")

#----------------------------------------------------------------------------

class RandomIntGenerator:
    def __init__(self, seed=42):
        random.seed(seed)

    def randint(self, int_min, int_max):
        while True:
            yield random.randint(int_min, int_max)

#----------------------------------------------------------------------------

def training_loop(
    run_dir             = '.',      # Output directory.
    loss_kwargs         = {},       # Options for loss function.
    optimizer_kwargs    = {},       # Options for optimizer.
    seed                = 0,        # Global random seed.
    batch_size          = None,     # Total batch size for one training iteration.
    batch_gpu           = None,     # Limit batch size per GPU, None = no limit.
    total_kimg          = 20,       # Training duration, measured in thousands of training images.
    kimg_per_tick       = 1,        # Interval of progress prints.
    snapshot_ticks      = 1,        # How often to save network snapshots, None = disable.
    state_dump_ticks    = 99,       # How often to dump training state, None = disable.
    cudnn_benchmark     = True,     # Enable torch.backends.cudnn.benchmark?
    dataset_name        = None,
    model_path          = None,
    guidance_type       = None,
    guidance_rate       = 0.,
    device              = torch.device('cuda'),
    is_second_stage     = False,
    freeze_net          = False,
    rank                = 3,        # LoRA rank
    iter                = 160,      # Train iter nums
    scale               = 1,        # LoRA Scale
    **kwargs,
):
    # Initialize.
    start_time = time.time()
    np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31))
    torch.manual_seed(np.random.randint(1 << 31))
    torch.backends.cudnn.benchmark = cudnn_benchmark
    torch.backends.cudnn.allow_tf32 = False
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False

    writer = SummaryWriter(log_dir=run_dir)
    epoch_loss_means = []
    epoch_loss_stds = []
    test_epoch_loss_means = []
    test_epoch_loss_stds = []
    
    best_test_loss = float('inf')  
    early_stop_counter = 0       
    best_model_saved = False     
    if dataset_name == 'lsun_bedroom_ldm':
        early_stop_patience = 3
        early_stop_min_delta = 0.01
    elif dataset_name == 'ms_coco':
        early_stop_patience = 2
        early_stop_min_delta = 0.1
    else:
        early_stop_patience = 5
        early_stop_min_delta = 0.0001
    # -------------------------------------------------------------------
    
    # Select batch size per GPU.
    batch_gpu_total = batch_size // dist.get_world_size()
    if batch_gpu is None or batch_gpu > batch_gpu_total:
        batch_gpu = batch_gpu_total
    num_acc_rounds = batch_gpu_total // batch_gpu
    # batch_gpu_total = 2, batch_size = 4, batch_gpu = 2, num_acc_rounds = 1
    dist.print0(f'batch_gpu_total = {batch_gpu_total}, batch_size = {batch_size}, batch_gpu = {batch_gpu}, num_acc_rounds = {num_acc_rounds}')
    assert batch_size == batch_gpu * num_acc_rounds * dist.get_world_size()
   
    if dataset_name in ['ms_coco']:
        prompt_path, _ = check_file_by_key('prompts')
        sample_captions = []
        with open(prompt_path, 'r') as file:
            reader = csv.DictReader(file)
            for row in reader:
                text = row['text']
                sample_captions.append(text)

    # Load pre-trained diffusion model.
    if dist.get_rank() != 0:
        torch.distributed.barrier()         # rank 0 goes first
    
    net, model_source = create_model(dataset_name, model_path, guidance_type, guidance_rate, device, is_second_stage)
    if dataset_name in ['ms_coco']:
        net.guidance_rate = 1.0             # training with guidance_rate=1.0, sampling with specified guidance_rate
    net.use_fp16 = True                     # use half precision to accelerate training
    if freeze_net == False:                # if freeze the network
        net.train().requires_grad_(True)
        dist.print0('net.train().requires_grad_(True)')
    else:
        net.eval().requires_grad_(False)
        dist.print0('net.train().requires_grad_(False)  freeze_net...')
    if dist.get_rank() == 0:
        torch.distributed.barrier()         # other ranks follow
    
    # Check model structure
    # print_network_layers(net)
    
    total_params_unet = 0
    net_params = []
    for param in net.parameters():
        net_params.append(param)
        total_params_unet += param.numel()
    dist.print0("Total parameters in U-Net:     ", total_params_unet)
    alpha = rank * scale
    dist.print0(f"LoRA rank: {rank}, alpha: {alpha}")
    net = add_lora_to_custom_model(net, rank=rank, alpha=alpha, dataset_name=dataset_name) 
    total_params_lora = 0
    lora_params = []
    for name, param in net.named_parameters():
        if "lora_A" in name or "lora_B" in name:
            lora_params.append(param)
            # print(f"参数: {name}, 类型: {param.dtype}, 设备: {param.device}")
            total_params_lora += param.numel()
    dist.print0("Total parameters in lora:     ", total_params_lora)
    
    # 获取轨迹信息模型
    customVectorModel = CustomVectorModel()
    customVectorModel = load_model(model=customVectorModel, checkpoint_path=f"./checkpoint_traj/traj_fit_model_T-1_T_dpm_{dataset_name}.pth" , device=device)
    customVectorModel.eval().requires_grad_(False)
    customVectorModel = customVectorModel.to(device)
    total_params_customVectorModel = 0
    for param in customVectorModel.parameters():
        total_params_customVectorModel += param.numel()
    dist.print0("Total parameters in customVectorModel:     ", total_params_customVectorModel)
    
    # 获取插件
    c = net.img_channels
    r = net.img_resolution
    plugin = Plugin_UNetBlock(step=loss_kwargs.num_steps - 1, use_step_condition=loss_kwargs.use_step_condition, dataset_name=dataset_name)
    # plugin = MLP(input_size=c*r*r, hidden_size= c*r*r // 4, output_size=c*r*r*3, step=loss_kwargs.num_steps - 1, use_step_condition=loss_kwargs.use_step_condition)
    plugin.train()
    plugin = plugin.to(device)
    total_params_plugin = 0
    for param in plugin.parameters():
        total_params_plugin += param.numel()
    dist.print0("Total parameters in plugin:     ", total_params_plugin)
    # Setup optimizer.
    dist.print0('Setting up optimizer...')
    loss_kwargs.update(sigma_min=net.sigma_min, sigma_max=net.sigma_max, model_source=model_source)
    loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs)
    if freeze_net:
        all_params = list(plugin.parameters()) + lora_params
    else:
        all_params = net_params + list(plugin.parameters()) + lora_params
    optimizer = dnnlib.util.construct_class_by_name(params=all_params, **optimizer_kwargs) 
    # Record args for sampling
    net.training_kwargs = loss_kwargs
    net.training_kwargs['dataset_name'] = dataset_name
    net.training_kwargs['guidance_type'] = guidance_type
    net.training_kwargs['guidance_rate'] = guidance_rate
    
    if freeze_net:
        ddp = net
    else:
        ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False, find_unused_parameters=True)
    ddp_plugin = torch.nn.parallel.DistributedDataParallel(plugin, device_ids=[device], broadcast_buffers=False, find_unused_parameters=True)
    ddp_traj =  customVectorModel
    # print_network_layers(net)
    # Train.
    dist.print0(f'Training for {total_kimg} kimg...')
    dist.print0()
    cur_nimg = 0
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    dist.update_progress(cur_nimg // 1000, total_kimg)
    stats_jsonl = None
    rig = RandomIntGenerator()
    num_acc_rounds = 128 // batch_size if dataset_name == 'ms_coco' else 1      # number of accumulation rounds, force 128 for stable diffusion
    batch_gpu_total = num_acc_rounds * batch_gpu
    dist.print0(f'num_acc_rounds = {num_acc_rounds} \t batch_gpu_total = {batch_gpu_total}')

    labels = c = uc = [None for k in range(num_acc_rounds)]

    if guidance_type == 'cfg' and dataset_name in ['ms_coco']:
        with torch.no_grad():
            uc = net.model.get_learned_conditioning(batch_gpu * [""])
    loss = torch.zeros(1,)
    
    tea_traj_dir = os.path.join('tea_traj', dataset_name, f'{loss_kwargs.sampler_tea}-{loss_kwargs.num_steps}-{total_kimg}k')
    test_tea_traj_dir = os.path.join('tea_traj', dataset_name, f'{loss_kwargs.sampler_tea}-{loss_kwargs.num_steps}-test-1k')
    
    teacher_traj_list = load_and_move_to_cpu(f'{tea_traj_dir}/teacher_traj_list.pkl')
    latents_list = load_and_move_to_cpu(f'{tea_traj_dir}/latents_list.pkl')
    
    test_teacher_traj_list = load_and_move_to_cpu(f'{test_tea_traj_dir}/teacher_traj_list.pkl')
    test_latents_list = load_and_move_to_cpu(f'{test_tea_traj_dir}/latents_list.pkl')
   
    if dataset_name in ['ms_coco', 'imagenet64']:
        with open(f'{tea_traj_dir}/labels_list.pkl', 'rb') as f:
            labels_list = pickle.load(f)
        with open(f'{test_tea_traj_dir}/labels_list.pkl', 'rb') as f:
            test_labels_list = pickle.load(f)
    else:
        labels_list = [[None for k in range(num_acc_rounds)] for i in range(len(teacher_traj_list))]
        test_labels_list = [[None for k in range(num_acc_rounds)] for i in range(len(test_teacher_traj_list))]
    with open(f'{tea_traj_dir}/c_list.pkl', 'rb') as f:
        c_list = pickle.load(f)
    with open(f'{test_tea_traj_dir}/c_list.pkl', 'rb') as f:
        test_c_list = pickle.load(f)
        
    total_iter = iter
    Epochs = total_iter // total_kimg
    
    for epoch in tqdm(range(Epochs), desc=f"Training Epochs"):
        net.train()
        plugin.train()
        step_loss_means = []
        step_loss_stds = []
     
        for idx in range(len(teacher_traj_list)):
            
            if torch.isnan(loss).any().item():
                net.use_fp16 = False 
                # net_copy.use_fp16 = False 
                dist.print0('Meet nan, disable fp16!')

            latents = copy.deepcopy(latents_list[idx])
            labels = copy.deepcopy(labels_list[idx])
            if net.label_dim:
                if guidance_type == 'cfg' and dataset_name in ['ms_coco']:      # For Stable Diffusion
                    prompts = [random.sample(sample_captions, batch_gpu) for k in range(num_acc_rounds)]
                    with torch.no_grad():
                        if isinstance(prompts[0], tuple):
                            prompts = [list(p) for p in prompts]
                        c = c_list[idx]
                else:                                                           # EDM models
                    labels = labels_list[idx]
            
            teacher_traj = teacher_traj_list[idx]
            buffer_d = [torch.zeros_like(t) for t in teacher_traj]
            new_slice = torch.zeros_like(buffer_d[0][0]).unsqueeze(0)
            
            for i in range(len(buffer_d)):
                buffer_d[i] = torch.cat([buffer_d[i], new_slice], dim=0)
                buffer_d[i][0] = latents[i] 
            
            cur_kimg = cur_nimg / 1000
            if cur_kimg <= 0.1 * total_kimg:
                for g in optimizer.param_groups:
                    g['lr'] = optimizer_kwargs['lr'] * (cur_kimg / (0.1 * total_kimg) + 1e-10)
            elif cur_kimg >= 0.9 * total_kimg:
                for g in optimizer.param_groups:
                    g['lr'] = optimizer_kwargs['lr'] / 5
            elif cur_kimg >= 0.95 * total_kimg:
                for g in optimizer.param_groups:
                    g['lr'] = optimizer_kwargs['lr'] / 10
                    
            for step_idx in range(loss_fn.num_steps - 1):
                optimizer.zero_grad(set_to_none=True)
                for round_idx in range(num_acc_rounds):
                            
                    with misc.ddp_sync(ddp, (round_idx == num_acc_rounds - 1)):
                        if guidance_type in ['uncond', 'cfg']:      # LDM and SD models
                            with autocast("cuda"):
                                loss, stu_out ,buffer_d[round_idx]= loss_fn(net=ddp, plugin=ddp_plugin, traj=ddp_traj, buffer_d=buffer_d[round_idx].to(device), tensor_in=latents[round_idx].to(device), labels=labels[round_idx], step_idx=step_idx, teacher_out=teacher_traj[round_idx][step_idx].to(device), condition=c[round_idx], unconditional_condition=uc)
                        else:
                            loss, stu_out ,buffer_d[round_idx]= loss_fn(net=ddp, plugin=ddp_plugin, traj=ddp_traj, buffer_d=buffer_d[round_idx].to(device), tensor_in=latents[round_idx].to(device), labels=labels[round_idx], step_idx=step_idx, teacher_out=teacher_traj[round_idx][step_idx].to(device))                        
                        latents[round_idx] = stu_out                # start point in next loop
                        training_stats.report('Loss/loss', loss)
                        if not (loss_fn.afs and step_idx == 0):
                            loss.sum().mul(1 / batch_gpu_total).backward()
                with torch.no_grad():
                    loss_norm = torch.norm(loss, p=2, dim=(1,2,3))
                    loss_mean, loss_std = loss_norm.mean().item(), loss_norm.std().item()
                
                dist.print0("Step: {} | Loss-mean: {:12.8f} | loss-std: {:12.8f}".format(step_idx, loss_mean, loss_std))
                
                if not (loss_fn.afs and step_idx == 0):
                    for param in net.parameters():
                        if param.grad is not None:  
                            torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
                    for param in plugin.parameters():
                        if param.grad is not None:
                            torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
                    optimizer.step()
            step_loss_means.append(loss_mean)
            step_loss_stds.append(loss_std)    
            cur_nimg += batch_size * num_acc_rounds / Epochs
            done = (cur_nimg >= total_kimg * 1000)

            if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
                continue

            tick_end_time = time.time()
            fields = []
            fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
            fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
            fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
            fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
            fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
            fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
            fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
            fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
            torch.cuda.reset_peak_memory_stats()
            dist.print0(' '.join(fields))
            
            if (not done) and dist.should_stop():
                done = True
                dist.print0()
                dist.print0('Aborting...')
                
            if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0) and (cur_tick != 0):
                data = dict(model=net)
                for key, value in data.items():
                    if isinstance(value, torch.nn.Module):
                        value = copy.deepcopy(value).eval().requires_grad_(False)
                        misc.check_ddp_consistency(value)
                        data[key] = value.cpu()
                    del value # conserve memory
                if dist.get_rank() == 0:
                    with open(os.path.join(run_dir, f'network-{freeze_net}-{total_kimg}-snapshot-{int(cur_nimg)//1000:06d}.pkl'), 'wb') as f:
                        pickle.dump(data, f)
                del data 
                data = dict(model=plugin)
                for key, value in data.items():
                    if isinstance(value, torch.nn.Module):
                        value = copy.deepcopy(value).eval().requires_grad_(False)
                        misc.check_ddp_consistency(value)
                        data[key] = value.cpu()
                    del value # conserve memory
                if dist.get_rank() == 0:
                    with open(os.path.join(run_dir, f'plugin-{freeze_net}-{total_kimg}-snapshot-{int(cur_nimg)//1000:06d}.pkl'), 'wb') as f:
                        pickle.dump(data, f)
                del data # conserve memory

            training_stats.default_collector.update()
            if dist.get_rank() == 0:
                if stats_jsonl is None:
                    stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at')
                stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n')
                stats_jsonl.flush()
            dist.update_progress(cur_nimg // 1000, total_kimg)
            cur_tick += 1
            tick_start_nimg = cur_nimg
            tick_start_time = time.time()
            maintenance_time = tick_start_time - tick_end_time
            if done:
                break
        
        epoch_mean_loss = np.mean(step_loss_means)
        epoch_std_loss = np.mean(step_loss_stds)
        epoch_loss_means.append(epoch_mean_loss)
        epoch_loss_stds.append(epoch_std_loss)
        writer.add_scalars('Epoch_Loss', {'train_mean': epoch_mean_loss,'train_std': epoch_std_loss}, epoch)
        
        net.eval()
        plugin.eval()
        test_step_loss_means = []
        test_step_loss_stds = []
        for idx in range(len(test_teacher_traj_list)):
            latents = copy.deepcopy(test_latents_list[idx])
            labels = copy.deepcopy(test_labels_list[idx])
            if net.label_dim:
                if guidance_type == 'cfg' and dataset_name in ['ms_coco']:      # For Stable Diffusion
                    prompts = [random.sample(sample_captions, batch_gpu) for k in range(num_acc_rounds)]
                    with torch.no_grad():
                        if isinstance(prompts[0], tuple):
                            prompts = [list(p) for p in prompts]
                        c = test_c_list[idx]
                else:                                                           # EDM models
                    labels = test_labels_list[idx]
            
            teacher_traj = test_teacher_traj_list[idx]
            buffer_d = [torch.zeros_like(t) for t in teacher_traj]
            new_slice = torch.zeros_like(buffer_d[0][0]).unsqueeze(0)
            for i in range(len(buffer_d)):
                buffer_d[i] = torch.cat([buffer_d[i], new_slice], dim=0)  # add a new slice for the first step
                buffer_d[i][0] = latents[i]  # the starting point
            # Perform training step by step
            with torch.no_grad():
                for step_idx in range(loss_fn.num_steps - 1):
                    for round_idx in range(num_acc_rounds):
                        with misc.ddp_sync(ddp, (round_idx == num_acc_rounds - 1)):
                            if guidance_type in ['uncond', 'cfg']:      # LDM and SD models
                                with autocast("cuda"):
                                    loss, stu_out ,buffer_d[round_idx]= loss_fn(net=ddp, plugin=ddp_plugin, traj=ddp_traj, buffer_d=buffer_d[round_idx].to(device), tensor_in=latents[round_idx].to(device), labels=labels[round_idx], step_idx=step_idx, teacher_out=teacher_traj[round_idx][step_idx].to(device), condition=c[round_idx], unconditional_condition=uc)
                            else:
                                loss, stu_out ,buffer_d[round_idx]= loss_fn(net=ddp, plugin=ddp_plugin, traj=ddp_traj, buffer_d=buffer_d[round_idx].to(device), tensor_in=latents[round_idx].to(device), labels=labels[round_idx], step_idx=step_idx, teacher_out=teacher_traj[round_idx][step_idx].to(device))                        
                            latents[round_idx] = stu_out
                    loss_norm = torch.norm(loss, p=2, dim=(1,2,3))
                    loss_mean, loss_std = loss_norm.mean().item(), loss_norm.std().item()
                    dist.print0("Test Step: {} | Loss-mean: {:12.8f} | loss-std: {:12.8f}".format(step_idx, loss_mean, loss_std))
                test_step_loss_means.append(loss_mean)
                test_step_loss_stds.append(loss_std)    
                torch.cuda.reset_peak_memory_stats()
        test_epoch_mean_loss = np.mean(test_step_loss_means)
        test_epoch_std_loss = np.mean(test_step_loss_stds)
        test_epoch_loss_means.append(test_epoch_mean_loss)
        test_epoch_loss_stds.append(test_epoch_std_loss)
        writer.add_scalars('Epoch_Loss', {'test_mean': test_epoch_mean_loss,'test_std': test_epoch_std_loss}, epoch)
        plot_loss_curves(epoch_loss_means, epoch_loss_stds, test_epoch_loss_means, test_epoch_loss_stds, os.path.join(run_dir, 'loss.svg'))
        if dist.get_rank() == 0:
            if test_epoch_mean_loss < best_test_loss - early_stop_min_delta:
                best_test_loss = test_epoch_mean_loss
                early_stop_counter = 0
                if not best_model_saved:
                    best_model_saved = True
                data = dict(model=net)
                for key, value in data.items():
                    if isinstance(value, torch.nn.Module):
                        value = copy.deepcopy(value).eval().requires_grad_(False)
                        misc.check_ddp_consistency(value)
                        data[key] = value.cpu()
                    del value # conserve memory
                if dist.get_rank() == 0:
                    with open(os.path.join(run_dir, f'network-{freeze_net}-{total_kimg}-snapshot-best.pkl'), 'wb') as f:
                        pickle.dump(data, f)
                del data # conserve memory
                data = dict(model=plugin)
                for key, value in data.items():
                    if isinstance(value, torch.nn.Module):
                        value = copy.deepcopy(value).eval().requires_grad_(False)
                        misc.check_ddp_consistency(value)
                        data[key] = value.cpu()
                    del value # conserve memory
                if dist.get_rank() == 0:
                    with open(os.path.join(run_dir, f'plugin-{freeze_net}-{total_kimg}-snapshot-best.pkl'), 'wb') as f:
                        pickle.dump(data, f)
                del data # conserve memory
                dist.print0(f"Epoch {epoch}: Test loss improved to {best_test_loss:.8f}. Saved best model.")
            else:
                for g in optimizer.param_groups:
                   g['lr'] = optimizer_kwargs['lr'] / 5
                early_stop_counter += 1
                dist.print0(f"Epoch {epoch}: Test loss did not improve. Early stop counter: {early_stop_counter}/{early_stop_patience}")
                if early_stop_counter >= early_stop_patience:
                    dist.print0(f"Early stop triggered at epoch {epoch}! No improvement for {early_stop_patience} epochs.")
                    break
    dist.print0()
    dist.print0('Exiting...')
    writer.close()
