import sys
import torch.multiprocessing as mp
import os
project_root = os.path.join(os.path.dirname(__file__), "..")
project_root = os.path.abspath(project_root)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import matplotlib
matplotlib.use('Agg')
from main.utils import prepare_images_for_saving, draw_valued_array, draw_probability_histogram
from main.sd_image_dataset import SDImageDatasetLMDB
from transformers import CLIPTokenizer, AutoTokenizer
from accelerate.utils import ProjectConfiguration
from diffusers.optimization import get_scheduler
from main.utils import SDTextDataset, SDTextImgDataset, cycle 
from dmd_timestepaware_pixel_adv.sd_unified_model import SDUniModel
from accelerate.utils import set_seed,FullyShardedDataParallelPlugin
from accelerate import Accelerator
import argparse 
import shutil 
import wandb 
import torch 
import time 
import os
from copy import deepcopy

import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="diffusers.configuration_utils")

torch.multiprocessing.set_sharing_strategy('file_system')

class Trainer:
    def __init__(self, args):
        self.args = args

        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True 

        # fsdp_plugin = FullyShardedDataParallelPlugin(
        #     use_orig_params=True
        # )
        accelerator_project_config = ProjectConfiguration(logging_dir=args.log_path)
        accelerator = Accelerator(
            gradient_accumulation_steps=1,
            mixed_precision="no",
            # log_with="wandb",
            project_config=accelerator_project_config,
            kwargs_handlers=None,
            # fsdp_plugin=fsdp_plugin
        )
        set_seed(args.seed + accelerator.process_index)

        print(accelerator.state)

        if accelerator.is_main_process:
            output_path = os.path.join(args.output_path, f"time_{int(time.time())}_seed{args.seed}")
            os.makedirs(output_path, exist_ok=False)

            self.cache_dir = os.path.join(args.cache_dir, f"time_{int(time.time())}_seed{args.seed}")
            os.makedirs(self.cache_dir, exist_ok=False)

            self.output_path = output_path

            os.makedirs(args.log_path, exist_ok=True)

            print('name: ',args.wandb_project)
            
            os.makedirs('~/.cache/wandb', exist_ok=True)
            run = wandb.init(project=args.wandb_project, dir='/root/.cache', entity='2936127220-nanjing-university')
            wandb.run.name = args.wandb_name
            print(f"run dir: {run.dir}")
            self.wandb_folder = run.dir
            os.makedirs(self.wandb_folder, exist_ok=True)
        accelerator.wait_for_everyone()
        start_time = time.time()
        self.model = SDUniModel(args, accelerator)
        print(f'\033[91m model loaded, cost time {time.time()-start_time}s to load the model. \033[0m')
        self.max_grad_norm = args.max_grad_norm
        self.denoising = args.denoising
        self.step = 0 

        if self.denoising: assert args.sdxl, "denoising only supported for sdxl." 

        if args.ckpt_only_path is not None:
            if accelerator.is_main_process:
                print(f"loading ckpt only from {args.ckpt_only_path}")
            generator_path = os.path.join(args.ckpt_only_path, "pytorch_model.bin")
            guidance_path = os.path.join(args.ckpt_only_path, "pytorch_model_1.bin")
            print(self.model.feedforward_model.load_state_dict(torch.load(generator_path, map_location="cpu"), strict=False))
            guidance_state_dict = torch.load(guidance_path, map_location="cpu")
            fake_unet_state_dict = {
                k[len("fake_unet."):]: v for k, v in guidance_state_dict.items() if k.startswith("fake_unet.")
            }
            print(self.model.guidance_model.fake_unet.load_state_dict(fake_unet_state_dict, strict=True))
            
        self.sdxl = args.sdxl 
        
        
        if self.sdxl:
            tokenizer_one = AutoTokenizer.from_pretrained(
                args.model_id, subfolder="tokenizer", revision=args.revision, use_fast=False
            )

            tokenizer_two = AutoTokenizer.from_pretrained(
                args.model_id, subfolder="tokenizer_2", revision=args.revision, use_fast=False
            )
            print('Loading Dataset')
            dataset = SDTextDataset(
                args.train_prompt_path, 
                is_sdxl=True,
                tokenizer_one=tokenizer_one,
                tokenizer_two=tokenizer_two
            )  
            # teacher_datast = SDTextImgDataset(
            #     args.train_prompt_path, 
            #     args.train_img_path,
            #     is_sdxl=True,
            #     tokenizer_one=tokenizer_one,
            #     tokenizer_two=tokenizer_two
            # )
            real_dataset = SDImageDatasetLMDB(
                args.real_image_path, 
                is_sdxl=True,
                tokenizer_one=tokenizer_one,
                tokenizer_two=tokenizer_two
            )
            
        else:        
            tokenizer = CLIPTokenizer.from_pretrained(
                args.model_id, subfolder="tokenizer"
            )
            uncond_input_ids = tokenizer(
                [""], max_length=tokenizer.model_max_length, 
                return_tensors="pt", padding="max_length", truncation=True
            ).input_ids.to(accelerator.device)

            dataset = SDTextDataset(args.train_prompt_path, tokenizer, is_sdxl=False)
            self.uncond_embedding = self.model.text_encoder(uncond_input_ids)[0]
            
            # teacher_datast = SDTextImgDataset(
            #     args.train_prompt_path, 
            #     args.train_img_path,
            #     is_sdxl=False,
            #     tokenizer_one=tokenizer,
            # )
            # also load the training dataset images, this will be useful for GAN loss 
            
            real_dataset = SDImageDatasetLMDB(
                args.real_image_path, 
                is_sdxl=False,
                tokenizer_one=tokenizer
            )

        dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
        dataloader = accelerator.prepare(dataloader)
        self.dataloader = cycle(dataloader)

        teacher_dataloader = torch.utils.data.DataLoader(real_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
        teacher_dataloader = accelerator.prepare(teacher_dataloader)
        self.teacher_dataloader = cycle(teacher_dataloader)

        # real_dataloader = torch.utils.data.DataLoader(
        #     real_dataset, num_workers=8, 
        #     batch_size=args.batch_size, shuffle=True, 
        #     drop_last=True,
        #     multiprocessing_context=torch.multiprocessing.get_context('spawn')
        # )
        # real_dataloader = accelerator.prepare(real_dataloader)
        # self.real_dataloader = cycle(real_dataloader)

        # use two dataloader 
        # as the generator and guidance model are trained at different paces 
        guidance_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
        guidance_dataloader = accelerator.prepare(guidance_dataloader)
        self.guidance_dataloader = cycle(guidance_dataloader)

        self.guidance_cls_loss_weight = args.guidance_cls_loss_weight 

        self.cls_on_clean_image = args.cls_on_clean_image 
        self.gen_cls_loss = args.gen_cls_loss 
        self.gen_cls_loss_weight = args.gen_cls_loss_weight 
        self.previous_time = None 

        if self.denoising:
            # denoising_dataloader = torch.utils.data.DataLoader(
            #     real_dataset, num_workers=8, 
            #     batch_size=args.batch_size, shuffle=True, 
            #     drop_last=True,
            #     multiprocessing_context=torch.multiprocessing.get_context('spawn')
            # )
            denoising_dataloader = torch.utils.data.DataLoader(real_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
            denoising_dataloader = accelerator.prepare(denoising_dataloader)
            self.denoising_dataloader = cycle(denoising_dataloader)

        if args.ckpt_only_path is None:
            generator_path = os.path.join(args.output_path, f"checkpoint_model_{self.step:06d}", "pytorch_model.bin")
            guidance_path = os.path.join(args.output_path, f"checkpoint_model_{self.step:06d}", "pytorch_model_1.bin")

            if accelerator.is_main_process:
                print("reloading done")


        self.optimizer_generator = torch.optim.AdamW(
            [param for param in self.model.feedforward_model.parameters() if param.requires_grad], 
            lr=args.generator_lr, 
            betas=(0.9, 0.999),  # pytorch's default 
            weight_decay=0.01  # pytorch's default 
        )
        self.optimizer_guidance = torch.optim.AdamW(
            [param for param in self.model.guidance_model.parameters() if param.requires_grad], 
            lr=args.guidance_lr, 
            betas=(0.9, 0.999),  # pytorch's default 
            weight_decay=0.01  # pytorch's default 
        )

        # actually this scheduler is not very useful (it warms up from 0 to max_lr in 500 / num_gpu steps), but we keep it here for consistency
        self.scheduler_generator = get_scheduler(
            "constant_with_warmup",
            optimizer=self.optimizer_generator,
            num_warmup_steps=args.warmup_step,
            num_training_steps=args.train_iters 
        )
        self.scheduler_guidance = get_scheduler(
            "constant_with_warmup",
            optimizer=self.optimizer_guidance,
            num_warmup_steps=args.warmup_step,
            num_training_steps=args.train_iters 
        )

        (
            self.model.feedforward_model, self.model.guidance_model, self.optimizer_generator,
            self.optimizer_guidance, self.scheduler_generator, self.scheduler_guidance 
        ) = accelerator.prepare(
            self.model.feedforward_model, self.model.guidance_model, self.optimizer_generator, 
            self.optimizer_guidance, self.scheduler_generator, self.scheduler_guidance
        ) 

        self.accelerator = accelerator
        self.train_iters = args.train_iters
        self.batch_size = args.batch_size
        self.resolution = args.resolution 
        self.log_iters = args.log_iters
        self.wandb_iters = args.wandb_iters
        self.latent_resolution = args.latent_resolution
        self.grid_size = args.grid_size
        self.log_loss = args.log_loss
        self.latent_channel = args.latent_channel

        self.no_save = args.no_save
        self.max_checkpoint = args.max_checkpoint

        self.dfake_gen_update_ratio = args.dfake_gen_update_ratio

        if args.checkpoint_path is not None:
            self.load(args.checkpoint_path)


    def load(self, checkpoint_path):
        self.step = int(checkpoint_path.replace("/", "").split("_")[-1])
        print(self.accelerator.load_state(checkpoint_path, strict=False))
        self.accelerator.print(f"Loaded checkpoint from {checkpoint_path}")

    def save(self):
        # NOTE: we save the checkpoints to two places 
        # 1. output_path: save the latest one, this is assumed to be a permanent storage
        # 2. cache_dir: save all checkpoints, this is assumed to be a temporary storage
        # training states 
        
        if self.accelerator.is_main_process:
            output_path = os.path.join(self.output_path, f"checkpoint_model_{self.step:06d}")
            os.makedirs(output_path, exist_ok=True)
            print(f"start saving checkpoint to {output_path}")
            feedforward_state_dict = self.accelerator.get_state_dict(self.model.feedforward_model)
            self.accelerator.save(feedforward_state_dict, os.path.join(output_path, f"pytorch_model.bin"))
            del feedforward_state_dict
            guidance_model_state_dict =self.accelerator.get_state_dict(self.model.guidance_model)
            self.accelerator.save(guidance_model_state_dict,os.path.join(output_path, f"pytorch_model_1.bin"))
            del guidance_model_state_dict

            print("done saving")
        torch.cuda.empty_cache()
    
    def train_one_step(self):
        self.model.train()

        accelerator = self.accelerator

        # 4 channel for SD-VAE, please adapt for other autoencoders 
        noise = torch.randn(self.batch_size, self.latent_channel, self.latent_resolution, self.latent_resolution, device=accelerator.device)
        visual = self.step % self.wandb_iters == 0
        COMPUTE_GENERATOR_GRADIENT = self.step % self.dfake_gen_update_ratio == 0

        
        cache_dict = {
                'generator_log_dict': [],
                'noise': [],
                'text_embedding': [],
                'uncond_embedding': [],
                'real_train_dict': [],
                # 'denoising_dict': []
            }
            
        for i in range(args.gradient_accumulation_steps):
            if COMPUTE_GENERATOR_GRADIENT:
                text_embedding = next(self.dataloader)
            else:
                text_embedding = next(self.guidance_dataloader) 

            if self.sdxl:
                # SDXL uses zero as the uncond_embedding
                uncond_embedding = None 
            else:
                text_embedding = text_embedding['text_input_ids_one'].squeeze(1) # actually it is tokenized text prompts
                uncond_embedding = self.uncond_embedding.repeat(len(text_embedding), 1, 1)

            if self.denoising:
                denoising_dict = next(self.denoising_dataloader)
            else:
                denoising_dict = None

            if self.cls_on_clean_image:
                real_train_dict = next(self.teacher_dataloader) 
            else:
                real_train_dict = None
            
            generator_loss_dict, gen_adv_loss_dict = {}, {}
            # generate images and optionaly compute the generator gradient
            generator_loss_dict, generator_log_dict = self.model(
                noise, text_embedding, uncond_embedding, 
                visual=visual, denoising_dict=denoising_dict,
                compute_generator_gradient=COMPUTE_GENERATOR_GRADIENT,
                real_train_dict=real_train_dict,
                generator_turn=True,
                guidance_turn=False
            )
            
            cache_dict['generator_log_dict'].append(generator_log_dict)
            cache_dict['noise'].append(noise)
            cache_dict['text_embedding'].append(text_embedding)
            cache_dict['uncond_embedding'].append(uncond_embedding)
            cache_dict['real_train_dict'].append(real_train_dict)
            # cache_dict['denoising_dict'].append(denoising_dict)

            # first update the generator if the current step is a multiple of dfake_gen_update_ratio
            generator_loss = 0.0 

            if COMPUTE_GENERATOR_GRADIENT:
                if not self.args.gan_alone:
                    generator_loss += generator_loss_dict["loss_dm"] * self.args.dm_loss_weight
                    self.accelerator.backward(generator_loss/args.gradient_accumulation_steps)
                    generator_grad_norm = accelerator.clip_grad_norm_(self.model.feedforward_model.parameters(), self.max_grad_norm)
                
            
            gen_adv_loss_dict, _ = self.model(
                noise, text_embedding, uncond_embedding, 
                visual=visual, denoising_dict=denoising_dict,
                compute_generator_gradient=COMPUTE_GENERATOR_GRADIENT,
                real_train_dict=real_train_dict,
                generator_turn=False,
                guidance_turn=False,
                gen_adv_turn=True,
                generator_data_dict=generator_log_dict
            )
            if COMPUTE_GENERATOR_GRADIENT:
                gen_adv_loss = gen_adv_loss_dict["gen_cls_loss"] * self.gen_cls_loss_weight
                self.accelerator.backward(gen_adv_loss/args.gradient_accumulation_steps)
                generator_grad_norm = accelerator.clip_grad_norm_(self.model.feedforward_model.parameters(), self.max_grad_norm)
                
                
        if COMPUTE_GENERATOR_GRADIENT:      
            self.optimizer_generator.step()
            self.optimizer_generator.zero_grad() 
            self.optimizer_guidance.zero_grad()
            self.scheduler_generator.step()
            
            if args.ema:
                guidance_model_raw = accelerator.unwrap_model(self.model.guidance_model)
                fake_unet_raw = guidance_model_raw.fake_unet
                gen_unet_raw = accelerator.unwrap_model(self.model.feedforward_model)
                ema_decay = args.ema_decay
                with torch.no_grad():
                    for param_gen, param_fake in zip(gen_unet_raw.parameters(), fake_unet_raw.parameters()):
                        param_fake.data.mul_(ema_decay).add_(param_gen.data, alpha=1 - ema_decay)

                    for buf_gen, buf_fake in zip(gen_unet_raw.buffers(), fake_unet_raw.buffers()):
                        if buf_fake.dtype in [torch.float32, torch.float16, torch.bfloat16]:
                            buf_fake.data.mul_(ema_decay).add_(buf_gen.data, alpha=1 - ema_decay)
            
        # update the guidance model (dfake and classifier)
        for i in range(args.gradient_accumulation_steps):
            if self.denoising:
                denoising_dict = next(self.denoising_dataloader)
            else:
                denoising_dict = None
            generator_log_dict = cache_dict['generator_log_dict'][i]
            noise = cache_dict['noise'][i]
            text_embedding = cache_dict['text_embedding'][i]
            uncond_embedding = cache_dict['uncond_embedding'][i]
            real_train_dict = cache_dict['real_train_dict'][i]
            # denoising_dict = cache_dict['denoising_dict'][i]
            
            guidance_loss_dict, guidance_log_dict = self.model(
                noise, text_embedding, uncond_embedding, 
                visual=visual, denoising_dict=denoising_dict,
                real_train_dict=real_train_dict,
                compute_generator_gradient=COMPUTE_GENERATOR_GRADIENT,
                generator_turn=False,
                guidance_turn=True,
                guidance_data_dict=generator_log_dict['guidance_data_dict']
            )

            guidance_loss = 0 

            guidance_loss += guidance_loss_dict["loss_fake_mean"]

            if self.cls_on_clean_image:
                # discriminator and generator should avoid same sample
                if COMPUTE_GENERATOR_GRADIENT:
                    guidance_loss += guidance_loss_dict["guidance_cls_loss"] * self.guidance_cls_loss_weight * 1e-5
                else:
                    guidance_loss += guidance_loss_dict["guidance_cls_loss"] * self.guidance_cls_loss_weight

            self.accelerator.backward(guidance_loss/args.gradient_accumulation_steps)
            guidance_grad_norm = accelerator.clip_grad_norm_(self.model.guidance_model.parameters(), self.max_grad_norm)
            
        self.optimizer_guidance.step()
        self.optimizer_guidance.zero_grad()
        self.optimizer_generator.zero_grad() # zero out the generator's gradient as well
        self.scheduler_guidance.step()

        # combine the two dictionaries 
        loss_dict = {**generator_loss_dict, **gen_adv_loss_dict, **guidance_loss_dict}
        log_dict = {**generator_log_dict, **guidance_log_dict}

        generated_image_mean = log_dict["guidance_data_dict"]['image'].mean()
        generated_image_std = log_dict["guidance_data_dict"]['image'].std()

        generated_image_mean = accelerator.gather(generated_image_mean).mean()
        generated_image_std = accelerator.gather(generated_image_std).mean()

        if COMPUTE_GENERATOR_GRADIENT:
            if not self.args.gan_alone:
                dmtrain_pred_real_image_mean = log_dict['dmtrain_pred_real_image'].mean()
                dmtrain_pred_real_iamge_std = log_dict['dmtrain_pred_real_image'].std()

                dmtrain_pred_real_image_mean = accelerator.gather(dmtrain_pred_real_image_mean).mean()
                dmtrain_pred_real_iamge_std = accelerator.gather(dmtrain_pred_real_iamge_std).mean()

                dmtrain_pred_fake_image_mean = log_dict['dmtrain_pred_fake_image'].mean()
                dmtrain_pred_fake_image_std = log_dict['dmtrain_pred_fake_image'].std()

                dmtrain_pred_fake_image_mean = accelerator.gather(dmtrain_pred_fake_image_mean).mean()
                dmtrain_pred_fake_image_std = accelerator.gather(dmtrain_pred_fake_image_std).mean()

        # if self.denoising:
        #     original_image_mean = denoising_dict["images"].mean()
        #     original_image_std = denoising_dict["images"].std()

        #     original_image_mean = accelerator.gather(original_image_mean).mean()
        #     original_image_std = accelerator.gather(original_image_std).mean()

        if accelerator.is_main_process and self.log_loss and (not visual):
            wandb_loss_dict = {
                "loss_fake_mean": guidance_loss_dict['loss_fake_mean'].item(),
                "guidance_grad_norm": guidance_grad_norm.item(),
                "generated_image_mean": generated_image_mean.item(),
                "generated_image_std": generated_image_std.item(),
                "batch_size": len(noise)
            }

            if COMPUTE_GENERATOR_GRADIENT and (not self.args.gan_alone):
                wandb_loss_dict.update(
                    {
                        "dmtrain_pred_real_image_mean": dmtrain_pred_real_image_mean.item(),
                        "dmtrain_pred_real_image_std": dmtrain_pred_real_iamge_std.item(),
                        "dmtrain_pred_fake_image_mean": dmtrain_pred_fake_image_mean.item(),
                        "dmtrain_pred_fake_image_std": dmtrain_pred_fake_image_std.item()
                    }
                )

            # if self.denoising:
            #     wandb_loss_dict.update({
            #         "original_image_mean": original_image_mean.item(),
            #         "original_image_std": original_image_std.item()
            #     })

            if COMPUTE_GENERATOR_GRADIENT:
                wandb_loss_dict["generator_grad_norm"] = generator_grad_norm.item()

                if not self.args.gan_alone:
                    wandb_loss_dict["loss_dm"] = loss_dict['loss_dm'].item()
                    wandb_loss_dict["dmtrain_gradient_norm"] = log_dict['dmtrain_gradient_norm']

                if self.gen_cls_loss:
                    wandb_loss_dict.update({
                        "gen_cls_loss": loss_dict['gen_cls_loss'].item()
                    })

            if self.cls_on_clean_image:
                wandb_loss_dict.update({
                    "guidance_cls_loss": loss_dict['guidance_cls_loss'].item()
                })

            wandb.log(
                wandb_loss_dict,
                step=self.step
            )

        if visual:
            if not self.args.gan_alone:
                log_dict['dmtrain_pred_real_image_decoded'] = accelerator.gather(log_dict['dmtrain_pred_real_image_decoded'])
                log_dict['dmtrain_pred_fake_image_decoded'] = accelerator.gather(log_dict['dmtrain_pred_fake_image_decoded'])
        
            log_dict['generated_image'] = accelerator.gather(log_dict['generated_image'])
            log_dict['generated_image_t0'] = accelerator.gather(log_dict['generated_image_t0'])

            if self.denoising:
                log_dict["original_clean_image"] = accelerator.gather(log_dict["original_clean_image"])
                log_dict['denoising_timestep'] = accelerator.gather(log_dict['denoising_timestep'])

            # if self.cls_on_clean_image and self.args.convert_real:
            #     log_dict['converted_real_image'] = accelerator.gather(real_train_dict['pixel_image'])
            #     log_dict['original_real_image'] = accelerator.gather(real_train_dict['origin_pixel_image'])
                

        if accelerator.is_main_process and visual:
            with torch.no_grad():
                if not self.args.gan_alone:
                    (
                        dmtrain_pred_real_image, dmtrain_pred_fake_image, 
                        dmtrain_gradient_norm
                    ) = (
                        log_dict['dmtrain_pred_real_image_decoded'], log_dict['dmtrain_pred_fake_image_decoded'], 
                        log_dict['dmtrain_gradient_norm']
                    )

                    dmtrain_pred_real_image_grid = prepare_images_for_saving(dmtrain_pred_real_image, resolution=self.resolution, grid_size=self.grid_size)
                    dmtrain_pred_fake_image_grid = prepare_images_for_saving(dmtrain_pred_fake_image, resolution=self.resolution, grid_size=self.grid_size)

                    difference_scale_grid = draw_valued_array(
                        (dmtrain_pred_real_image - dmtrain_pred_fake_image).abs().mean(dim=[1, 2, 3]).cpu().numpy(), 
                        output_dir=self.wandb_folder, grid_size=self.grid_size
                    )

                    difference = (dmtrain_pred_real_image - dmtrain_pred_fake_image)
                    difference = (difference - difference.min()) / (difference.max() - difference.min())
                    difference = (difference - 0.5)/0.5
                    difference = prepare_images_for_saving(difference, resolution=self.resolution, grid_size=self.grid_size)

                    data_dict = {
                        "dmtrain_pred_real_image": wandb.Image(dmtrain_pred_real_image_grid),
                        "dmtrain_pred_fake_image": wandb.Image(dmtrain_pred_fake_image_grid),
                        "loss_dm": loss_dict['loss_dm'].item(),
                        "dmtrain_gradient_norm": dmtrain_gradient_norm,
                        "difference": wandb.Image(difference),
                        "difference_norm_grid": wandb.Image(difference_scale_grid),
                    }
                else:
                    data_dict = {} 

                generated_image = log_dict['generated_image']
                generated_image_grid = prepare_images_for_saving(generated_image, resolution=self.resolution, grid_size=self.grid_size)

                generated_image_t0 = log_dict['generated_image_t0']
                generated_image_grid_t0 = prepare_images_for_saving(generated_image_t0, resolution=self.resolution, grid_size=self.grid_size)

                
                generated_image_mean = generated_image_t0.mean()
                generated_image_std = generated_image_t0.std()

                data_dict.update({
                    "generated_image": wandb.Image(generated_image_grid),
                    "generated_image_t0": wandb.Image(generated_image_grid_t0),
                    "loss_fake_mean": loss_dict['loss_fake_mean'].item(),
                    "generator_grad_norm": generator_grad_norm.item(),
                    "guidance_grad_norm": guidance_grad_norm.item(),
                })

                if self.denoising:
                    origianl_clean_image = log_dict["original_clean_image"]
                    origianl_clean_image_grid = prepare_images_for_saving(origianl_clean_image, resolution=self.resolution, grid_size=self.grid_size)

                    denoising_timestep = log_dict["denoising_timestep"]
                    denoising_timestep_grid = draw_valued_array(denoising_timestep.cpu().numpy(), output_dir=self.wandb_folder, grid_size=self.grid_size)

                    data_dict.update(
                        {
                            "original_clean_image": wandb.Image(origianl_clean_image_grid),
                            # "original_image_mean": original_image_mean.item(),
                            # "original_image_std": original_image_std.item(),
                            "denoising_timestep": wandb.Image(denoising_timestep_grid)
                        }
                    )

                if self.cls_on_clean_image:
                    data_dict['guidance_cls_loss'] = loss_dict['guidance_cls_loss'].item()

                    if self.gen_cls_loss:
                        data_dict['gen_cls_loss'] = loss_dict['gen_cls_loss'].item()

                    pred_realism_on_fake = log_dict["pred_realism_on_fake"]
                    pred_realism_on_real = log_dict["pred_realism_on_real"]

                    hist_pred_realism_on_fake = draw_probability_histogram(pred_realism_on_fake.float().cpu().numpy())
                    hist_pred_realism_on_real = draw_probability_histogram(pred_realism_on_real.float().cpu().numpy())

                    if self.args.convert_real:
                        converted_real_image = log_dict['converted_real_image']
                        converted_real_image = prepare_images_for_saving(converted_real_image, resolution=self.resolution, grid_size=self.grid_size)
                        original_real_image = log_dict['original_real_image']
                        original_real_image = prepare_images_for_saving(original_real_image, resolution=self.resolution, grid_size=self.grid_size)
                        
                        data_dict.update(
                            {
                                "hist_pred_realism_on_fake": wandb.Image(hist_pred_realism_on_fake),
                                "hist_pred_realism_on_real": wandb.Image(hist_pred_realism_on_real),
                                "converted_real_image": wandb.Image(converted_real_image),
                                "original_real_image": wandb.Image(original_real_image),
                            }
                        )
                    else:
                        data_dict.update(
                            {
                                "hist_pred_realism_on_fake": wandb.Image(hist_pred_realism_on_fake),
                                "hist_pred_realism_on_real": wandb.Image(hist_pred_realism_on_real),
                            }
                        )


                wandb.log(
                    data_dict,
                    step=self.step
                )
        
        self.accelerator.wait_for_everyone()

    def train(self):
        from tqdm import tqdm
        for index in tqdm(range(self.step, self.train_iters)):          
            self.train_one_step()
            if (not self.no_save)  and (self.step + 1) % self.log_iters == 0:
                self.save()

            self.accelerator.wait_for_everyone()
            if self.accelerator.is_main_process:
                current_time = time.time()
                if self.previous_time is None:
                    self.previous_time = current_time
                else:
                    wandb.log({"per iteration time": current_time-self.previous_time}, step=self.step)
                    self.previous_time = current_time

            self.step += 1

def parse_args():
    mp.set_start_method('spawn', force=True)
    print("Start method:", mp.get_start_method())
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, default="runwayml/stable-diffusion-v1-5")
    parser.add_argument("--vae_id", type=str)
    parser.add_argument("--ref_model_id", type=str, default="")
    parser.add_argument("--output_path", type=str, default="/mnt/localssd/test_stable_diffusion_coco")
    parser.add_argument("--log_path", type=str, default="/mnt/localssd/log_stable_diffusion_coco")
    parser.add_argument("--train_iters", type=int, default=1000000)
    parser.add_argument("--log_iters", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--resolution", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--initialie_generator", action="store_true")
    parser.add_argument("--checkpoint_path", type=str, default=None)
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--wandb_entity", type=str)
    parser.add_argument("--wandb_project", type=str)
    parser.add_argument("--wandb_iters", type=int, default=100)
    parser.add_argument("--wandb_name", type=str, required=True)
    parser.add_argument("--max_grad_norm", type=float, default=10.0, help="max grad norm for network")
    parser.add_argument("--warmup_step", type=int, default=500, help="warmup step for network")
    parser.add_argument("--min_step_percent", type=float, default=0.02, help="minimum step percent for training")
    parser.add_argument("--max_step_percent", type=float, default=0.98, help="maximum step percent for training")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--use_fp16", action="store_true")
    parser.add_argument("--num_train_timesteps", type=int, default=1000)
    parser.add_argument("--ckpt_only_path", type=str, default=None, help="checkpoint (no optimizer state) only path")
    parser.add_argument("--train_prompt_path", type=str)
    parser.add_argument("--train_img_path", type=str)
    parser.add_argument("--latent_resolution", type=int, default=64)
    parser.add_argument("--real_guidance_scale", type=float, default=6.0)
    parser.add_argument("--fake_guidance_scale", type=float, default=1.0)
    parser.add_argument("--grid_size", type=int, default=2)
    parser.add_argument("--no_save", action="store_true", help="don't save ckpt for debugging only")
    parser.add_argument("--cache_dir", type=str, default="/mnt/localssd/cache")
    parser.add_argument("--log_loss", action="store_true", help="log loss at every iteration")
    parser.add_argument("--num_workers", type=int, default=12)
    parser.add_argument("--latent_channel", type=int, default=4)
    parser.add_argument("--max_checkpoint", type=int, default=150)
    parser.add_argument("--dfake_gen_update_ratio", type=int, default=1)
    parser.add_argument("--generator_lr", type=float)
    parser.add_argument("--guidance_lr", type=float)
    parser.add_argument("--cls_on_clean_image", action="store_true")
    parser.add_argument("--gen_cls_loss", action="store_true")
    parser.add_argument("--gen_cls_loss_weight", type=float, default=0)
    parser.add_argument("--guidance_cls_loss_weight", type=float, default=0)
    parser.add_argument("--sdxl", action="store_true")
    parser.add_argument("--fsdp", action="store_true")
    parser.add_argument("--generator_ckpt_path", type=str)
    parser.add_argument("--conditioning_timestep", type=int, default=999)
    parser.add_argument("--tiny_vae", action="store_true")
    parser.add_argument("--gradient_checkpointing", action="store_true", help="apply gradient checkpointing for dfake and generator. this might be a better option than FSDP")
    parser.add_argument("--dm_loss_weight", type=float, default=1.0)

    parser.add_argument("--denoising", action="store_true", help="train the generator for denoising")
    parser.add_argument("--denoising_timestep", type=int, default=1000)
    parser.add_argument("--num_denoising_step", type=int, default=1)
    parser.add_argument("--denoising_loss_weight", type=float, default=1.0)

    parser.add_argument("--diffusion_gan", action="store_true")
    parser.add_argument("--convert_real", action="store_true")
    parser.add_argument("--diffusion_gan_max_timestep", type=int, default=0)
    parser.add_argument("--revision", type=str)

    parser.add_argument("--real_image_path", type=str)
    parser.add_argument("--gan_alone", action="store_true", help="only use the gan loss without dmd")
    parser.add_argument("--backward_simulation", action="store_true")

    parser.add_argument("--generator_lora", action="store_true")
    parser.add_argument("--ema", action="store_true")
    parser.add_argument("--ema_decay", type=float, default=0.998)
    parser.add_argument("--lora_rank", type=int, default=64)
    parser.add_argument("--lora_alpha", type=float, default=64)
    parser.add_argument("--lora_dropout", type=float, default=0.0)
    
    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    # assert args.gradient_accumulation_steps == 1, "grad accumulation not supported yet"

    assert not (args.fsdp and args.gradient_checkpointing), "currently, we don't support both options. open an issue for details."

    assert args.wandb_iters % args.dfake_gen_update_ratio == 0, "wandb_iters should be a multiple of dfake_gen_update_ratio"

    return args 

if __name__ == "__main__":
    args = parse_args()

    trainer = Trainer(args)

    trainer.train()