import copy
import functools
import os

import blobfile as bf
import torch as th
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import RAdam
import torch.nn.functional as F

from . import dist_util, logger
from .fp16_util import MixedPrecisionTrainer
from .nn import update_ema
from .resample import LossAwareSampler, UniformSampler

from .fp16_util import (
    get_param_groups_and_shapes,
    make_master_params,
    state_dict_to_master_params,
    master_params_to_model_params,
    opt_master_params_to_state_dict,
)
import numpy as np
from cm.karras_diffusion import karras_sample
from cm.random_util import get_generator
from torchvision.utils import make_grid, save_image
import datetime
import dnnlib
import pickle
import glob
import scipy

# For ImageNet experiments, this was a good default value.
# We found that the lg_loss_scale quickly climbed to
# 20-21 within the first ~1K steps of training.
INITIAL_LOG_LOSS_SCALE = 20.0


class TrainLoop:
    def __init__(
        self,
        #*,
        model,
        discriminator,
        diffusion,
        data,
        batch_size,
        args=None,
        pretrained_classifier=None,
        classifier_vpsde=None,
    ):
        self.args = args
        self.model = model
        for name, param in self.model.named_parameters():
            print("check how consistency model overrides model parameters")
            print("model parameter before overriding: ", param.data.cpu().detach().reshape(-1)[:3])
            break
        self.discriminator = discriminator
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = args.microbatch if args.microbatch > 0 else batch_size
        self.lr = args.lr
        self.ema_rate = (
            [args.ema_rate]
            if isinstance(args.ema_rate, float)
            else [float(x) for x in args.ema_rate.split(",")]
        )
        self.pretrained_classifier = pretrained_classifier
        self.classifier_vpsde = classifier_vpsde
        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size * dist.get_world_size()
        self.generator = get_generator('determ', 10000, 42)
        self.x_T = self.generator.randn(*(self.args.sampling_batch, self.args.in_channels, self.args.image_size, self.args.image_size),
                                        device='cpu') * self.args.sigma_max #.to(dist_util.dev())
        self.classes = self.generator.randint(0, 1000, (self.args.sampling_batch,), device='cpu')
        self.sync_cuda = th.cuda.is_available()
        self._load_and_sync_parameters()
        for name, param in self.model.named_parameters():
            print("model parameter after overriding: ", param.data.cpu().detach().reshape(-1)[:3])
            break
        if self.discriminator != None:
            for name, param in self.discriminator.named_parameters():
                print("discriminator parameter before overriding: ", param.data.cpu().detach().reshape(-1)[:3])
                break
            self._load_and_sync_discriminator_parameters()
            for name, param in self.discriminator.named_parameters():
                print("discriminator parameter after overriding: ", param.data.cpu().detach().reshape(-1)[:3])
                break

        self.mp_trainer = MixedPrecisionTrainer(
            model=self.model,
            use_fp16=args.use_fp16,
            fp16_scale_growth=args.fp16_scale_growth,
        )
        print("mp trainer master parameter (should same to the model parameter): ", self.mp_trainer.master_params[1].reshape(-1)[:3])

        self.opt = RAdam(
            self.mp_trainer.master_params, lr=self.lr, weight_decay=self.args.weight_decay
        )
        print("opt state dict before overriding: ", self.opt.state_dict())

        if self.discriminator != None:
            self.d_mp_trainer = MixedPrecisionTrainer(
                model=self.discriminator,
                use_fp16=args.use_d_fp16,
                fp16_scale_growth=args.fp16_scale_growth,
            )
            self.d_opt = RAdam(
                self.d_mp_trainer.master_params, lr=args.d_lr, weight_decay=self.args.weight_decay, betas=(0.5, 0.9)
            )
        if self.resume_step:
            if args.load_optimizer:
                self._load_optimizer_state()
            else:
                print("!!!!!!!!!!!!!!!!!!!!!!!!!!!! warning !!!!!!!!!!!!!!!!!!!!!!!!!!!! model optimizer not loaded successfully")
            if self.discriminator != None:
                try:
                    if self.args.large_log:
                        print("discriminator opt state dict before overriding: ", self.d_opt.state_dict())
                    self._load_d_optimizer_state()
                    if self.args.large_log:
                        print("discriminator opt state dict after overriding: ", self.d_opt.state_dict())
                except:
                    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!! warning !!!!!!!!!!!!!!!!!!!!!!!!!!!! discriminator optimizer not loaded successfully")
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        else:
            self.ema_params = [
                copy.deepcopy(self.mp_trainer.master_params)
                for _ in range(len(self.ema_rate))
            ]

        if th.cuda.is_available():
            self.use_ddp = True
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
            self.ddp_discriminator = None
            if self.args.discriminator_training:
                if self.args.discriminator_fix:
                    self.ddp_discriminator = self.discriminator
                else:
                    self.ddp_discriminator = DDP(
                        self.discriminator,
                        device_ids=[dist_util.dev()],
                        output_device=dist_util.dev(),
                        broadcast_buffers=False,
                        bucket_cap_mb=128,
                        find_unused_parameters=False,
                    )
        else:
            if dist.get_world_size() > 1:
                logger.warn(
                    "Distributed training requires CUDA. "
                    "Gradients will not be synchronized properly!"
                )
            self.use_ddp = False
            self.ddp_model = self.model

        self.step = self.resume_step

    def _load_and_sync_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.args.resume_checkpoint

        if resume_checkpoint:
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
            if dist.get_rank() == 0:
                logger.log(f"loading pretrained model from checkpoint: {resume_checkpoint}...")
                if dist.get_world_size() > 1:
                    state_dict = th.load(resume_checkpoint, map_location=dist_util.dev())#"cpu")
                else:
                    state_dict = dist_util.load_state_dict(
                        resume_checkpoint, map_location='cpu',  # dist_util.dev()
                    )
                self.model.load_state_dict(state_dict, strict=False)

        dist_util.sync_params(self.model.parameters())
        dist_util.sync_params(self.model.buffers())

    def _load_and_sync_discriminator_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.args.resume_checkpoint
        if resume_checkpoint:
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
            resume_checkpoint = bf.join(bf.dirname(resume_checkpoint), f"d_model{self.resume_step:06}.pt")
            if dist.get_rank() == 0:
                if os.path.exists(resume_checkpoint):
                    logger.log(f"loading discriminator model from checkpoint: {resume_checkpoint}...")
                    #try:
                    if dist.get_world_size() > 1:
                        state_dict = th.load(resume_checkpoint, map_location="cpu")
                    else:
                        state_dict = dist_util.load_state_dict(
                            resume_checkpoint, map_location=dist_util.dev()
                        )
                    self.discriminator.load_state_dict(state_dict)
                    #except:
                    #    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!! warning !!!!!!!!!!!!!!!!!!!!!!!!!!!! discriminator parameter not loaded successfully")

        dist_util.sync_params(self.discriminator.parameters())
        dist_util.sync_params(self.discriminator.buffers())

    def _load_ema_parameters(self, rate):
        ema_params = copy.deepcopy(self.mp_trainer.master_params)
        print(f"{rate} ema param before overriding: ", ema_params[1].reshape(-1)[:3])
        main_checkpoint = find_resume_checkpoint() or self.args.resume_checkpoint
        ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
        if ema_checkpoint:
            if dist.get_rank() == 0:
                logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
                if dist.get_world_size() > 1:
                    state_dict = th.load(ema_checkpoint, map_location=dist_util.dev())#"cpu")
                else:
                    state_dict = dist_util.load_state_dict(
                        ema_checkpoint, map_location=dist_util.dev()
                    )
                ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)

        dist_util.sync_params(ema_params)
        print(f"{rate} ema param after overriding: ", ema_params[1].reshape(-1)[:3])
        return ema_params

    def _load_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint() or self.args.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
        )
        if bf.exists(opt_checkpoint):
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            if dist.get_world_size() > 1:
                state_dict = th.load(opt_checkpoint, map_location="cpu")
            else:
                state_dict = dist_util.load_state_dict(
                    opt_checkpoint, map_location=dist_util.dev()
                )
            #print(state_dict)
            if self.args.use_fp16:
                self.opt.load_state_dict(state_dict)
            else:
                self.opt.load_state_dict(state_dict)
                #self.opt.load_state_dict(opt_master_params_to_state_dict(self.opt, state_dict,
                #                                                         self.args.use_fp16))
        print("opt state dict after overriding: ", self.opt.state_dict()['state'])

    def _load_d_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint() or self.args.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint), f"d_opt{self.resume_step:06}.pt"
        )
        if bf.exists(opt_checkpoint):
            logger.log(f"loading d_optimizer state from checkpoint: {opt_checkpoint}")
            if os.path.exists(opt_checkpoint):
                if dist.get_world_size() > 1:
                    state_dict = th.load(opt_checkpoint, map_location="cpu")
                else:
                    state_dict = dist_util.load_state_dict(
                        opt_checkpoint, map_location=dist_util.dev()
                    )
                self.d_opt.load_state_dict(state_dict)

    def run_loop(self):
        while not self.args.lr_anneal_steps or self.step < self.args.lr_anneal_steps:
            batch, cond = next(self.data)
            self.run_step(batch, cond)
            if self.step % self.args.log_interval == 0:
                logger.dumpkvs()
            if self.step % self.args.save_interval == 0:
                self.save()
                # Run for a finite amount of time in integration tests.
                if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                    return
        # Save the last checkpoint if it wasn't already saved.
        if (self.step - 1) % self.args.save_interval != 0:
            self.save()

    def run_step(self, batch, cond):
        self.forward_backward(batch, cond)
        took_step = self.mp_trainer.optimize(self.opt)
        if took_step:
            self.step += 1
            self._update_ema()
        self._anneal_lr()
        self.log_step()

    def forward_backward(self, batch, cond):
        raise NotImplementedError
        self.mp_trainer.zero_grad()
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

            compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,
            )

            if last_batch or not self.use_ddp:
                losses = compute_losses()
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses()

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            loss = (losses["loss"] * weights).mean()
            log_loss_dict({k: v * weights for k, v in losses.items()})
            self.mp_trainer.backward(loss)

    def _update_ema(self):
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.mp_trainer.master_params, rate=rate)

    def _anneal_lr(self):
        if not self.args.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.args.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    def log_step(self):
        logger.logkv("step", self.step + self.resume_step)
        logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)

    def save(self):
        def save_checkpoint(rate, params):
            state_dict = self.mp_trainer.master_params_to_state_dict(params)
            if dist.get_rank() == 0:
                logger.log(f"saving model {rate}...")
                if not rate:
                    filename = f"model{(self.step+self.resume_step):06d}.pt"
                else:
                    filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
                with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                    th.save(state_dict, f)

        for rate, params in zip(self.ema_rate, self.ema_params):
            save_checkpoint(rate, params)

        if dist.get_rank() == 0:
            with bf.BlobFile(
                bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
                "wb",
            ) as f:
                th.save(self.opt.state_dict(), f)

        # Save model parameters last to prevent race conditions where a restart
        # loads model at step N, but opt/ema state isn't saved for step N.
        save_checkpoint(0, self.mp_trainer.master_params)
        dist.barrier()

    def sampling(self, model, sampler, ctm=None, teacher=False, step=-1, num_samples=1, batch_size=-1, rate=0.999, png=False, resize=True):
        if not teacher:
            model.eval()
        if step == -1:
            step = self.args.sampling_steps
        if batch_size == -1:
            batch_size = self.args.sampling_batch
        number = 0
        while num_samples > number - self.args.sampling_batch:
            with th.no_grad():
                model_kwargs = {}
                if self.args.class_cond:
                    if self.args.train_classes >= 0:
                        classes = th.ones(size=(batch_size,), device=dist_util.dev(), dtype=int) * self.args.train_classes
                        model_kwargs["y"] = classes
                    elif self.args.train_classes == -2:
                        classes = [0, 1, 9, 11, 29, 31, 33, 55, 76, 89, 90, 130, 207, 250, 279, 281, 291, 323, 386, 387,
                                   388, 417, 562, 614, 759, 789, 800, 812, 848, 933, 973, 980]
                        assert batch_size % len(classes) == 0
                        model_kwargs["y"] = th.tensor([x for x in classes for _ in range(batch_size // len(classes))], device=dist_util.dev())
                    else:
                        model_kwargs["y"] = self.classes.to(dist_util.dev())
                sample = karras_sample(
                    diffusion=self.diffusion,
                    model=model,
                    shape=(batch_size, self.args.in_channels, self.args.image_size, self.args.image_size),
                    steps=step,
                    model_kwargs=model_kwargs,
                    device=dist_util.dev(),
                    clip_denoised=False if self.args.data_name in ['church'] else True if teacher else self.args.clip_denoised,
                    sampler=sampler,
                    generator=self.generator,
                    teacher=teacher,
                    ctm=ctm if ctm != None else True if self.args.training_mode.lower() == 'ctm' else False,
                    x_T=self.x_T.to(dist_util.dev()) if num_samples == 1 else None,
                    clip_output=self.args.clip_output,
                    sigma_min=self.args.sigma_min,
                    sigma_max=self.args.sigma_max,
                    train=True,
                )
                if self.latent_decoder != None:
                    sample = self.latent_decoder(sample, teacher=True)
                if resize:
                    sample = F.interpolate(sample, size=224, mode="bilinear")

                sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
                sample = sample.permute(0, 2, 3, 1)
                sample = sample.contiguous()

                #gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
                #dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
                #all_images = []
                #all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
                all_images = sample.cpu().numpy()

                #arr = np.concatenate(all_images, axis=0)
                arr = all_images
                logger.log(f"created {arr.shape[0]} {sampler} samples")
                if dist.get_rank() == 0:
                    # shape_str = "x".join([str(x) for x in arr.shape])
                    # out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
                    os.makedirs(get_blob_logdir(), exist_ok=True)
                    logger.log(f"saving to {get_blob_logdir()}")
                    nrow = int(np.sqrt(arr.shape[0]))
                    image_grid = make_grid(th.tensor(arr).permute(0, 3, 1, 2) / 255., nrow, padding=2)
                    if num_samples == 1:
                        with bf.BlobFile(bf.join(get_blob_logdir(), f"{'teacher_' if teacher else ''}sample_{sampler}_sampling_step_{step}_step_{self.step}.png"), "wb") as fout:
                            save_image(image_grid, fout)
                    else:
                        r = np.random.randint(1000000)
                        os.makedirs(bf.join(get_blob_logdir(), f"{self.step}_{sampler}_{step}_{rate}"), exist_ok=True)
                        np.savez(bf.join(get_blob_logdir(), f"{self.step}_{sampler}_{step}_{rate}/sample_{r}.npz"), arr)
                        if png and number == 0:
                            with bf.BlobFile(bf.join(get_blob_logdir(),
                                                     f"{self.step}_{sampler}_{step}_{rate}/sample_{r}.png"), "wb") as fout:
                                save_image(image_grid, fout)
                    number += arr.shape[0]
        if not teacher:
            model.train()

class CMTrainLoop(TrainLoop):
    def __init__(
        self,
        *,
        target_model,
        teacher_model,
        latent_decoder,
        ema_scale_fn,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.training_mode = self.args.training_mode
        self.ema_scale_fn = ema_scale_fn
        self.target_model = target_model
        self.teacher_model = teacher_model
        self.latent_decoder = latent_decoder
        self.total_training_steps = self.args.total_training_steps

        if target_model:
            for name, param in self.target_model.named_parameters():
                print("target model parameter before overriding: ", param.data.cpu().detach().reshape(-1)[:3])
                break
            #self._load_and_sync_target_parameters()
            self._load_and_sync_ema_parameters_to_target_parameters()
            for name, param in self.target_model.named_parameters():
                print("target model parameter after overriding: ", param.data.cpu().detach().reshape(-1)[:3])
                break
            self.target_model.requires_grad_(False)
            self.target_model.train()

            self.target_model_param_groups_and_shapes = get_param_groups_and_shapes(
                self.target_model.named_parameters()
            )
            self.target_model_master_params = make_master_params(
                self.target_model_param_groups_and_shapes
            )
            for rate, params in zip(self.ema_rate, self.ema_params):
                if rate == 0.999:
                    #print(f"{rate} ema params: ", params)
                    logger.log(f"loading target model from 0.999 ema...")
                    update_ema(
                        self.target_model_master_params,
                        params,
                        rate=0.0,
                    )
                    master_params_to_model_params(
                        self.target_model_param_groups_and_shapes,
                        self.target_model_master_params,
                    )

            for name, param in self.target_model.named_parameters():
                print("target model parameter after all: ", param.data.cpu().detach().reshape(-1)[:3])
                break

        if teacher_model:
            #self._load_and_sync_teacher_parameters()
            self.teacher_model.requires_grad_(False)
            self.teacher_model.eval()

        self.global_step = self.step
        self.initial_step = copy.deepcopy(self.step)

    def _load_and_sync_ema_parameters_to_target_parameters(self):
        if dist.get_rank() == 0:
            for rate, params in zip(self.ema_rate, self.ema_params):
                if rate == 0.999:
                    logger.log(f"loading target model from 0.999 ema...")
                    state_dict = self.mp_trainer.master_params_to_state_dict(params)
                    self.target_model.load_state_dict(state_dict)

        dist_util.sync_params(self.target_model.parameters())
        dist_util.sync_params(self.target_model.buffers())

    def _load_and_sync_target_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.args.resume_checkpoint
        if resume_checkpoint:
            path, name = os.path.split(resume_checkpoint)
            target_name = name.replace("model", "target_model")
            resume_target_checkpoint = os.path.join(path, target_name)
            if bf.exists(resume_target_checkpoint) and dist.get_rank() == 0:
                logger.log(
                    f"loading target model from checkpoint: {resume_target_checkpoint}..."
                )
                if dist.get_world_size() > 1:
                    state_dict = th.load(resume_target_checkpoint, map_location="cpu")
                else:
                    state_dict = dist_util.load_state_dict(
                        resume_target_checkpoint, map_location=dist_util.dev()
                    )
                self.target_model.load_state_dict(state_dict, strict=False)

        dist_util.sync_params(self.target_model.parameters())
        dist_util.sync_params(self.target_model.buffers())

    def _load_and_sync_teacher_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.args.resume_checkpoint
        if resume_checkpoint:
            path, name = os.path.split(resume_checkpoint)
            teacher_name = name.replace("model", "teacher_model")
            resume_teacher_checkpoint = os.path.join(path, teacher_name)

            if bf.exists(resume_teacher_checkpoint) and dist.get_rank() == 0:
                logger.log(
                    f"loading teacher model from checkpoint: {resume_teacher_checkpoint}..."
                )
                if dist.get_world_size() > 1:
                    state_dict = th.load(resume_teacher_checkpoint, map_location="cpu")
                else:
                    state_dict = dist_util.load_state_dict(
                        resume_teacher_checkpoint, map_location=dist_util.dev()
                    )
                self.teacher_model.load_state_dict(state_dict)#, strict=False)

        dist_util.sync_params(self.teacher_model.parameters())
        dist_util.sync_params(self.teacher_model.buffers())

    def run_loop(self):
        saved = False
        while (
            self.step < self.args.lr_anneal_steps
            or self.global_step < self.total_training_steps
        ):
            #print("!!!!!!!!!!!!!!!!!!!!!!!!!!: ", th.rand(1))
            batch, cond = next(self.data)
            if dist.get_rank() == 0:
                if self.step == self.initial_step + 10 or (self.step % self.args.sample_interval == self.args.sample_interval - 1):
                    if self.args.training_mode.lower() == 'ctm':
                        if self.args.consistency_weight > 0.:
                            self.sampling(model=self.ddp_model, sampler='exact')
                            self.sampling(model=self.ddp_model, sampler='exact', step=2)
                            self.sampling(model=self.ddp_model, sampler='exact', step=1)
                        else:
                            self.sampling(model=self.ddp_model, sampler='heun', ctm=True, teacher=True)
                    elif self.args.training_mode.lower() == 'cd':
                        self.sampling(model=self.ddp_model, sampler='onestep', step=1)
                if self.step == self.initial_step + 10 and self.teacher_model != None:
                    self.sampling(model=self.teacher_model, sampler='heun', ctm=False, teacher=True)
                #dist.barrier()
            self.run_step(batch, cond)
            if self.args.large_log:
                print("mp trainer master parameter after one step update: ", self.mp_trainer.master_params[1].reshape(-1)[:3])

                for name, param in self.model.named_parameters():
                    print("model parameter after one step update: ", param.data.cpu().detach().reshape(-1)[:3])
                    break
                for name, param in self.target_model.named_parameters():
                    print("target model parameter after one step update: ", param.data.cpu().detach().reshape(-1)[:3])
                    break
            if (
                self.global_step
                and self.args.eval_interval != -1
                and self.global_step % self.args.eval_interval == 0
                and self.step - self.initial_step > 10
                or self.step == self.args.lr_anneal_steps - 1
                or self.global_step == self.total_training_steps - 1
            ):
                if dist.get_rank() == 0:
                    #model_state_dict = self.mp_trainer.master_params_to_state_dict(self.mp_trainer.master_params)
                    model_state_dict = self.model.state_dict()
                    self.evaluation(0.0)
                    for rate, params in zip(self.ema_rate, self.ema_params):
                        if not self.args.compute_ema_fids:
                            if rate != 0.999:
                                continue
                        #if dist.get_rank() == 0:
                        state_dict = self.mp_trainer.master_params_to_state_dict(params)
                        self.model.load_state_dict(state_dict)
                        #dist_util.sync_params(self.model.parameters())
                        #dist_util.sync_params(self.model.buffers())
                        #dist.barrier()
                        self.evaluation(rate)
                    #if dist.get_rank() == 0:
                    self.model.load_state_dict(model_state_dict)
                    #dist_util.sync_params(self.model.parameters())
                    #dist_util.sync_params(self.model.buffers())
            if self.args.large_log:
                print("mp trainer master parameter after sampling: ",
                      self.mp_trainer.master_params[1].reshape(-1)[:3])
                for name, param in self.model.named_parameters():
                    print("model parameter after sampling: ", param.data.cpu().detach().reshape(-1)[:3])
                    break
                for name, param in self.target_model.named_parameters():
                    print("target model parameter after sampling: ", param.data.cpu().detach().reshape(-1)[:3])
                    break
            dist.barrier()

            saved = False
            if (
                self.global_step
                and self.args.save_interval != -1
                and self.global_step % self.args.save_interval == 0
            ):
                self.save()
                if self.discriminator != None:
                    self.d_save()
                saved = True
                th.cuda.empty_cache()
                # Run for a finite amount of time in integration tests.
                if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                    return
            if self.global_step % self.args.log_interval == 0:
                logger.dumpkvs()
                logger.log(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
            if self.args.large_log:
                print("mp trainer master parameter after saving: ",
                      self.mp_trainer.master_params[1].reshape(-1)[:3])
                for name, param in self.model.named_parameters():
                    print("model parameter after saving: ", param.data.cpu().detach().reshape(-1)[:3])
                    break
                for name, param in self.target_model.named_parameters():
                    print("target model parameter after saving: ", param.data.cpu().detach().reshape(-1)[:3])
                    break
                print(f"0.999 ema param after overriding (should be same to the target parameter): ",
                      self.ema_params[0][1].reshape(-1)[:3])

        # Save the last checkpoint if it wasn't already saved.
        if not saved:
            self.save()
            if self.discriminator != None:
                self.d_save()

    def evaluation(self, rate):
        if self.args.training_mode.lower() == 'ctm':
            if self.args.consistency_weight > 0.:
                self.eval(step=1, rate=rate, ctm=True)
                if self.args.compute_ema_fids:
                    self.eval(step=2, rate=rate, ctm=True)
                    self.eval(step=4, rate=rate, ctm=True)
                self.eval(step=18, rate=rate, ctm=True)
                self.eval(step=18, sampler='heun', teacher=True, ctm=True, rate=rate)
            else:
                self.eval(step=18, sampler='heun', teacher=True, ctm=True, rate=rate)
        elif self.args.training_mode.lower() == 'cd':
            self.eval(step=1, sampler='onestep', rate=rate, ctm=False)

    def run_step(self, batch, cond):
        if self.args.large_log:
            print("mp trainer master parameter before update: ", self.mp_trainer.master_params[1].reshape(-1)[:3])
            for name, param in self.model.named_parameters():
                print("model parameter before update: ", param.data.cpu().detach().reshape(-1)[:3])
                break
            for name, param in self.target_model.named_parameters():
                print("target model parameter before update: ", param.data.cpu().detach().reshape(-1)[:3])
                break
        self.forward_backward(batch, cond)
        if self.discriminator == None:
            took_step = self.mp_trainer.optimize(self.opt)
            if took_step:
                self._update_ema()
                if self.target_model:
                    self._update_target_ema()
                self.step += 1
                self.global_step += 1
        else:
            if self.step % self.args.g_learning_period == 0:
                took_step = self.mp_trainer.optimize(self.opt)
            else:
                took_step = self.d_mp_trainer.optimize(self.d_opt)
            # print(self.step, took_step)
            if took_step:
                if self.step % self.args.g_learning_period == 0:
                    self._update_ema()
                    if self.target_model:
                        self._update_target_ema()
                self.step += 1
                self.global_step += 1
        self._anneal_lr()
        self.log_step()

    '''def _update_target_ema(self):
        target_ema, scales = self.ema_scale_fn(self.global_step)
        update_ema(self.target_model.parameters(), self.model.parameters(), rate=target_ema)'''

    def _update_target_ema(self):
        target_ema, scales = self.ema_scale_fn(self.global_step)
        with th.no_grad():
            update_ema(
                self.target_model_master_params,
                self.mp_trainer.master_params,
                rate=target_ema,
            )
            master_params_to_model_params(
                self.target_model_param_groups_and_shapes,
                self.target_model_master_params,
            )

    def forward_backward(self, batch, cond):
        self.mp_trainer.zero_grad()
        if self.discriminator != None:
            self.d_mp_trainer.zero_grad()
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            compute_losses = functools.partial(
                self.diffusion.ctm_losses,
                step=self.step,
                model=self.ddp_model,
                x_start=micro,
                model_kwargs=micro_cond,
                target_model=self.target_model,
                teacher_model=self.teacher_model,
                pretrained_classifier=self.pretrained_classifier,
                classifier_vpsde=self.classifier_vpsde,
                discriminator=self.ddp_discriminator,
                init_step=self.initial_step,
                ctm=True if self.training_mode.lower() == 'ctm' else False,
            )

            if last_batch or not self.use_ddp:
                losses = compute_losses()
            else:
                if self.step % self.args.g_learning_period == 0:
                    with self.ddp_model.no_sync():
                        losses = compute_losses()
                else:
                    with self.ddp_discriminator.no_sync():
                        losses = compute_losses()

            if 'consistency_loss' in list(losses.keys()):
                # print("Consistency learning")
                loss = self.args.consistency_weight * losses["consistency_loss"].mean()  # + self.denoising_weight * losses['denoising_loss']).mean()

                if 'd_loss' in list(losses.keys()):
                    print("GAN learning, ", self.args.discriminator_weight, losses['d_loss'].mean())
                    loss = loss + self.args.discriminator_weight * losses['d_loss'].mean()
                if 'denoising_loss' in list(losses.keys()):
                    loss = loss + self.args.denoising_weight * losses['denoising_loss'].mean()
                log_loss_dict({k: v.view(-1) for k, v in losses.items()})
                self.mp_trainer.backward(loss)

            elif 'd_loss' in list(losses.keys()):
                assert self.step % self.args.g_learning_period != 0
                loss = (losses["d_loss"]).mean()
                self.d_mp_trainer.backward(loss)
                if self.args.large_log:
                    for param in self.discriminator.parameters():
                        try:
                            print("discriminator param data, grad: ", param.grad.reshape(-1)[:3])
                        except:
                            print("discriminator param grad: ", param.grad)
                        break
            elif 'denoising_loss' in list(losses.keys()):
                loss = losses['denoising_loss'].mean()
                log_loss_dict({k: v.view(-1) for k, v in losses.items()})
                self.mp_trainer.backward(loss)

    @th.no_grad()
    def eval(self, step=1, sampler='exact', teacher=False, ctm=False, rate=0.999):
        if self.args.data_name.lower() == 'cifar10':
            model = self.model
            self.sampling(model=model, sampler=sampler, teacher=teacher, step=step,
                          num_samples=self.args.eval_num_samples, batch_size=self.args.eval_batch, rate=rate, ctm=ctm, png=False, resize=False)
            if dist.get_rank() == 0:
                self.calculate_inception_stats_npz(os.path.join(get_blob_logdir(), f"{self.step}_{sampler}_{step}_{rate}"),
                                                   num_samples=self.args.eval_num_samples,
                                                   step=step, device=dist_util.dev(), rate=rate)
        else:
            logger.log('Not implemented yet for FID computation other than CIFAR10')

    def calculate_inception_stats_npz(self, image_path, num_samples=50000, step=1, batch_size=250, device=th.device('cuda'),
                                      rate=0.999):
        print('Loading Inception-v3 model...')
        detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
        detector_kwargs = dict(return_features=True)
        feature_dim = 2048
        with dnnlib.util.open_url(detector_url, verbose=(0 == 0)) as f:
            detector_net = pickle.load(f).to(device)

        print(f'Loading images from "{image_path}"...')
        mu = th.zeros([feature_dim], dtype=th.float64, device=device)
        sigma = th.zeros([feature_dim, feature_dim], dtype=th.float64, device=device)

        files = glob.glob(os.path.join(image_path, 'sample*.npz'))
        count = 0
        for file in files:
            images = np.load(file)['arr_0']  # [0]#["samples"]
            for k in range((images.shape[0] - 1) // batch_size + 1):
                mic_img = images[k * batch_size: (k + 1) * batch_size]
                mic_img = th.tensor(mic_img).permute(0, 3, 1, 2).to(device)
                features = detector_net(mic_img, **detector_kwargs).to(th.float64)
                if count + mic_img.shape[0] > num_samples:
                    remaining_num_samples = num_samples - count
                else:
                    remaining_num_samples = mic_img.shape[0]
                mu += features[:remaining_num_samples].sum(0)
                sigma += features[:remaining_num_samples].T @ features[:remaining_num_samples]
                count = count + remaining_num_samples
                print(count)
            if count >= num_samples:
                break
        assert count == num_samples
        print(count)
        mu /= num_samples
        sigma -= mu.ger(mu) * num_samples
        sigma /= num_samples - 1
        mu = mu.cpu().numpy()
        sigma = sigma.cpu().numpy()
        with dnnlib.util.open_url(self.args.ref_path) as f:
            ref = dict(np.load(f))
        mu_ref = ref['mu']
        sigma_ref = ref['sigma']

        m = np.square(mu - mu_ref).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
        fid = m + np.trace(sigma + sigma_ref - s * 2)
        fid = float(np.real(fid))
        assert num_samples % 1000 == 0
        logger.log(f"{self.step}-th step exact sampler (NFE {step}) EMA {rate} FID-{num_samples // 1000}k: {fid}")

    def save(self):
        def save_checkpoint(rate, params):
            state_dict = self.mp_trainer.master_params_to_state_dict(params)
            if dist.get_rank() == 0:
                logger.log(f"saving model {rate}...")
                if not rate:
                    filename = f"model{self.global_step:06d}.pt"
                else:
                    filename = f"ema_{rate}_{self.global_step:06d}.pt"
                with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                    th.save(state_dict, f)

        for rate, params in zip(self.ema_rate, self.ema_params):
            save_checkpoint(rate, params)

        logger.log("saving optimizer state...")
        if dist.get_rank() == 0:
            with bf.BlobFile(
                bf.join(get_blob_logdir(), f"opt{self.global_step:06d}.pt"),
                "wb",
            ) as f:
                th.save(self.opt.state_dict(), f)

        if dist.get_rank() == 0:
            if self.target_model:
                logger.log("saving target model state")
                filename = f"target_model{self.global_step:06d}.pt"
                with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                    th.save(self.target_model.state_dict(), f)
            if self.teacher_model and self.training_mode == "progdist":
                logger.log("saving teacher model state")
                filename = f"teacher_model{self.global_step:06d}.pt"
                with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                    th.save(self.teacher_model.state_dict(), f)

        # Save model parameters last to prevent race conditions where a restart
        # loads model at step N, but opt/ema state isn't saved for step N.
        save_checkpoint(0, self.mp_trainer.master_params)
        dist.barrier()

    def d_save(self):
        logger.log("saving d_optimizer state...")
        if dist.get_rank() == 0:
            with bf.BlobFile(
                bf.join(get_blob_logdir(), f"d_opt{self.global_step:06d}.pt"),
                "wb",
            ) as f:
                th.save(self.d_opt.state_dict(), f)
            with bf.BlobFile(bf.join(get_blob_logdir(), f"d_model{self.global_step:06d}.pt"), "wb") as f:
                th.save(self.d_mp_trainer.master_params_to_state_dict(self.d_mp_trainer.master_params), f)

        # Save model parameters last to prevent race conditions where a restart
        # loads model at step N, but opt/ema state isn't saved for step N.
        dist.barrier()

    def log_step(self):
        step = self.global_step
        logger.logkv("step", step)
        logger.logkv("samples", (step + 1) * self.global_batch)


def parse_resume_step_from_filename(filename):
    """
    Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
    checkpoint's number of steps.
    """
    split = filename.split("model")
    if len(split) < 2:
        return 0
    split1 = split[-1].split(".")[0]
    try:
        return int(split1)
    except ValueError:
        return 0


def get_blob_logdir():
    # You can change this to be a separate path to save checkpoints to
    # a blobstore or some external drive.
    return logger.get_dir()


def find_resume_checkpoint():
    # On your infrastructure, you may want to override this to automatically
    # discover the latest checkpoint on your blob storage, etc.
    return None


def find_ema_checkpoint(main_checkpoint, step, rate):
    if main_checkpoint is None:
        return None
    filename = f"ema_{rate}_{(step):06d}.pt"
    path = bf.join(bf.dirname(main_checkpoint), filename)
    if bf.exists(path):
        return path
    return None


def log_loss_dict(losses):
    for key, values in losses.items():
        logger.logkv_mean(f"{key} mean", values.mean().item())
        # Log the quantiles (four quartiles, in particular).
        logger.logkv_mean(f"{key} std", values.std().item())
        #for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
        #    quartile = int(4 * sub_t / diffusion.num_timesteps)
        #    logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
