#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import argparse
import gc
import hashlib
import itertools
import logging
import math
import os
import shutil
import warnings
from pathlib import Path
from typing import Dict

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed, GradScalerKwargs
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig

import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
    StableDiffusionPipeline,
    UNet2DConditionModel,
)
from diffusers.loaders import (
    LoraLoaderMixin,
    text_encoder_lora_state_dict,
    text_encoder_attn_modules,
    text_encoder_mlp_modules
)
from diffusers.models.attention_processor import (
    AttnAddedKVProcessor,
    AttnAddedKVProcessor2_0,
    LoRAAttnAddedKVProcessor,
    LoRAAttnProcessor,
    LoRAAttnProcessor2_0,
    SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available












from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPModel, CLIPProcessor, CLIPVisionModel, CLIPVisionModelWithProjection
from diffusers.loaders import AttnProcsLayers
# from diffusers.utils import TEXT_ENCODER_ATTN_MODULE
import json
import pytz
import datetime

from torchvision.models.mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torch import nn
import random
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from insightface.app import FaceAnalysis
import torchvision
from PIL import Image, ImageOps, ImageDraw, ImageFont
import accelerate
import scipy
import sys
from diffusers.training_utils import EMAModel
import copy

import pickle as pkl
from sentence_transformers import SentenceTransformer, util

from skimage import transform
import kornia

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# check_min_version("0.19.0.dev0")

# testing against diffusers == 0.19.3

# os.environ["gpu_ids"] = "1"

os.environ["WANDB__SERVICE_WAIT"] = "300"  # set to DETAIL for runtime logging.

class FaceFeatsModel(torch.nn.Module):
    def __init__(self, face_feats_path, face_gender_threshold):
        super().__init__()
        
        with open(face_feats_path, "rb") as f:
            face_feats, face_genders, face_logits = pkl.load(f)
        face_feats = torch.nn.functional.normalize(face_feats, dim=-1)
        face_probs = torch.softmax(face_logits, dim=-1)
        indicators = torch.tensor([face_probs[i, i_gender]>face_gender_threshold for i, i_gender in enumerate(face_genders)])
        
        face_feats = face_feats[indicators]
        face_genders = face_genders[indicators]
        face_probs = face_probs[indicators]
        face_feats_male = face_feats[face_genders==1]
        face_feats_female = face_feats[face_genders==0]
        
        self.face_feats_male = nn.Parameter(face_feats_male)
        self.face_feats_female = nn.Parameter(face_feats_female)
        
        self.face_feats_male.requires_grad_(False)
        self.face_feats_female.requires_grad_(False)
        
        self.face_gender_threshold = face_gender_threshold
        
        self.face_feats_male_add = []
        self.face_feats_female_add = []
        
    def forward(self, x):
        """
        no forward function
        """
        return None
        
    @torch.no_grad()
    def semantic_search(self, query_embeddings, genders=None, selector=None, return_similarity=False):
        target_embeddings = torch.ones_like(query_embeddings) * (-1)
        if return_similarity:
            similarities = torch.ones([query_embeddings.shape[0]], device=query_embeddings.device, dtype=query_embeddings.dtype) * (-1)
            
        if genders==None:
            if selector.sum()>0:
                corpus_embeddings = torch.cat([self.face_feats_female,self.face_feats_male])
                hits = util.semantic_search(query_embeddings[selector], corpus_embeddings, score_function=util.dot_score, top_k=1)
                target_embeddings_ = torch.cat([corpus_embeddings[hit[0]["corpus_id"]].unsqueeze(dim=0) for hit in hits])
                target_embeddings[selector] = target_embeddings_
                if return_similarity:
                    similarities_ = torch.tensor([hit[0]["score"] for hit in hits], device=query_embeddings.device, dtype=query_embeddings.dtype)
                    similarities[selector] = similarities_
        else:
            if (genders==0).sum() > 0:
                hits = util.semantic_search(query_embeddings[genders==0], self.face_feats_female, score_function=util.dot_score, top_k=1)
                target_embeddings_ = torch.cat([self.face_feats_female[hit[0]["corpus_id"]].unsqueeze(dim=0) for hit in hits])
                target_embeddings[genders==0] = target_embeddings_
                if return_similarity:
                    similarities_ = torch.tensor([hit[0]["score"] for hit in hits], device=query_embeddings.device, dtype=query_embeddings.dtype)
                    similarities[genders==0] = similarities_
            
            if (genders==1).sum() > 0:
                hits = util.semantic_search(query_embeddings[genders==1], self.face_feats_male, score_function=util.dot_score, top_k=1)
                target_embeddings_ = torch.cat([self.face_feats_male[hit[0]["corpus_id"]].unsqueeze(dim=0) for hit in hits])
                target_embeddings[genders==1] = target_embeddings_
                if return_similarity:
                    similarities_ = torch.tensor([hit[0]["score"] for hit in hits], device=query_embeddings.device, dtype=query_embeddings.dtype)
                    similarities[genders==1] = similarities_
        
        if return_similarity:
            return target_embeddings.data.detach().clone(), similarities
        else:
            return target_embeddings.data.detach().clone()
    
    @torch.no_grad()
    def add_face_feats(self, feats, logits):
        probs = torch.softmax(logits, dim=-1)
        genders = probs.argmax(dim=-1)
        if (genders==0).sum() > 0:
            indicators = (genders==0) * (probs[:,0]>self.face_gender_threshold)
            self.face_feats_female = nn.Parameter(torch.cat([self.face_feats_female, feats[indicators]]))
            self.face_feats_female.requires_grad_(False)
        if (genders==1).sum() > 0:
            indicators = (genders==1) * (probs[:,0]>self.face_gender_threshold)
            self.face_feats_male = nn.Parameter(torch.cat([self.face_feats_male, feats[indicators]]))
            self.face_feats_male.requires_grad_(False)


def clean_checkpoint(ckpts_save_dir, name, checkpoints_total_limit):
    checkpoints = os.listdir(ckpts_save_dir)
    checkpoints = [d for d in checkpoints if d.startswith(name)]
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

    # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
    if len(checkpoints) >= checkpoints_total_limit:
        num_to_remove = len(checkpoints) - checkpoints_total_limit + 1
        removing_checkpoints = checkpoints[0:num_to_remove]

        logger.info(
            f"chekpoint name:{name}, {len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
        )
        logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

        for removing_checkpoint in removing_checkpoints:
            removing_checkpoint = os.path.join(args.ckpts_save_dir, removing_checkpoint)
            shutil.rmtree(removing_checkpoint)


def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def plot_in_grid(images, save_to, face_indicators=None, face_bboxs=None, preds_gender=None, pred_class_probs_gender=None, face_real_scores=None):
    """
    images: torch tensor in shape of [N,3,H,W], in range [-1,1]
    """
    images_w_face = images[face_indicators]
    images_wo_face = images[face_indicators.logical_not()]

    # first reorder everything from most to least male, from most to least female, and finally images without faces
    idxs_male = (preds_gender == 1).nonzero(as_tuple=False).view([-1])
    probs_male = pred_class_probs_gender[idxs_male]
    idxs_male = idxs_male[probs_male.argsort(descending=True)]

    idxs_female = (preds_gender == 0).nonzero(as_tuple=False).view([-1])
    probs_female = pred_class_probs_gender[idxs_female]
    idxs_female = idxs_female[probs_female.argsort(descending=True)]

    idxs_no_face = (preds_gender == -1).nonzero(as_tuple=False).view([-1])

    images_to_plot = []
    idxs_reordered = torch.torch.cat([idxs_male, idxs_female, idxs_no_face])
    
    if face_real_scores != None:
        for idx in idxs_reordered:
            img = images[idx]
            face_indicator = face_indicators[idx]
            face_bbox = face_bboxs[idx]
            pred_gender = preds_gender[idx]
            pred_class_prob_gender = pred_class_probs_gender[idx]
            face_real_score = face_real_scores[idx]
            
            face_real_color = "magenta"
            if pred_gender == 1:
                pred = "Male"
                border_color = "blue"
            elif pred_gender == 0:
                pred = "Female"
                border_color = "red"
            elif pred_gender == -1:
                pred = "Undetected"
                border_color = "white"
            
            img_pil = transforms.ToPILImage()(img*0.5+0.5)
            img_pil_draw = ImageDraw.Draw(img_pil)  
            img_pil_draw.rectangle(face_bbox.tolist(), fill =None, outline =border_color, width=4)

            img_pil = ImageOps.expand(img_pil_draw._image, border=(50,0,0,0),fill=face_real_color)
            img_pil_draw = ImageDraw.Draw(img_pil)
            if face_real_score.item() < 1:
                img_pil_draw.rectangle([(0,0),(50,(1-face_real_score.item())*512)], fill ="white", outline =None)

            img_pil = ImageOps.expand(img_pil_draw._image, border=(50,0,0,0),fill=border_color)
            img_pil_draw = ImageDraw.Draw(img_pil)
            if pred_class_prob_gender.item() < 1:
                img_pil_draw.rectangle([(0,0),(50,(1-pred_class_prob_gender.item())*512)], fill ="white", outline =None)
                
            img_pil_draw.text((400, 400), f"{idx.item()}", align ="left")

            img_pil = ImageOps.expand(img_pil_draw._image, border=(10,10,10,10),fill="black")
            
            # img_pil_draw._image.save(save_to)
            images_to_plot.append(img_pil)
    else:
        for idx in idxs_reordered:
            img = images[idx]
            face_indicator = face_indicators[idx]
            face_bbox = face_bboxs[idx]
            pred_gender = preds_gender[idx]
            pred_class_prob_gender = pred_class_probs_gender[idx]
            
            if pred_gender == 1:
                pred = "Male"
                border_color = "blue"
            elif pred_gender == 0:
                pred = "Female"
                border_color = "red"
            elif pred_gender == -1:
                pred = "Undetected"
                border_color = "white"
            
            img_pil = transforms.ToPILImage()(img*0.5+0.5)
            img_pil_draw = ImageDraw.Draw(img_pil)  
            img_pil_draw.rectangle(face_bbox.tolist(), fill =None, outline =border_color, width=4)

            # img_pil = ImageOps.expand(img_pil, border=(50,0,0,50),fill=border_color)
            img_pil = ImageOps.expand(img_pil, border=(50,0,0,0),fill=border_color)

            img_pil_draw = ImageDraw.Draw(img_pil)
            if pred_class_prob_gender.item() < 1:
                img_pil_draw.rectangle([(0,0),(50,(1-pred_class_prob_gender.item())*512)], fill ="white", outline =None)
                
            img_pil_draw.text((400, 400), f"{idx.item()}", align ="left")

            img_pil = ImageOps.expand(img_pil_draw._image, border=(10,10,10,10),fill="black")
            
            # img_pil_draw._image.save(save_to)
            images_to_plot.append(img_pil)
        
    N_imgs = len(images_to_plot)
    N1 = int(math.sqrt(N_imgs))
    N2 = math.ceil(N_imgs / N1)

    for i in range(N1*N2-N_imgs):
        images_to_plot.append(
            Image.new('RGB', color="white", size=images_to_plot[0].size)
        )
    grid = image_grid(images_to_plot, N1, N2)
    if not os.path.exists(os.path.dirname(save_to)):
        os.makedirs(os.path.dirname(save_to))
    grid.save(save_to, quality=25)

def make_grad_hook(coef):
    return lambda x: coef * x

def customized_all_gather(tensor, accelerator, return_tensor_others=False):
    tensor_all = [tensor.detach().clone() for i in range(accelerator.state.num_processes)]
    torch.distributed.all_gather(tensor_all, tensor)
    if return_tensor_others:
        tensor_others = torch.cat([tensor_all[idx] for idx in range(accelerator.state.num_processes) if idx != accelerator.local_process_index], dim=0)
    tensor_all = torch.cat(tensor_all, dim=0)
    
    if return_tensor_others:
        return tensor_all, tensor_others
    else:
        return tensor_all

def expand_bbox(bbox, expand_coef, target_ratio):
    """
    bbox: [width_small, height_small, width_large, height_large], 
        this is the format returned from insightface.app.FaceAnalysis
    expand_coef: 0 is no expansion
    target_ratio: target img height/width ratio
    
    note that it is possible that bbox is outside the original image size
    confirmed for insightface.app.FaceAnalysis
    """
    
    bbox_width = bbox[2] - bbox[0]
    bbox_height = bbox[3] - bbox[1]
    
    current_ratio = bbox_height / bbox_width
    if current_ratio > target_ratio:
        more_height = bbox_height * expand_coef
        more_width = (bbox_height+more_height) / target_ratio - bbox_width
    elif current_ratio <= target_ratio:
        more_width = bbox_width * expand_coef
        more_height = (bbox_width+more_width) * target_ratio - bbox_height
    
    bbox_new = [0,0,0,0]
    bbox_new[0] = int(round(bbox[0] - more_width*0.5))
    bbox_new[2] = int(round(bbox[2] + more_width*0.5))
    bbox_new[1] = int(round(bbox[1] - more_height*0.5))
    bbox_new[3] = int(round(bbox[3] + more_height*0.5))
    return bbox_new

def crop_face(img_tensor, bbox_new, target_size, fill_value):
    """
    img_tensor: [3,H,W]
    bbox_new: [width_small, height_small, width_large, height_large]
    target_size: [width,height]
    fill_value: value used if need to pad
    """
    img_height, img_width = img_tensor.shape[-2:]
    
    idx_left = max(bbox_new[0],0)
    idx_right = min(bbox_new[2], img_width)
    idx_bottom = max(bbox_new[1],0)
    idx_top = min(bbox_new[3], img_height)

    pad_left = max(-bbox_new[0],0)
    pad_right = max(-(img_width-bbox_new[2]),0)
    pad_top = max(-bbox_new[1],0)
    pad_bottom = max(-(img_height-bbox_new[3]),0)

    img_face = img_tensor[:,idx_bottom:idx_top,idx_left:idx_right]
    if pad_left>0 or pad_top>0 or pad_right>0 or pad_bottom>0:
        img_face = torchvision.transforms.Pad([pad_left,pad_top,pad_right,pad_bottom], fill=fill_value)(img_face)
    img_face = torchvision.transforms.Resize(size=target_size)(img_face)
    return img_face

class soft_CELoss(nn.Module):
    def __init__(self, margin=0, CEL_args=None, reduction="none"):
        super(soft_CELoss, self).__init__()
        self.margin = margin
        self.CEL_args = CEL_args
        self.CELoss = nn.CrossEntropyLoss(CEL_args, reduction=reduction)
    
    def forward(self, logits, targets):
        if self.margin == 0:
            loss = self.CELoss(logits, targets)
        else:
            with torch.no_grad():
                probs = torch.softmax(logits,dim=-1)
                mask = [probs[i,target].item() <= (1-self.margin) for i,target in enumerate(targets)]
            loss = torch.zeros_like(logits[:,0])
            loss[mask] = self.CELoss(logits[mask], targets[mask])
        return loss

# Copied from transformers.models.clip.modeling_clip.py
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

def text_model_forward(
    text_encoder,
    input_ids: Optional[torch.Tensor] = None,
    token_embeds: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
    
    # copied from transformers 4.30.0.dev0, transformers/models/clip/modeling_clip.py/CLIPTextTransformer
    def _build_causal_attention_mask(bsz, seq_len, dtype, device=None):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
        mask.fill_(torch.finfo(dtype).min)
        mask.triu_(1)  # zero out the lower diagonal
        mask = mask.unsqueeze(1)  # expand mask
        return mask
    
    output_attentions = output_attentions if output_attentions is not None else text_encoder.text_model.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else text_encoder.text_model.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else text_encoder.text_model.config.use_return_dict

    if input_ids is None:
        raise ValueError("You have to specify input_ids")
    
    input_shape = input_ids.size()
    input_ids = input_ids.view(-1, input_shape[-1])
    
    # only difference!
    # hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
    hidden_states = token_embeds

    bsz, seq_len = input_shape
    # CLIP's text model uses causal mask, prepare it here.
    # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
    causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
        hidden_states.device
    )
    # expand attention_mask
    if attention_mask is not None:
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        attention_mask = _expand_mask(attention_mask, hidden_states.dtype)

    encoder_outputs = text_encoder.text_model.encoder(
        inputs_embeds=hidden_states,
        attention_mask=attention_mask,
        causal_attention_mask=causal_attention_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    last_hidden_state = encoder_outputs[0]
    last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)

    # text_embeds.shape = [batch_size, sequence_length, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
    # TODO: the pooled_output is wrong
    pooled_output = last_hidden_state[
        torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
        input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
    ]

    if not return_dict:
        return (last_hidden_state, pooled_output) + encoder_outputs[1:]

    return BaseModelOutputWithPooling(
        last_hidden_state=last_hidden_state,
        pooler_output=pooled_output,
        hidden_states=encoder_outputs.hidden_states,
        attentions=encoder_outputs.attentions,
    )

logger = get_logger(__name__)


def save_model_card(
    repo_id: str,
    images=None,
    base_model=str,
    train_text_encoder=False,
    prompt=str,
    repo_folder=None,
    pipeline: DiffusionPipeline = None,
):
    img_str = ""
    for i, image in enumerate(images):
        image.save(os.path.join(repo_folder, f"image_{i}.png"))
        img_str += f"![img_{i}](./image_{i}.png)\n"

    yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
- text-to-image
- diffusers
- lora
inference: true
---
    """
    model_card = f"""
# LoRA DreamBooth - {repo_id}

These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
{img_str}

LoRA for the text encoder was enabled: {train_text_encoder}.
"""
    with open(os.path.join(repo_folder, "README.md"), "w") as f:
        f.write(yaml + model_card)


def parse_args(input_args=None):
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        # default=None,
        default="runwayml/stable-diffusion-v1-5",
        required=False,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--seed", type=int, default="1995", help="A seed for reproducible training.")
    parser.add_argument(
        "--size_small",
        type=int,
        default=224,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--size_face",
        type=int,
        default=224,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--size_aligned_face",
        type=int,
        default=112,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--train_text_encoder",
        action="store_true",
        default=True,
        help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
    )
    parser.add_argument(
        "--train_unet",
        action="store_true",
        default=False,
        help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
    )
    parser.add_argument("--num_train_epochs", type=int, default=1)
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=999999,
        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--checkpointing_steps_long",
        type=int,
        default=200,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
            " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
            " training using `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=20,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
            " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
            " training using `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--checkpoints_total_limit",
        type=int,
        default=2,
        help=("Max number of checkpoints to store."),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=False,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--lr_num_cycles",
        type=int,
        default=1,
        help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
    )
    parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=0,
        help=(
            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
        ),
    )
    parser.add_argument(
        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
    )
    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
    parser.add_argument("--max_grad_norm", default=100.0, type=float, help="Max gradient norm.")
    # parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
    parser.add_argument(
        "--hub_model_id",
        type=str,
        default=None,
        help="The name of the repository to keep in sync with the local `output_dir`.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        default=True,
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="wandb",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="fp16",
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument(
        "--enable_xformers_memory_efficient_attention", 
        action="store_true", 
        default=True,
        help="Whether or not to use xformers."
    )
    # parser.add_argument(
    #     "--pre_compute_text_embeddings",
    #     action="store_true",
    #     help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
    # )
    parser.add_argument(
        "--tokenizer_max_length",
        type=int,
        default=None,
        required=False,
        help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
    )
    parser.add_argument(
        "--text_encoder_use_attention_mask",
        action="store_true",
        required=False,
        help="Whether to use attention mask for the text encoder",
    )
    parser.add_argument(
        "--rank",
        type=int,
        default=50,
        help=("The dimension of the LoRA update matrices."),
    )
    parser.add_argument(
        "--occupation_dataset_path",
        type=str,
        required=True,
        help=("train and test data file"),
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=1, # this train_batch_size must be 1, in our implementation
        help=("The dimension of the LoRA update matrices."),
    )
    parser.add_argument(
        '--val_batch_size', 
        default=9999,
        help="train, val, test batch size", 
        type=int, 
        required=False, 
    )
    parser.add_argument(
        '--recon_loss_weight', 
        default=5,
        help="train, val, test batch size", 
        type=float, 
        required=False, 
    )
    parser.add_argument(
        '--real_loss_weight', 
        default=0.1,
        help="train, val, test batch size", 
        type=float, 
        required=False, 
    )
    parser.add_argument(
        '--proj_name', 
        default="DAL_finetuning",
        help="train, val, test batch size", 
        type=str, 
        required=False, 
    )
    parser.add_argument(
        '--classifier_weight_path', 
        default=None,
        required=True,
        help="train, val, test batch size", 
        type=str,
        required=False, 
    )
    parser.add_argument(
        '--CEL_margin', 
        default=0,
        help="train, val, test batch size", 
        type=float,
        required=False, 
    )
    parser.add_argument('--train_images_per_prompt', help="train, val, test batch size", type=int, required=False, default=6)
    parser.add_argument('--train_GPU_batch_size', help="train, val, test batch size", type=int, required=False, default=3)
    parser.add_argument('--val_images_per_prompt', help="train, val, test batch size", type=int, required=False, default=10)
    parser.add_argument('--val_GPU_batch_size', help="train, val, test batch size", type=int, required=False, default=10)
    parser.add_argument('--guidance_scale', help="train, val, test batch size", type=float, required=False, default=7.5)
    parser.add_argument('--skip_uncertainty_threshold', help="train, val, test batch size", type=float, required=False, default=0.2)
    parser.add_argument('--scale_grad', help="train, val, test batch size", type=bool, required=False, default=True)
    parser.add_argument('--train_plot_every_n_iter', help="train, val, test batch size", type=int, required=False, default=20)
    parser.add_argument('--evaluate_every_n_iter', help="train, val, test batch size", type=int, required=False, default=200)
    parser.add_argument('--EMA_decay', help="train, val, test batch size", type=float, required=False, default=0.996)
    parser.add_argument('--face_feats_path', help="train, val, test batch size", type=str, default=None, required=True)
    parser.add_argument('--aligned_face_gender_model_path', help="train, val, test batch size", type=str, default=None, required=True)
    parser.add_argument('--aligned_face_gender_threshold', help="train, val, test batch size", type=float, default=0.99)
    parser.add_argument('--opensphere_config', help="train, val, test batch size", type=str, default=None, required=True)
    parser.add_argument('--opensphere_model_path', help="train, val, test batch size", type=str, default=None, required=True)
    parser.add_argument('--face_gender_confidence_level', help="train, val, test batch size", type=float, default=0.9)
    parser.add_argument('--factor1', help="train, val, test batch size", type=float, default=0.5)
    parser.add_argument('--factor2', help="train, val, test batch size", type=float, default=0.5)

    if input_args is not None:
        args = parser.parse_args(input_args)
    else:
        args = parser.parse_args()

    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    return args


class FairDataset(Dataset):
    def __init__(
        self,
        prompt_templates,
        occupations,
        key=None,
    ):
        
        self.prompt_templates = prompt_templates
        self.occupations = occupations
        self.prompts = [prompt.format(occupation=occupation) for prompt in prompt_templates for occupation in occupations]


    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, i):
        return self.prompts[i]


def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
    text_input_ids = input_ids.to(text_encoder.device)

    if text_encoder_use_attention_mask:
        attention_mask = attention_mask.to(text_encoder.device)
    else:
        attention_mask = None

    prompt_embeds = text_encoder(
        text_input_ids,
        attention_mask=attention_mask,
    )
    prompt_embeds = prompt_embeds[0]

    return prompt_embeds


def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
    r"""
    Returns:
        a state dict containing just the attention processor parameters.
    """
    attn_processors = unet.attn_processors

    attn_processors_state_dict = {}

    for attn_processor_key, attn_processor in attn_processors.items():
        for parameter_key, parameter in attn_processor.state_dict().items():
            attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter

    return attn_processors_state_dict


def main(args):
    logging_dir = Path(args.output_dir, args.logging_dir)

    kwargs = GradScalerKwargs(
        init_scale = 2.**0,
        growth_interval=99999999, 
        backoff_factor=0.5,
        growth_factor=2,
        )

    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
        kwargs_handlers=[kwargs]
    )

    if args.report_to == "wandb":
        if not is_wandb_available():
            raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
        import wandb

    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
    # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
        raise ValueError(
            "Gradient accumulation is not supported when training the text encoder in distributed training. "
            "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
        )
    
    if args.train_text_encoder and args.train_unet and args.gradient_accumulation_steps > 1:
        raise ValueError(
            "Gradient accumulation is not supported when training both text encoder and unet! This feature might be supported in the future. See https://github.com/huggingface/accelerate/issues/668"
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        # use the device_specific flag so each device generates different images during training
        set_seed(args.seed, device_specific=True)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
            
    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    now = datetime.datetime.now()
    timestring = f"{now.month:02}{now.day:02}{now.hour:02}{now.minute:02}"
    folder_name = f"MLP_BS{args.train_images_per_prompt*accelerator.num_processes}_MSE{args.recon_loss_weight}_real{args.real_loss_weight}_skip{args.skip_uncertainty_threshold}_factor{args.factor1}_{args.factor2}_lr{args.learning_rate}_loraR{args.rank}_{timestring}"
    
    args.imgs_save_dir = os.path.join(args.output_dir, folder_name, "imgs")
    args.ckpts_save_dir = os.path.join(args.output_dir, folder_name, "ckpts")

    if accelerator.is_main_process:
        os.makedirs(args.imgs_save_dir, exist_ok=True)
        os.makedirs(args.ckpts_save_dir, exist_ok=True)
        
        accelerator.init_trackers(
            args.proj_name, 
            init_kwargs = {
                "wandb": {
                    "name": folder_name, 
                    "dir": args.output_dir
                        }
                }
            )

        # if args.push_to_hub:
        #     repo_id = create_repo(
        #         repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
        #     ).repo_id

    # Load the tokenizer
    # if args.tokenizer_name:
    #     tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
    tokenizer = CLIPTokenizer.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="tokenizer"
        )
    text_encoder = CLIPTextModel.from_pretrained(
        args.pretrained_model_name_or_path, 
        subfolder="text_encoder"
        )
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="vae",
        )
    unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, 
        subfolder="unet",
        )
    noise_scheduler = DPMSolverMultistepScheduler.from_config(
        args.pretrained_model_name_or_path, 
        subfolder="scheduler",
        )
    

    # We only train the additional adapter LoRA layers
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    unet.enable_gradient_checkpointing()
    vae.enable_gradient_checkpointing()

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    weight_dtype_high_precision = torch.float32
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Move unet, vae and text_encoder to device and cast to weight_dtype
    text_encoder.to(accelerator.device, dtype=weight_dtype)
    unet.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)
    
    if args.train_text_encoder:
        eval_text_encoder = CLIPTextModel.from_pretrained(
            args.pretrained_model_name_or_path, 
            subfolder="text_encoder", 
            )
        eval_text_encoder.requires_grad_(False)
        eval_text_encoder.to(accelerator.device, dtype=weight_dtype)
        
    if args.train_unet:        
        eval_unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, 
        subfolder="unet",
        )
        eval_unet.requires_grad_(False)
        eval_unet.to(accelerator.device, dtype=weight_dtype)

    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            import xformers

            xformers_version = version.parse(xformers.__version__)
            if xformers_version == version.parse("0.0.16"):
                logger.warn(
                    "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
                )
            unet.enable_xformers_memory_efficient_attention()
            vae.enable_xformers_memory_efficient_attention()
            
            if args.train_unet:
                eval_unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")

    # now we will add new LoRA weights to the attention layers
    # It's important to realize here how many attention weights will be added and of which sizes
    # The sizes of the attention layers consist only of two different variables:
    # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
    # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

    # Let's first see how many attention processors we will have to set.
    # For Stable Diffusion, it should be equal to:
    # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
    # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
    # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
    # => 32 layers

    if args.train_unet:
        unet_lora_procs = {}
        for name in unet.attn_processors.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]

            unet_lora_procs[name] = LoRAAttnProcessor(
                hidden_size=hidden_size,
                cross_attention_dim=cross_attention_dim,
                rank=args.rank,
            ).to(accelerator.device)
            
        unet.set_attn_processor(unet_lora_procs)
        unet_lora_layers = AttnProcsLayers(unet.attn_processors)
        
        for p in unet_lora_layers.parameters():
            torch.distributed.broadcast(p, src=0)
        
        unet_lora_ema = EMAModel(unet_lora_layers.parameters(), decay=args.EMA_decay)
        unet_lora_ema.to(accelerator.device)
        # Set correct lora layers
        

    # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
    # So, instead, we monkey-patch the forward calls of its attention-blocks.
    if args.train_text_encoder:
        # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
        text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank, patch_mlp=True)
        
        for p in text_encoder_lora_params:
            torch.distributed.broadcast(p, src=0)
                    
        text_encoder_lora_dict = {}
        text_encoder_lora_params_name_order = []
        for lora_param in text_encoder_lora_params:
            for name, param in text_encoder.named_parameters():
                if param is lora_param:
                    text_encoder_lora_dict[name] = lora_param
                    text_encoder_lora_params_name_order.append(name)
                    break
        assert text_encoder_lora_dict.__len__() == len(text_encoder_lora_params), "length does not match! something wrong happened while converting lora params to a state dict."

        # text_encoder_lora_params is randomly initiazed w/ different values at different devices
        # a hacky way to broadcast from main_process
        for name in text_encoder_lora_params_name_order:
            if accelerator.is_main_process:
                lora_param = text_encoder_lora_dict[name].detach().clone()
            else:
                lora_param = torch.zeros_like(text_encoder_lora_dict[name])
            torch.distributed.broadcast(lora_param, src=0)
            text_encoder_lora_dict[name].data = lora_param

        class CustomModel(torch.nn.Module):
            def __init__(self, dict):
                """
                In the constructor we instantiate four parameters and assign them as
                member parameters.
                """
                super().__init__()
                self.param_names = list(dict.keys())
                self.params = nn.ParameterList()
                for name in self.param_names:
                    self.params.append( dict[name] )

            def forward(self, x):
                """
                no forward function
                """
                return None
        text_encoder_lora_model = CustomModel(text_encoder_lora_dict)

        text_encoder_lora_ema = EMAModel(text_encoder_lora_params, decay=args.EMA_decay)
        text_encoder_lora_ema.to(accelerator.device)

        text_encoder_lora_ema_dict = {}
        for name, shadow_param in itertools.zip_longest(text_encoder_lora_params_name_order, text_encoder_lora_ema.shadow_params):
            text_encoder_lora_ema_dict[name] = shadow_param
        assert text_encoder_lora_ema_dict.__len__() == text_encoder_lora_dict.__len__(), "length does not match! something wrong happened while converting lora params to a state dict."

    if args.train_text_encoder:
        print(f"{accelerator.device}, text_encoder, {list(text_encoder_lora_model.parameters())[0].flatten()[1]:.6f}, {text_encoder_lora_ema.shadow_params[0].flatten()[1]:.6f}")
    if args.train_unet:
        print(f"{accelerator.device}, unet, {list(unet_lora_layers.parameters())[0].flatten()[1]:.6f}, {unet_lora_ema.shadow_params[0].flatten()[1]:.6f}")

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if args.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
        )

    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
            )

        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    # Optimizer creation
    # params_to_optimize = text_encoder_lora_params
    # if args.train_text_encoder and args.train_unet:
    #     params_to_optimize = itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
    # elif args.train_text_encoder and not args.train_unet:
    #     params_to_optimize = text_encoder_lora_params
    # elif not args.train_text_encoder and args.train_unet:
    #     params_to_optimize = unet_lora_layers.parameters()
    # else:
    #     raise ValueError("Both args.train_text_encoder and args.train_unet are False. At least one must be True.")

    if args.train_text_encoder and args.train_unet:
        params_to_optimize = itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_model.parameters())
    elif args.train_text_encoder and not args.train_unet:
        params_to_optimize = text_encoder_lora_model.parameters()
    else:
        raise ValueError(f"Not implemented: args.train_text_encoder={args.train_text_encoder}, args.train_unet={args.train_unet}.")
        
    optimizer = optimizer_class(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    # Dataset and DataLoaders creation:
    with open(args.occupation_dataset_path, 'r') as f:
        experiment_data = json.load(f)
    
    train_dataset = FairDataset(
        prompt_templates=experiment_data["prompt_templates_train"],
        occupations= experiment_data["occupations_train_full"]
    )
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, num_workers=0, shuffle=True)
    # self-make a simple dataloader
    random.seed(1997)
    train_dataloader_idxs = []
    for epoch in range(100):
        idxs = list(range(train_dataset.__len__()))
        random.shuffle(idxs)
        train_dataloader_idxs.append(idxs)
    val_dataset = FairDataset(
        prompt_templates=experiment_data["prompt_templates_test"],
        occupations=experiment_data["occupations_test_small"],
        key="TestSmall"
        )
    val_dataloader_idxs = [list(range(val_dataset.__len__()))]
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=0)
    
    
    #######################################################
    # set up things needed for fair finetuning        
    gender_classifier = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT, width_mult=1.0, reduced_tail=False, dilated=False)
    gender_classifier._modules['classifier'][3] = nn.Linear(1280, 80, bias=True)
    
    gender_classifier.load_state_dict(torch.load(args.classifier_weight_path))
    gender_classifier.to(accelerator.device, dtype=weight_dtype)
    gender_classifier.requires_grad_(False)
    gender_classifier.eval()
    
    # get the explicit device_id
    device_id = [p for p in gender_classifier.parameters()][0].device.index
    face_app = FaceAnalysis(
        name="buffalo_l",
        allowed_modules=['detection'], 
        providers=['CUDAExecutionProvider'], 
        provider_options=[{'device_id': device_id}]
        )
    face_app.prepare(ctx_id=0, det_size=(640, 640))
    
    
    clip_image_processoor = CLIPImageProcessor.from_pretrained(
        "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
    )
    clip_vision_model_w_proj = CLIPVisionModelWithProjection.from_pretrained(
        "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
        # pytorch_dtype=pytorch_dtype
    )
    clip_vision_model_w_proj.vision_model.to(accelerator.device, dtype=weight_dtype)
    clip_vision_model_w_proj.visual_projection.to(accelerator.device, dtype=weight_dtype)
    clip_vision_model_w_proj.requires_grad_(False)
    clip_vision_model_w_proj.gradient_checkpointing_enable()
    clip_img_mean = torch.tensor(clip_image_processoor.image_mean).reshape([-1,1,1]).to(accelerator.device, dtype=weight_dtype) # mean is based on range [0,1]
    clip_img_std = torch.tensor(clip_image_processoor.image_std).reshape([-1,1,1]).to(accelerator.device, dtype=weight_dtype) # std is based on range [0,1]
    
    
    dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
    dinov2.to(accelerator.device, dtype=weight_dtype)
    dinov2.requires_grad_(False)
    dinov2_img_mean = torch.tensor([0.485, 0.456, 0.406]).reshape([-1,1,1]).to(accelerator.device, dtype=weight_dtype)
    dinov2_img_std = torch.tensor([0.229, 0.224, 0.225]).reshape([-1,1,1]).to(accelerator.device, dtype=weight_dtype)
    
    CELoss = soft_CELoss(margin=args.CEL_margin)
    
    # set dlib to usefor later use of face_recognition function
    import face_recognition
    
    import pathlib
    import yaml
    sys.path.append(pathlib.Path(__file__).parent.resolve().__str__())
    sys.path.append(pathlib.Path(__file__).parent.resolve().joinpath("opensphere").__str__())
    from opensphere.builder import build_dataloader, build_from_cfg
    from opensphere.utils import fill_config
    # build model
    with open(args.opensphere_config, 'r') as f:
        opensphere_config = yaml.load(f, yaml.SafeLoader)
    opensphere_config['data'] = fill_config(opensphere_config['data'])
    face_feats_net = build_from_cfg(
        opensphere_config['model']['backbone']['net'],
        'model.backbone',
    )
    face_feats_net = nn.DataParallel(face_feats_net)
    face_feats_net.load_state_dict(torch.load(args.opensphere_model_path))
    face_feats_net = face_feats_net.module
    face_feats_net.to(accelerator.device)
    face_feats_net.requires_grad_(False)
    face_feats_net.to(weight_dtype)
    face_feats_net.eval()
    
    algined_face_gender_classifier = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT, width_mult=1.0, reduced_tail=False, dilated=False)
    algined_face_gender_classifier._modules['classifier'][3] = nn.Linear(1280, 80, bias=True)
    algined_face_gender_classifier.load_state_dict(torch.load(args.aligned_face_gender_model_path))
    algined_face_gender_classifier.requires_grad_(False)
    algined_face_gender_classifier.to(weight_dtype)
    algined_face_gender_classifier.to(accelerator.device)
    algined_face_gender_classifier.eval()
    
    face_feats_model = FaceFeatsModel(args.face_feats_path, args.aligned_face_gender_threshold)
    face_feats_model.to(weight_dtype_high_precision)
    face_feats_model.to(accelerator.device)
    face_feats_model.eval()
    #######################################################
    
    @torch.no_grad()
    def generate_image_no_gradient(prompt, noises, num_denoising_steps, which_text_encoder, which_unet):
        """
        prompts: str
        noises: [N,4,64,64], N is number images to be generated for the prompt
        """
        N = noises.shape[0]
        prompts = [prompt] * N
        
        prompts_token = tokenizer(prompts, return_tensors="pt", padding=True)
        prompts_token["input_ids"] = prompts_token["input_ids"].to(accelerator.device)
        prompts_token["attention_mask"] = prompts_token["attention_mask"].to(accelerator.device)

        prompt_embeds = which_text_encoder(
            prompts_token["input_ids"],
            prompts_token["attention_mask"],
        )
        prompt_embeds = prompt_embeds[0]

        batch_size = prompt_embeds.shape[0]
        uncond_tokens = [""] * batch_size
        max_length = prompt_embeds.shape[1]
        uncond_input = tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )
        uncond_input["input_ids"] = uncond_input["input_ids"].to(accelerator.device)
        uncond_input["attention_mask"] = uncond_input["attention_mask"].to(accelerator.device)
        negative_prompt_embeds = which_text_encoder(
            uncond_input["input_ids"],
            uncond_input["attention_mask"],
        )
        negative_prompt_embeds = negative_prompt_embeds[0]

        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
        prompt_embeds = prompt_embeds.to(weight_dtype)
        
        noise_scheduler.set_timesteps(num_denoising_steps)
        latents = noises
        for i, t in enumerate(noise_scheduler.timesteps):
        
            # scale model input
            latent_model_input = torch.cat([latents.to(weight_dtype)] * 2)
            latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
            
            noises_pred = which_unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
            ).sample
            noises_pred = noises_pred.to(weight_dtype_high_precision)
            
            noises_pred_uncond, noises_pred_text = noises_pred.chunk(2)
            noises_pred = noises_pred_uncond + args.guidance_scale * (noises_pred_text - noises_pred_uncond)
            
            latents = noise_scheduler.step(noises_pred, t, latents).prev_sample

        latents = 1 / vae.config.scaling_factor * latents
        images = vae.decode(latents.to(vae.dtype)).sample.clamp(-1,1) # in range [-1,1]
        
        return images
    
    def generate_image_w_gradient(prompt, noises, num_denoising_steps, which_text_encoder, which_unet):
        """
        prompts: str
        noises: [N,4,64,64], N is number images to be generated for the prompt
        """
        # to enable gradient_checkpointing, unet must be set to train()
        unet.train()
        
        N = noises.shape[0]
        prompts = [prompt] * N
        
        prompts_token = tokenizer(prompts, return_tensors="pt", padding=True)
        prompts_token["input_ids"] = prompts_token["input_ids"].to(accelerator.device)
        prompts_token["attention_mask"] = prompts_token["attention_mask"].to(accelerator.device)

        prompt_embeds = which_text_encoder(
            prompts_token["input_ids"],
            prompts_token["attention_mask"],
        )
        prompt_embeds = prompt_embeds[0]

        batch_size = prompt_embeds.shape[0]
        uncond_tokens = [""] * batch_size
        max_length = prompt_embeds.shape[1]
        uncond_input = tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )
        uncond_input["input_ids"] = uncond_input["input_ids"].to(accelerator.device)
        uncond_input["attention_mask"] = uncond_input["attention_mask"].to(accelerator.device)
        negative_prompt_embeds = which_text_encoder(
            uncond_input["input_ids"],
            uncond_input["attention_mask"],
        )
        negative_prompt_embeds = negative_prompt_embeds[0]

        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]).to(weight_dtype)
        
        noise_scheduler.set_timesteps(num_denoising_steps)
        if args.scale_grad:
            grad_coefs = []
            for i, t in enumerate(noise_scheduler.timesteps):
                grad_coefs.append( noise_scheduler.alphas_cumprod[t].sqrt().item() * (1-noise_scheduler.alphas_cumprod[t]).sqrt().item() / (1-noise_scheduler.alphas[t].item()) )
            grad_coefs = np.array(grad_coefs)
            grad_coefs /= (math.prod(grad_coefs)**(1/len(grad_coefs)))
            
        latents = noises
        for i, t in enumerate(noise_scheduler.timesteps):
        
            # scale model input
            latent_model_input = torch.cat([latents.detach().to(weight_dtype)]*2)
            latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
            
            noises_pred = which_unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
            ).sample
            noises_pred = noises_pred.to(weight_dtype_high_precision)
            
            noises_pred_uncond, noises_pred_text = noises_pred.chunk(2)
            noises_pred = noises_pred_uncond + args.guidance_scale * (noises_pred_text - noises_pred_uncond)
            
            if args.scale_grad:
                hook_fn = make_grad_hook(grad_coefs[i])
                noises_pred.register_hook(hook_fn)
            
            latents = noise_scheduler.step(noises_pred, t, latents).prev_sample

        latents = 1 / vae.config.scaling_factor * latents
        images = vae.decode(latents.to(vae.dtype)).sample.clamp(-1,1) # in range [-1,1]
        
        return images
    
    
    def get_clip_feat(images, normalize=True, to_high_precision=True):
        """
        images:shape [N,3,H,W], in range [-1,1], pytorch tensor
        """
        images_preprocessed = ((images+1)*0.5 - clip_img_mean) / clip_img_std
        embeds = clip_vision_model_w_proj(images_preprocessed).image_embeds
        
        if to_high_precision:
            embeds = embeds.to(torch.float)
        if normalize:
            embeds = torch.nn.functional.normalize(embeds, dim=-1)
        return embeds
    
    def get_dino_feat(images, normalize=True, to_high_precision=True):
        """
        images:shape [N,3,H,W], in range [-1,1], pytorch tensor
        """
        images_preprocessed = ((images+1)*0.5 - dinov2_img_mean) / dinov2_img_std
        embeds = dinov2(images_preprocessed)
        
        if to_high_precision:
            embeds = embeds.to(torch.float)
        if normalize:
            embeds = torch.nn.functional.normalize(embeds, dim=-1)
        return embeds
    
    def get_face_feats(net, data, flip=True, normalize=True, to_high_precision=True):
        # extract features from the original 
        # and horizontally flipped data
        feats = net(data)
        if flip:
            data = torch.flip(data, [3])
            feats += net(data)

        if to_high_precision:
            feats = feats.to(torch.float)
        if normalize:
            feats = torch.nn.functional.normalize(feats, dim=-1)
            
        return feats
    
    def image_pipeline(img, tgz_landmark):
        img = (img+1)/2.0 * 255 # map to [0,255]

        crop_size = (112,112)
        src_landmark = np.array(
        [[38.2946, 51.6963], # 左眼
        [73.5318, 51.5014], # 右眼
        [56.0252, 71.7366], # 鼻子
        [41.5493, 92.3655], # 左嘴角
        [70.7299, 92.2041]] # 右嘴角
        )

        tform = transform.SimilarityTransform()
        tform.estimate(tgz_landmark, src_landmark)

        M = torch.tensor(tform.params[0:2, :]).unsqueeze(dim=0).to(img.dtype).to(img.device)
        img_face = kornia.geometry.transform.warp_affine(img.unsqueeze(dim=0), M, crop_size, mode='bilinear', padding_mode='zeros', align_corners=False)
        img_face = img_face.squeeze()

        img_face = (img_face/255.0)*2-1 # map back to [-1,1]
        return img_face
        
    def get_face(images, fill_value=-1):
        """
        images:shape [N,3,H,W], in range [-1,1], pytorch tensor
        returns:
            face_indicators: torch tensor of shape [N], only True or False
                True means face is detected, False otherwise
            face_bboxs: torch tensor of shape [N,4], 
                if face_indicator is False, the corresponding face_bbox will be [fill_value,fill_value,fill_value,fill_value]
            face_chips: torch tensor of shape [N,3,224,224]
                if face_indicator is False, the corresponding face_chip will be all fill_value
        """
        face_indicators_app, face_bboxs_app, face_chips_app, face_landmarks_app, aligned_face_chips_app = get_face_app(images, fill_value=fill_value)
        face_indicators_FR, face_bboxs_FR, face_chips_FR, face_landmarks_FR, aligned_face_chips_FR = get_face_FR(images, fill_value=fill_value)
        
        # debug visualization code; do not remove
        # for idx in range(images.shape[0]):
        #     img = images[idx]
        #     face_chip = face_chips_FR[idx]
        #     aligned_face_chip = aligned_face_chips_FR[idx]

        face_indicators_out = face_indicators_app.clone()
        face_bboxs_out = face_bboxs_app.clone()
        face_chips_out = face_chips_app.clone()
        face_landmarks_out = face_landmarks_app.clone()
        aligned_face_chips_out = aligned_face_chips_app.clone()

        face_indicators_out[face_indicators_app.logical_not()] = face_indicators_FR[face_indicators_app.logical_not()]
        face_bboxs_out[face_indicators_app.logical_not()] = face_bboxs_FR[face_indicators_app.logical_not()]
        face_chips_out[face_indicators_app.logical_not()] = face_chips_FR[face_indicators_app.logical_not()]
        face_landmarks_out[face_indicators_app.logical_not()] = face_landmarks_FR[face_indicators_app.logical_not()]
        aligned_face_chips_out[face_indicators_app.logical_not()] = aligned_face_chips_FR[face_indicators_app.logical_not()]
        
        return face_indicators_out, face_bboxs_out, face_chips_out, face_landmarks_out, aligned_face_chips_out
    
    def get_face_FR(images, fill_value=-1):
        """
        images:shape [N,3,H,W], in range [-1,1], pytorch tensor
        returns:
            face_indicators: torch tensor of shape [N], only True or False
                True means face is detected, False otherwise
            face_bboxs: torch tensor of shape [N,4], 
                if face_indicator is False, the corresponding face_bbox will be [fill_value,fill_value,fill_value,fill_value]
            face_chips: torch tensor of shape [N,3,224,224]
                if face_indicator is False, the corresponding face_chip will be all fill_value
        """
        def get_largest_face(faces_from_FR, dim_max, dim_min):
            if len(faces_from_FR) == 1:
                return faces_from_FR[0]
            elif len(faces_from_FR) > 1:
                area_max = 0
                idx_max = 0
                for idx, bbox in enumerate(faces_from_FR):
                    bbox1 = np.array((bbox[-1],) + bbox[:-1])
                    area = (min(bbox1[2],dim_max) - max(bbox1[0], dim_min)) * (min(bbox1[3],dim_max) - max(bbox1[1], dim_min))
                    if area > area_max:
                        area_max = area
                        idx_max = idx
                return faces_from_FR[idx_max]
        images_np = ((images*0.5 + 0.5)*255).cpu().detach().permute(0,2,3,1).float().numpy().astype(np.uint8)
        
        face_indicators_FR = []
        face_bboxs_FR = []
        face_chips_FR = []
        face_landmarks_FR = []
        aligned_face_chips_FR = []
        for idx, image_np in enumerate(images_np):
            faces_from_FR = face_recognition.face_locations(image_np, model="cnn")
            if len(faces_from_FR) == 0:
                face_indicators_FR.append(False)
                face_bboxs_FR.append([fill_value]*4)
                face_chips_FR.append(torch.ones([1,3,args.size_face,args.size_face], dtype=images.dtype, device=images.device)*(fill_value))
                face_landmarks_FR.append(torch.ones([1,5,2], dtype=images.dtype, device=images.device)*(fill_value))
                aligned_face_chips_FR.append(torch.ones([1,3,args.size_aligned_face,args.size_aligned_face], dtype=images.dtype, device=images.device)*(fill_value))
            else:
                face_from_FR = get_largest_face(faces_from_FR, dim_max=image_np.shape[0], dim_min=0)
                bbox = face_from_FR
                bbox = np.array((bbox[-1],) + bbox[:-1]) # need to convert bbox from face_recognition to the right order
                bbox = expand_bbox(bbox, expand_coef=1.1, target_ratio=1) # need to use a larger expand_coef for FR
                face_chip = crop_face(images[idx], bbox, target_size=[args.size_face,args.size_face], fill_value=fill_value)
                
                face_landmarks = face_recognition.face_landmarks(image_np, face_locations=[face_from_FR], model="large")

                left_eye = np.array(face_landmarks[0]["left_eye"]).mean(axis=0)
                right_eye = np.array(face_landmarks[0]["right_eye"]).mean(axis=0)
                nose_tip = np.array(face_landmarks[0]["nose_bridge"][-1])
                top_lip_left = np.array(face_landmarks[0]["top_lip"][0])
                top_lip_right = np.array(face_landmarks[0]["top_lip"][6])
                face_landmarks = np.stack([left_eye, right_eye, nose_tip, top_lip_left, top_lip_right])
                
                aligned_face_chip = image_pipeline(images[idx], face_landmarks)
                
                face_indicators_FR.append(True)
                face_bboxs_FR.append(bbox)
                face_chips_FR.append(face_chip.unsqueeze(dim=0))
                face_landmarks_FR.append(torch.tensor(face_landmarks).unsqueeze(dim=0).to(device=images.device).to(images.dtype))
                aligned_face_chips_FR.append(aligned_face_chip.unsqueeze(dim=0))
        
        face_indicators_FR = torch.tensor(face_indicators_FR).to(device=images.device)
        face_bboxs_FR = torch.tensor(face_bboxs_FR).to(device=images.device)
        face_chips_FR = torch.cat(face_chips_FR, dim=0)
        face_landmarks_FR = torch.cat(face_landmarks_FR, dim=0)
        aligned_face_chips_FR = torch.cat(aligned_face_chips_FR, dim=0)
        
        return face_indicators_FR, face_bboxs_FR, face_chips_FR, face_landmarks_FR, aligned_face_chips_FR
    
    def get_face_app(images, fill_value=-1):
        """
        images:shape [N,3,H,W], in range [-1,1], pytorch tensor
        returns:
            face_indicators: torch tensor of shape [N], only True or False
                True means face is detected, False otherwise
            face_bboxs: torch tensor of shape [N,4], 
                if face_indicator is False, the corresponding face_bbox will be [fill_value,fill_value,fill_value,fill_value]
            face_chips: torch tensor of shape [N,3,224,224]
                if face_indicator is False, the corresponding face_chip will be all fill_value
        """
        def get_largest_face(face_from_app, dim_max, dim_min):
            if len(face_from_app) == 1:
                return face_from_app[0]
            elif len(face_from_app) > 1:
                area_max = 0
                idx_max = 0
                for idx in range(len(face_from_app)):
                    bbox = face_from_app[idx]["bbox"]
                    area = (min(bbox[2],dim_max) - max(bbox[0], dim_min)) * (min(bbox[3],dim_max) - max(bbox[1], dim_min))
                    if area > area_max:
                        area_max = area
                        idx_max = idx
                return face_from_app[idx_max]
        
        images_np = ((images*0.5 + 0.5)*255).cpu().detach().permute(0,2,3,1).float().numpy().astype(np.uint8)
        
        face_indicators_app = []
        face_bboxs_app = []
        face_chips_app = []
        face_landmarks_app = []
        aligned_face_chips_app = []
        for idx, image_np in enumerate(images_np):
            # face_app.get input should be [BGR]
            faces_from_app = face_app.get(image_np[:,:,[2,1,0]])
            if len(faces_from_app) == 0:
                face_indicators_app.append(False)
                face_bboxs_app.append([fill_value]*4)
                face_chips_app.append(torch.ones([1,3,args.size_face,args.size_face], dtype=images.dtype, device=images.device)*(fill_value))
                face_landmarks_app.append(torch.ones([1,5,2], dtype=images.dtype, device=images.device)*(fill_value))
                aligned_face_chips_app.append(torch.ones([1,3,args.size_aligned_face,args.size_aligned_face], dtype=images.dtype, device=images.device)*(fill_value))
            else:
                face_from_app = get_largest_face(faces_from_app, dim_max=image_np.shape[0], dim_min=0)
                bbox = expand_bbox(face_from_app["bbox"], expand_coef=0.5, target_ratio=1)
                face_chip = crop_face(images[idx], bbox, target_size=[args.size_face,args.size_face], fill_value=fill_value)
                
                face_landmarks = np.array(face_from_app["kps"])
                aligned_face_chip = image_pipeline(images[idx], face_landmarks)
                
                face_indicators_app.append(True)
                face_bboxs_app.append(bbox)
                face_chips_app.append(face_chip.unsqueeze(dim=0))
                face_landmarks_app.append(torch.tensor(face_landmarks).unsqueeze(dim=0).to(device=images.device).to(images.dtype))
                aligned_face_chips_app.append(aligned_face_chip.unsqueeze(dim=0))
        
        face_indicators_app = torch.tensor(face_indicators_app).to(device=images.device)
        face_bboxs_app = torch.tensor(face_bboxs_app).to(device=images.device)
        face_chips_app = torch.cat(face_chips_app, dim=0)
        face_landmarks_app = torch.cat(face_landmarks_app, dim=0)
        aligned_face_chips_app = torch.cat(aligned_face_chips_app, dim=0)
        
        return face_indicators_app, face_bboxs_app, face_chips_app, face_landmarks_app, aligned_face_chips_app
                
    def get_face_gender(face_chips, selector=None, fill_value=-1):
        """for CelebA classifier
        """
        if selector != None:
            face_chips_w_faces = face_chips[selector]
        else:
            face_chips_w_faces = face_chips
            
        if face_chips_w_faces.shape[0] == 0:
            logits_gender = torch.empty([0,2], dtype=face_chips.dtype, device=face_chips.device)
            probs_gender = torch.empty([0,2], dtype=face_chips.dtype, device=face_chips.device)
            # pred_class_probs_gender = torch.empty([0], dtype=face_chips.dtype, device=face_chips.device)
            preds_gender = torch.empty([0], dtype=torch.int64, device=face_chips.device)
        else:
            logits = gender_classifier(face_chips_w_faces)
            logits_gender = logits.view([logits.shape[0],-1,2])[:,20,:]
            probs_gender = torch.softmax(logits_gender, dim=-1)
        
            temp = probs_gender.max(dim=-1)
            # pred_class_probs_gender = temp.values
            preds_gender = temp.indices
        
        if selector != None:
            preds_gender_new = torch.ones(
                [selector.shape[0]]+list(preds_gender.shape[1:]), 
                dtype=preds_gender.dtype, 
                device=preds_gender.device
                ) * (fill_value)
            preds_gender_new[selector] = preds_gender
            
            probs_gender_new = torch.ones(
                [selector.shape[0]]+list(probs_gender.shape[1:]),
                dtype=probs_gender.dtype, 
                device=probs_gender.device
                ) * (fill_value)
            probs_gender_new[selector] = probs_gender
            
            logits_gender_new = torch.ones(
                [selector.shape[0]]+list(logits_gender.shape[1:]),
                dtype=logits_gender.dtype, 
                device=logits_gender.device
                ) * (fill_value)
            logits_gender_new[selector] = logits_gender
            
            return preds_gender_new, probs_gender_new, logits_gender_new
        else:
            return preds_gender, probs_gender, logits_gender
        
    @torch.no_grad()
    def generate_dynamic_targets(probs, target_ratio=0.5, w_uncertainty=False):
        """
        probs: [N,C], N points in a probability simplex of C dims
        """
        idxs_2_rank = (probs!=-1).all(dim=-1)
        probs_2_rank = probs[idxs_2_rank]

        rank = torch.argsort(torch.argsort(probs_2_rank[:,1]))
        targets = (rank >= (rank.shape[0]/2)).long()

        targets_all = torch.ones([probs.shape[0]], dtype=torch.long, device=probs.device) * (-1)
        targets_all[idxs_2_rank] = targets
        
        if w_uncertainty:
            uncertainty = torch.ones([probs_2_rank.shape[0]], dtype=probs.dtype, device=probs.device) * (-1)
            uncertainty[targets==1] = torch.tensor(
                1 - scipy.stats.binom.cdf(
                    (rank[targets==1]).cpu().numpy(), 
                    probs_2_rank.shape[0], 
                    1-target_ratio
                    )
                ).to(probs.dtype).to(probs.device)
            uncertainty[targets==0] = torch.tensor(
                scipy.stats.binom.cdf(
                    rank[targets==0].cpu().numpy(), 
                    probs_2_rank.shape[0], 
                    target_ratio
                    )
                ).to(probs.dtype).to(probs.device)
            
            uncertainty_all = torch.ones([probs.shape[0]], dtype=probs.dtype, device=probs.device) * (-1)
            uncertainty_all[idxs_2_rank] = uncertainty
            
            return targets_all, uncertainty_all
        else:
            return targets_all

    @torch.no_grad()
    def evaluate_process(which_text_encoder, which_unet, name, prompts, noises, current_global_step):
        logs = []
        log_imgs = []
        num_denoising_steps = 25
        for prompt_i, noises_i in itertools.zip_longest(prompts, noises):
            if accelerator.is_main_process:
                logs_i = {
                    "gender_gap_0.5": [],
                    "gender_gap_0.3": [],
                    "gender_gap_0.1": [],
                    "gender_gap_abs_0.5": [],
                    "gender_gap_abs_0.3": [],
                    "gender_gap_abs_0.1": [],
                    "gender_mid_0.3": [],
                    "gender_mid_0.1": [],
                }
                log_imgs_i = {}
            ################################################
            # step 1: generate all ori images
            images_ori = []
            N = math.ceil(noises_i.shape[0] / args.val_GPU_batch_size)
            for j in range(N):
                noises_ij = noises_i[args.val_GPU_batch_size*j:args.val_GPU_batch_size*(j+1)]
                if args.train_text_encoder and args.train_unet:
                    images_ij = generate_image_no_gradient(prompt_i, noises_ij, num_denoising_steps, which_text_encoder=eval_text_encoder, which_unet=eval_unet)
                elif args.train_text_encoder and not args.train_unet:
                    images_ij = generate_image_no_gradient(prompt_i, noises_ij, num_denoising_steps, which_text_encoder=eval_text_encoder, which_unet=unet)
                images_ori.append(images_ij)
            images_ori = torch.cat(images_ori)
            face_indicators_ori, face_bboxs_ori, face_chips_ori, face_landmarks_ori, aligned_face_chips_ori = get_face(images_ori)
            preds_gender_ori, probs_gender_ori, logits_gender_ori = get_face_gender(face_chips_ori, selector=face_indicators_ori, fill_value=-1)
            
            face_feats_ori = get_face_feats(face_feats_net, aligned_face_chips_ori)
            _, face_real_scores_ori = face_feats_model.semantic_search(face_feats_ori, genders=None, selector=face_indicators_ori, return_similarity=True)

            images_ori_all = customized_all_gather(images_ori, accelerator, return_tensor_others=False)
            face_indicators_ori_all = customized_all_gather(face_indicators_ori, accelerator, return_tensor_others=False)
            face_bboxs_ori_all = customized_all_gather(face_bboxs_ori, accelerator, return_tensor_others=False)
            preds_gender_ori_all = customized_all_gather(preds_gender_ori, accelerator, return_tensor_others=False)
            probs_gender_ori_all = customized_all_gather(probs_gender_ori, accelerator, return_tensor_others=False)
            face_real_scores_ori_all = customized_all_gather(face_real_scores_ori, accelerator, return_tensor_others=False)

            if accelerator.is_main_process:
                save_to = os.path.join(args.imgs_save_dir, f"eval_{name}_{global_step}_{prompt_i}_ori.jpg")
                plot_in_grid(
                    images_ori_all, 
                    save_to, 
                    face_indicators=face_indicators_ori_all, face_bboxs=face_bboxs_ori_all, 
                    preds_gender=preds_gender_ori_all, 
                    pred_class_probs_gender=probs_gender_ori_all.max(dim=-1).values,
                    face_real_scores=face_real_scores_ori_all
                )

                log_imgs_i["img_ori"] = [save_to]

            
            images = []
            N = math.ceil(noises_i.shape[0] / args.val_GPU_batch_size)
            for j in range(N):
                noises_ij = noises_i[args.val_GPU_batch_size*j:args.val_GPU_batch_size*(j+1)]
                images_ij = generate_image_no_gradient(prompt_i, noises_ij, num_denoising_steps, which_text_encoder=which_text_encoder, which_unet=which_unet)
                images.append(images_ij)
            images = torch.cat(images)
            
            face_indicators, face_bboxs, face_chips, face_landmarks, aligned_face_chips = get_face(images)
            preds_gender, probs_gender, logits_gender = get_face_gender(face_chips, selector=face_indicators, fill_value=-1)
            
            face_feats = get_face_feats(face_feats_net, aligned_face_chips)
            _, face_real_scores = face_feats_model.semantic_search(face_feats, genders=None, selector=face_indicators, return_similarity=True)

            images_all = customized_all_gather(images, accelerator, return_tensor_others=False)
            face_indicators_all = customized_all_gather(face_indicators, accelerator, return_tensor_others=False)
            face_bboxs_all = customized_all_gather(face_bboxs, accelerator, return_tensor_others=False)
            preds_gender_all = customized_all_gather(preds_gender, accelerator, return_tensor_others=False)
            probs_gender_all = customized_all_gather(probs_gender, accelerator, return_tensor_others=False)
            face_real_scores = customized_all_gather(face_real_scores, accelerator, return_tensor_others=False)

            if accelerator.is_main_process:
                save_to = os.path.join(args.imgs_save_dir, f"eval_{name}_{global_step}_{prompt_i}_generated.jpg")
                plot_in_grid(
                    images_all, 
                    save_to, 
                    face_indicators=face_indicators_all, 
                    face_bboxs=face_bboxs_all, 
                    preds_gender=preds_gender_all, 
                    pred_class_probs_gender=probs_gender_all.max(dim=-1).values,
                    face_real_scores=face_real_scores)

                log_imgs_i["img_generated"] = [save_to]
            
            if accelerator.is_main_process:
                probs_tmp = probs_gender_all[(probs_gender_all!=-1).all(dim=-1)]
                gender_gap_05 = (((probs_tmp[:,1]>=0.5)*(probs_tmp[:,1]<=1)).float().mean() - ((probs_tmp[:,1]>=0)*(probs_tmp[:,1]<=0.5)).float().mean()).item()
                gender_gap_03 = (((probs_tmp[:,1]>=0.7)*(probs_tmp[:,1]<=1)).float().mean() - ((probs_tmp[:,1]>=0)*(probs_tmp[:,1]<=0.3)).float().mean()).item()
                gender_gap_01 = (((probs_tmp[:,1]>=0.9)*(probs_tmp[:,1]<=1)).float().mean() - ((probs_tmp[:,1]>=0)*(probs_tmp[:,1]<=0.1)).float().mean()).item()
                gender_mid_03 = ((probs_tmp[:,1]>0.3)*(probs_tmp[:,1]<0.7)).float().mean().item()
                gender_mid_01 = ((probs_tmp[:,1]>0.1)*(probs_tmp[:,1]<0.9)).float().mean().item()
                logs_i["gender_gap_0.5"].append(gender_gap_05)
                logs_i["gender_gap_0.3"].append(gender_gap_03)
                logs_i["gender_gap_0.1"].append(gender_gap_01)
                logs_i["gender_gap_abs_0.5"].append(abs(gender_gap_05))
                logs_i["gender_gap_abs_0.3"].append(abs(gender_gap_03))
                logs_i["gender_gap_abs_0.1"].append(abs(gender_gap_01))
                logs_i["gender_mid_0.3"].append(abs(gender_mid_03))
                logs_i["gender_mid_0.1"].append(abs(gender_mid_01))
            
            if accelerator.is_main_process:
                log_imgs.append(log_imgs_i)
                logs.append(logs_i)
        
        if accelerator.is_main_process:
            for prompt_i, logs_i in itertools.zip_longest(prompts, logs):
                for key, values in logs_i.items():
                    if isinstance(values, list):
                        wandb_tracker.log({f"eval_{name}_{key}_{prompt_i}": np.mean(values)}, step=current_global_step)
                    else:
                        wandb_tracker.log({f"eval_{name}_{key}_{prompt_i}": values.mean().item()}, step=current_global_step)
                
                for key in list(logs[0].keys()):
                    avg = np.array([log[key] for log in logs]).mean()
                    wandb_tracker.log({f"eval_{name}_{key}": avg}, step=current_global_step)

            imgs_dict = {}
            for prompt_i, log_imgs_i in itertools.zip_longest(prompts, log_imgs):
                for key, values in log_imgs_i.items():
                    if key not in imgs_dict.keys():
                        imgs_dict[key] = [wandb.Image(
                            data_or_path=values[0],
                            caption=prompt_i,
                        )]
                    else:
                        imgs_dict[key].append(wandb.Image(
                            data_or_path=values[0],
                            caption=prompt_i,
                        ))
            for key, imgs in imgs_dict.items():
                wandb_tracker.log(
                    {f"eval_{name}_{key}": imgs},
                    step=current_global_step
                    ) 
        
        return logs, log_imgs
    
    def patch_face_reduce_grad(images, face_bboxs, face_bboxs_ori, targets, preds_gender_ori, probs_gender_ori, factor=0.1, confidence_level=0.9):
        images_new = []
        for image, face_bbox, face_bbox_ori, target, pred_gender_ori, prob_gender_ori in itertools.zip_longest(images, face_bboxs, face_bboxs_ori, targets, preds_gender_ori, probs_gender_ori):
            if (face_bbox == -1).all():
                images_new.append(image.unsqueeze(dim=0))
            else:
                img_width, img_height = image.shape[1:]
                idx_left = max(face_bbox[0], face_bbox_ori[0], 0)
                idx_right = min(face_bbox[2], face_bbox_ori[2], img_width)
                idx_bottom = max(face_bbox[1], face_bbox_ori[1], 0)
                idx_top = min(face_bbox[3], face_bbox_ori[3], img_height)

                img_face = image[:,idx_bottom:idx_top,idx_left:idx_right].clone()
                if target==-1:
                    grad_hook = make_grad_hook(factor)
                elif (target==pred_gender_ori) and (prob_gender_ori[pred_gender_ori] >= confidence_level):
                    grad_hook = make_grad_hook(1)
                elif (target==pred_gender_ori) and (prob_gender_ori[pred_gender_ori] < confidence_level):
                    grad_hook = make_grad_hook(1)
                elif target!=pred_gender_ori:
                    grad_hook = make_grad_hook(factor)
                img_face.register_hook(grad_hook)

                img_add = torch.zeros_like(image)
                img_add[:,idx_bottom:idx_top,idx_left:idx_right] = img_face

                mask = torch.zeros_like(image)
                mask[:,idx_bottom:idx_top,idx_left:idx_right] = 1

                image = mask*img_add + (1-mask)*image
                images_new.append(image.unsqueeze(dim=0))

        images_new = torch.cat(images_new)
        return images_new
    
    def gen_dynamic_weights(face_indicators, targets, preds_gender_ori, probs_gender_ori, factor=0.2, confidence_level=0.9):
        weights = []
        for face_indicator, target, pred_gender_ori, prob_gender_ori in itertools.zip_longest(face_indicators, targets, preds_gender_ori, probs_gender_ori):
            if (face_indicator == False).all():
                weights.append(1)
            else:
                if target==-1:
                    weights.append(factor)
                elif (target==pred_gender_ori) and (prob_gender_ori[pred_gender_ori] >= confidence_level):
                    weights.append(1)
                elif (target==pred_gender_ori) and (prob_gender_ori[pred_gender_ori] < confidence_level):
                    weights.append(1)
                elif target!=pred_gender_ori:
                    weights.append(factor)

        weights = torch.tensor(weights, dtype=probs_gender_ori.dtype, device=probs_gender_ori.device)
        return weights

    def model_sanity_print(model, state):
        params = [p for p in model.parameters()]
        print(f"\t{accelerator.device}; {state};\n\t\tparam[0]: {params[0].flatten()[0].item():.12f};\tparam[0].grad: {params[0].grad.flatten()[0].item():.12f};\n\t\tparam[1]: {params[1].flatten()[0].item():.12f};\tparam[1].grad: {params[1].grad.flatten()[0].item():.12f}")
        

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
        num_training_steps=args.max_train_steps * accelerator.num_processes,
        num_cycles=args.lr_num_cycles,
        power=args.lr_power,
    )
    
    # print(f"{accelerator.device}, text_encoder, {list(text_encoder_lora_model.parameters())[0].flatten()[1]:.6f}, {text_encoder_lora_ema.shadow_params[0].flatten()[1]:.6f}")
    # print(f"{accelerator.device}, unet, {list(unet_lora_layers.parameters())[0].flatten()[1]:.6f}, {unet_lora_ema.shadow_params[0].flatten()[1]:.6f}")
    
    optimizer, lr_scheduler = accelerator.prepare(
            optimizer, lr_scheduler
        )
    # accelerator.register_for_checkpointing(face_feats_model)
    
    if args.train_text_encoder:
        text_encoder_lora_model, text_encoder_lora_ema = accelerator.prepare(text_encoder_lora_model, text_encoder_lora_ema)
        accelerator.register_for_checkpointing(text_encoder_lora_ema)
    if args.train_unet:
        unet_lora_layers, unet_lora_ema = accelerator.prepare(unet_lora_layers, unet_lora_ema)
        accelerator.register_for_checkpointing(unet_lora_ema)
        
    # print(f"{accelerator.device}, text_encoder, {list(text_encoder_lora_model.parameters())[0].flatten()[1]:.6f}, {text_encoder_lora_ema.shadow_params[0].flatten()[1]:.6f}")
    # print(f"{accelerator.device}, unet, {list(unet_lora_layers.parameters())[0].flatten()[1]:.6f}, {unet_lora_ema.shadow_params[0].flatten()[1]:.6f}")
    
    
    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num batches each epoch = {len(train_dataloader)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    global_step = 0
    first_epoch = 0

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:

        if not os.path.exists(args.resume_from_checkpoint):
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            args.resume_from_checkpoint = None
        else:
            accelerator.print(f"Resuming from checkpoint {args.resume_from_checkpoint}")
            accelerator.load_state(args.resume_from_checkpoint)
            global_step = int(os.path.basename(args.resume_from_checkpoint).split("-")[1])

            resume_global_step = global_step * args.gradient_accumulation_steps
            first_epoch = global_step // num_update_steps_per_epoch
            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
            
            if args.train_text_encoder:
                text_encoder_lora_ema.to(accelerator.device)
                
                # need to recreate text_encoder_lora_ema_dict
                text_encoder_lora_ema_dict = {}
                for name, shadow_param in itertools.zip_longest(text_encoder_lora_params_name_order, text_encoder_lora_ema.shadow_params):
                    text_encoder_lora_ema_dict[name] = shadow_param
                assert text_encoder_lora_ema_dict.__len__() == text_encoder_lora_dict.__len__(), "length does not match! something wrong happened while converting lora params to a state dict."
            
            if args.train_unet:
                unet_lora_ema.to(accelerator.device)

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")
    wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)

    for epoch in range(first_epoch, args.num_train_epochs):
        for step, dataset_idx in enumerate(train_dataloader_idxs[epoch]):            
            # broadcast batch from main_process to all others
            batch = [train_dataset.__getitem__(dataset_idx)]
            torch.distributed.broadcast_object_list(batch, src=0)
            prompts = batch
            
            # Skip steps until we reach the resumed step
            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                if step % args.gradient_accumulation_steps == 0:
                    progress_bar.update(1)
                continue
            
            noises = torch.randn(
                [len(prompts), args.train_images_per_prompt,4,64,64],
                dtype=weight_dtype_high_precision
                ).to(accelerator.device)
            
            # if global_step == 0:
            #     prompts_val = [val_dataset.__getitem__(i) for i in val_dataloader_idxs[0]]
            #     torch.distributed.broadcast_object_list(prompts_val, src=0)
            #     noises_val = torch.randn(
            #     [len(prompts_val), args.val_images_per_prompt,4,64,64],
            #     dtype=weight_dtype_high_precision
            #     ).to(accelerator.device)
            #     evaluate_process_out = evaluate_process(text_encoder, unet, "main", prompts_val, noises_val, global_step)

            #     if accelerator.is_main_process:
            #         # report the same results for EMA as well
            #         logs, log_imgs = evaluate_process_out
            #         name = "EMA"
            #         for prompt_i, logs_i in itertools.zip_longest(prompts_val, logs):
            #             for key, values in logs_i.items():
            #                 if isinstance(values, list):
            #                     wandb_tracker.log({f"eval_{name}_{key}_{prompt_i}": np.mean(values)}, step=global_step)
            #                 else:
            #                     wandb_tracker.log({f"eval_{name}_{key}_{prompt_i}": values.mean().item()}, step=global_step)
                    
            #         for key in list(logs[0].keys()):
            #             avg = np.array([log[key] for log in logs]).mean()
            #             wandb_tracker.log({f"eval_{name}_{key}": avg}, step=global_step)

            #         imgs_dict = {}
            #         for prompt_i, log_imgs_i in itertools.zip_longest(prompts_val, log_imgs):
            #             for key, values in log_imgs_i.items():
            #                 if key not in imgs_dict.keys():
            #                     imgs_dict[key] = [wandb.Image(
            #                         data_or_path=values[0],
            #                         caption=prompt_i,
            #                     )]
            #                 else:
            #                     imgs_dict[key].append(wandb.Image(
            #                         data_or_path=values[0],
            #                         caption=prompt_i,
            #                     ))
            #         for key, imgs in imgs_dict.items():
            #             wandb_tracker.log(
            #                 {f"eval_{name}_{key}": imgs},
            #                 step=global_step
            #                 )

            accelerator.wait_for_everyone()
            optimizer.zero_grad()
            logs = []
            log_imgs = []
            with accelerator.accumulate(text_encoder_lora_model):

                for prompt_i, noises_i in itertools.zip_longest(prompts, noises):

                    # print some info for easy monitor
                    noises_i_all = [noises_i.detach().clone() for i in range(accelerator.state.num_processes)]
                    torch.distributed.all_gather(noises_i_all, noises_i)
                    if accelerator.is_main_process:
                        now = datetime.datetime.now(my_timezone)
                        accelerator.print(
                            f"{now.strftime('%Y/%m/%d - %H:%M:%S')} --- epoch: {epoch}, step: {step}, prompt: {prompt_i}\n" +
                            " ".join([f"\tprocess idx: {idx}; noise: {noises_i_all[idx].flatten()[-1].item():.4f};" for idx in range(len(noises_i_all))])
                            )
                    
                    if accelerator.is_main_process:
                        logs_i = {
                            "loss_CEL": [],
                            "loss_real": [],
                            "loss_CLIP": [],
                            "loss_DINO": [],
                            "loss": [],
                            "gender_gap_0.5": [],
                            "gender_gap_0.3": [],
                            "gender_gap_0.1": [],
                            "gender_gap_abs_0.5": [],
                            "gender_gap_abs_0.3": [],
                            "gender_gap_abs_0.1": [],
                            "gender_mid_0.3": [],
                            "gender_mid_0.1": [],
                        }
                        log_imgs_i = {}

                    num_denoising_steps = random.choices(range(19,24), k=1)
                    torch.distributed.broadcast_object_list(num_denoising_steps, src=0)
                    num_denoising_steps = num_denoising_steps[0]
                    with torch.no_grad():
                        ################################################
                        # step 1: generate all images, check if there are many faces
                        images = []
                        N = math.ceil(noises_i.shape[0] / args.val_GPU_batch_size)
                        for j in range(N):
                            noises_ij = noises_i[args.val_GPU_batch_size*j:args.val_GPU_batch_size*(j+1)]
                            images_ij = generate_image_no_gradient(prompt_i, noises_ij, num_denoising_steps, which_text_encoder=text_encoder, which_unet=unet)
                            images.append(images_ij)
                        images = torch.cat(images)
                        
                        face_indicators, face_bboxs, face_chips, face_landmarks, aligned_face_chips = get_face(images)
                        preds_gender, probs_gender, logits_gender = get_face_gender(face_chips, selector=face_indicators, fill_value=-1)
                        
                        face_feats = torch.ones([aligned_face_chips.shape[0],512], dtype=weight_dtype_high_precision, device=aligned_face_chips.device) * (-1)
                        if sum(face_indicators)>0:
                            face_feats_ = get_face_feats(face_feats_net, aligned_face_chips[face_indicators])
                            face_feats[face_indicators] = face_feats_
                        
                        _, face_real_scores = face_feats_model.semantic_search(face_feats, genders=None, selector=face_indicators, return_similarity=True)
                        
                        accelerator.wait_for_everyone()
                        face_indicators_all, face_indicators_others = customized_all_gather(face_indicators, accelerator, return_tensor_others=True)
                        accelerator.print(f"\tFind faces {face_indicators_all.float().sum().item()}/{face_indicators_all.shape[0]}!")
                        
                        images_all = customized_all_gather(images, accelerator, return_tensor_others=False)
                        face_bboxs_all = customized_all_gather(face_bboxs, accelerator, return_tensor_others=False)
                        preds_gender_all = customized_all_gather(preds_gender, accelerator, return_tensor_others=False)
                        probs_gender_all = customized_all_gather(probs_gender, accelerator, return_tensor_others=False)
                        face_real_scores_all = customized_all_gather(face_real_scores, accelerator, return_tensor_others=False)
                        if accelerator.is_main_process:
                            if step % args.train_plot_every_n_iter == 0:
                                save_to = os.path.join(args.imgs_save_dir, f"train-{global_step}_generated.jpg")
                                plot_in_grid(images_all, save_to, face_indicators=face_indicators_all, face_bboxs=face_bboxs_all, preds_gender=preds_gender_all, pred_class_probs_gender=probs_gender_all.max(dim=-1).values,
                                face_real_scores=face_real_scores_all)

                                log_imgs_i["img_generated"] = [save_to]
                    
                        ################################################
                        # Step 2: if many faces, generate dynamic targets 
                        probs_gender_all, probs_gender_others = customized_all_gather(probs_gender, accelerator, return_tensor_others=True)
                        if accelerator.is_main_process:
                            probs_tmp = probs_gender_all[(probs_gender_all!=-1).all(dim=-1)]
                            gender_gap_05 = (((probs_tmp[:,1]>=0.5)*(probs_tmp[:,1]<=1)).float().mean() - ((probs_tmp[:,1]>=0)*(probs_tmp[:,1]<=0.5)).float().mean()).item()
                            gender_gap_03 = (((probs_tmp[:,1]>=0.7)*(probs_tmp[:,1]<=1)).float().mean() - ((probs_tmp[:,1]>=0)*(probs_tmp[:,1]<=0.3)).float().mean()).item()
                            gender_gap_01 = (((probs_tmp[:,1]>=0.9)*(probs_tmp[:,1]<=1)).float().mean() - ((probs_tmp[:,1]>=0)*(probs_tmp[:,1]<=0.1)).float().mean()).item()
                            gender_mid_03 = ((probs_tmp[:,1]>0.3)*(probs_tmp[:,1]<0.7)).float().mean().item()
                            gender_mid_01 = ((probs_tmp[:,1]>0.1)*(probs_tmp[:,1]<0.9)).float().mean().item()
                            logs_i["gender_gap_0.5"].append(gender_gap_05)
                            logs_i["gender_gap_0.3"].append(gender_gap_03)
                            logs_i["gender_gap_0.1"].append(gender_gap_01)
                            logs_i["gender_gap_abs_0.5"].append(abs(gender_gap_05))
                            logs_i["gender_gap_abs_0.3"].append(abs(gender_gap_03))
                            logs_i["gender_gap_abs_0.1"].append(abs(gender_gap_01))
                            logs_i["gender_mid_0.3"].append(abs(gender_mid_03))
                            logs_i["gender_mid_0.1"].append(abs(gender_mid_01))
                            
                        
                        # broadcast, just in case targets_all computed might be different on different processes
                        targets_all, uncertainty_all = generate_dynamic_targets(probs_gender_all, w_uncertainty=True)
                        torch.distributed.broadcast(targets_all, src=0)
                        torch.distributed.broadcast(uncertainty_all, src=0)

                        targets_all[uncertainty_all>args.skip_uncertainty_threshold] = -1
                        targets = targets_all[probs_gender.shape[0]*(accelerator.local_process_index):probs_gender.shape[0]*(accelerator.local_process_index+1)]
                        uncertainty = uncertainty_all[probs_gender.shape[0]*(accelerator.local_process_index):probs_gender.shape[0]*(accelerator.local_process_index+1)]
                        accelerator.print(f"\tFaces w/ grads: {(targets_all!=-1).sum().item()}/{targets_all.shape[0]}")
                        
                        
                        
                        ################################################
                        # Step 3: generate all original images
                        # note that only targets from above will be used to compute loss
                        # all other variables will not be used below
                        images_ori = []
                        N = math.ceil(noises_i.shape[0] / args.val_GPU_batch_size)
                        for j in range(N):
                            noises_ij = noises_i[args.val_GPU_batch_size*j:args.val_GPU_batch_size*(j+1)]
                            if args.train_text_encoder and args.train_unet:
                                images_ij = generate_image_no_gradient(prompt_i, noises_ij, num_denoising_steps, which_text_encoder=eval_text_encoder, which_unet=eval_unet)
                            elif args.train_text_encoder and not args.train_unet:
                                images_ij = generate_image_no_gradient(prompt_i, noises_ij, num_denoising_steps, which_text_encoder=eval_text_encoder, which_unet=unet)
                            images_ori.append(images_ij)
                        images_ori = torch.cat(images_ori)
                        face_indicators_ori, face_bboxs_ori, face_chips_ori, face_landmarks_ori, aligned_face_chips_ori = get_face(images_ori)
                        preds_gender_ori, probs_gender_ori, logits_gender_ori = get_face_gender(face_chips_ori, selector=face_indicators_ori, fill_value=-1)
                        
                        images_small_ori = transforms.Resize(args.size_small)(images_ori)
                        clip_feats_ori = get_clip_feat(images_small_ori, normalize=True, to_high_precision=True)
                        DINO_feats_ori = get_dino_feat(images_small_ori, normalize=True, to_high_precision=True)

                        images_ori_all = customized_all_gather(images_ori, accelerator, return_tensor_others=False)
                        face_indicators_ori_all = customized_all_gather(face_indicators_ori, accelerator, return_tensor_others=False)
                        face_bboxs_ori_all = customized_all_gather(face_bboxs_ori, accelerator, return_tensor_others=False)
                        preds_gender_ori_all = customized_all_gather(preds_gender_ori, accelerator, return_tensor_others=False)
                        probs_gender_ori_all = customized_all_gather(probs_gender_ori, accelerator, return_tensor_others=False)
                        
                        face_feats_ori = get_face_feats(face_feats_net, aligned_face_chips_ori)
                        # aligned_face_logits_ori = algined_face_gender_classifier(aligned_face_chips_ori)
                        # aligned_face_logits_gender_ori = aligned_face_logits_ori.view([aligned_face_logits_ori.shape[0],-1,2])[:,20,:].clone()
                        # face_feats_ori_all = customized_all_gather(face_feats_ori, accelerator, return_tensor_others=False)
                        # aligned_face_logits_gender_ori_all = customized_all_gather(aligned_face_logits_gender_ori, accelerator, return_tensor_others=False)
                        
                        # if sum(face_indicators_ori_all)>0:
                        #     face_feats_model.add_face_feats(face_feats_ori_all[face_indicators_ori_all], aligned_face_logits_gender_ori_all[face_indicators_ori_all])
                        
                        if accelerator.is_main_process:
                            if step % args.train_plot_every_n_iter == 0:
                                save_to = os.path.join(args.imgs_save_dir, f"train-{global_step}_ori.jpg")
                                plot_in_grid(images_ori_all, save_to, face_indicators=face_indicators_ori_all, face_bboxs=face_bboxs_ori_all, preds_gender=preds_gender_ori_all, pred_class_probs_gender=probs_gender_ori_all.max(dim=-1).values)

                                log_imgs_i["img_ori"] = [save_to]
                    
                    ################################################
                    # Step 4: compute loss
                    loss_CEL_i = torch.ones(targets.shape, dtype=weight_dtype, device=accelerator.device) *(-1)
                    loss_real_i = torch.ones(targets.shape, dtype=weight_dtype, device=accelerator.device) *(-1)
                    loss_CLIP_i = torch.ones(targets.shape, dtype=weight_dtype, device=accelerator.device) *(-1)
                    loss_DINO_i = torch.ones(targets.shape, dtype=weight_dtype, device=accelerator.device) *(-1)
                    loss_i = torch.ones(targets.shape, dtype=weight_dtype, device=accelerator.device) *(-1)
                    
                    idxs_i = list(range(targets.shape[0]))
                    N = math.ceil(targets.shape[0] / args.train_GPU_batch_size)
                    for j in range(N):
                        idxs_ij = idxs_i[j*args.train_GPU_batch_size:(j+1)*args.train_GPU_batch_size]
                        noises_ij = noises_i[idxs_ij]
                        targets_ij = targets[idxs_ij]
                        clip_feats_ori_ij = clip_feats_ori[idxs_ij]
                        DINO_feats_ori_ij = DINO_feats_ori[idxs_ij]
                        preds_gender_ori_ij = preds_gender_ori[idxs_ij]
                        probs_gender_ori_ij = probs_gender_ori[idxs_ij]
                        face_bboxs_ori_ij = face_bboxs_ori[idxs_ij]
                        face_feats_ori_ij = face_feats_ori[idxs_ij]
                        
                        images_ij = generate_image_w_gradient(prompt_i, noises_ij, num_denoising_steps, which_text_encoder=text_encoder, which_unet=unet)
                        face_indicators_ij, face_bboxs_ij, face_chips_ij, face_landmarks_ij, aligned_face_chips_ij = get_face(images_ij)
                        preds_gender_ij, probs_gender_ij, logits_gender_ij = get_face_gender(face_chips_ij, selector=face_indicators_ij, fill_value=-1)
                        
                        images_ij = patch_face_reduce_grad(images_ij, face_bboxs_ij, face_bboxs_ori_ij, targets_ij, preds_gender_ori_ij, probs_gender_ori_ij, factor=args.factor2, confidence_level=0.9)
                        images_small_ij = transforms.Resize(args.size_small)(images_ij)
                        clip_feats_ij = get_clip_feat(images_small_ij, normalize=True, to_high_precision=True)
                        DINO_feats_ij = get_dino_feat(images_small_ij, normalize=True, to_high_precision=True)
                        
                        loss_CLIP_ij = - (clip_feats_ij * clip_feats_ori_ij).sum(dim=-1) + 1
                        loss_DINO_ij = - (DINO_feats_ij * DINO_feats_ori_ij).sum(dim=-1) + 1
                        
                        loss_CEL_ij = torch.ones(len(idxs_ij), dtype=weight_dtype, device=accelerator.device) *(-1)
                        idxs_w_face_loss = ((face_indicators_ij == True) * (targets_ij != -1)).nonzero().view([-1])
                        loss_CEL_ij_w_face_loss = CELoss(logits_gender_ij[idxs_w_face_loss], targets_ij[idxs_w_face_loss])
                        loss_CEL_ij[idxs_w_face_loss] = loss_CEL_ij_w_face_loss
                        
                        loss_real_ij = torch.ones(len(idxs_ij), dtype=weight_dtype, device=accelerator.device) *(-1)
                        
                        idxs_w_face_feats_from_ori = ((face_indicators_ij==True) * (targets_ij!=-1) * (targets_ij==preds_gender_ori_ij) * (probs_gender_ori_ij.max(dim=-1).values>=args.face_gender_confidence_level)).nonzero().view([-1]).tolist()
                        if len(idxs_w_face_feats_from_ori)>0:
                            face_feats_1 = get_face_feats(face_feats_net, aligned_face_chips_ij[idxs_w_face_feats_from_ori])
                            face_feats_target_1 = face_feats_ori_ij[idxs_w_face_feats_from_ori]
                            loss_real_ij[idxs_w_face_feats_from_ori] = (1 - (face_feats_1*face_feats_target_1).sum(dim=-1)).to(loss_real_ij.dtype)
                        
                        idxs_w_face_feats_from_search = list(set(((face_indicators_ij==True) * (targets_ij!=-1) ).nonzero().view([-1]).tolist()) - set(idxs_w_face_feats_from_ori))
                        if len(idxs_w_face_feats_from_search)>0:
                            face_feats_2 = get_face_feats(face_feats_net, aligned_face_chips_ij[idxs_w_face_feats_from_search])
                            face_feats_target_2 = face_feats_model.semantic_search(face_feats_2, targets_ij[idxs_w_face_feats_from_search])
                            loss_real_ij[idxs_w_face_feats_from_search] = (1 - (face_feats_2*face_feats_target_2).sum(dim=-1)).to(loss_real_ij.dtype)
                        
                        dynamic_weights = gen_dynamic_weights(face_indicators_ij, targets_ij, preds_gender_ori_ij, probs_gender_ori_ij, factor=args.factor1, confidence_level=0.9)
                        loss_ij = loss_CEL_ij + args.recon_loss_weight * dynamic_weights * (loss_CLIP_ij + loss_DINO_ij) + args.real_loss_weight * loss_real_ij
                        accelerator.backward(loss_ij.mean())
                        
                        with torch.no_grad():
                            loss_CEL_i[idxs_ij] = loss_CEL_ij.to(loss_CEL_i.dtype)
                            loss_real_i[idxs_ij] = loss_real_ij.to(loss_CEL_i.dtype)
                            loss_CLIP_i[idxs_ij] = loss_CLIP_ij.to(loss_CLIP_i.dtype)
                            loss_DINO_i[idxs_ij] = loss_DINO_ij.to(loss_DINO_i.dtype)
                            loss_i[idxs_ij] = loss_ij.to(loss_i.dtype)
                            
                    # for logging purpose, gather all losses to main_process
                    accelerator.wait_for_everyone()
                    loss_CEL_all = customized_all_gather(loss_CEL_i, accelerator)
                    loss_real_all = customized_all_gather(loss_real_i, accelerator)
                    loss_CLIP_all = customized_all_gather(loss_CLIP_i, accelerator)
                    loss_DINO_all = customized_all_gather(loss_DINO_i, accelerator)
                    loss_all = customized_all_gather(loss_i, accelerator)
                    
                    loss_CLIP_all = loss_CLIP_all
                    loss_DINO_all = loss_DINO_all
                    loss_all = loss_all[loss_CEL_all!=-1]
                    loss_CEL_all = loss_CEL_all[loss_CEL_all!=-1]
                    loss_real_all = loss_real_all[loss_real_all!=-1]
                    accelerator.print(f"\tloss_all.shape: {loss_all.shape}")

                    if accelerator.is_main_process:
                        logs_i["loss_CLIP"].append(loss_CLIP_all)
                        logs_i["loss_DINO"].append(loss_DINO_all)
                        logs_i["loss"].append(loss_all)
                        logs_i["loss_CEL"].append(loss_CEL_all)
                        logs_i["loss_real"].append(loss_real_all)
                
                accelerator.wait_for_everyone()
                if accelerator.is_main_process:
                    for key in ["loss_CEL", "loss_real", "loss_CLIP", "loss_DINO", "loss"]:
                        if logs_i[key] == []:
                            logs_i.pop(key)
                        else:
                            logs_i[key] = torch.cat(logs_i[key])

                    for key in ["gender_gap_0.5", "gender_gap_0.3", "gender_gap_0.1", "gender_gap_abs_0.5", "gender_gap_abs_0.3", "gender_gap_abs_0.1", "gender_mid_0.3", "gender_mid_0.1"]:
                        if logs_i[key] == []:
                            logs_i.pop(key)
                    
                    logs.append(logs_i)

                    log_imgs.append(log_imgs_i)

                ##########################################################################
                # log process for training
                if accelerator.is_main_process:
                    # current prompts is only one prompt
                    for prompt_i, logs_i in itertools.zip_longest(prompts, logs):
                        for key, values in logs_i.items():
                            if isinstance(values, list):
                                wandb_tracker.log({f"train_{key}": np.mean(values)}, step=global_step)
                            else:
                                wandb_tracker.log({f"train_{key}": values.mean().item()}, step=global_step)

                    for prompt_i, log_imgs_i in itertools.zip_longest(prompts, log_imgs):
                        for key, values in log_imgs_i.items():
                            wandb_tracker.log({f"train_{key}":wandb.Image(
                                    data_or_path=values[0],
                                    caption=prompt_i,
                                )
                                },
                                step=global_step
                                )
                
                if args.train_text_encoder:
                    model_sanity_print(text_encoder_lora_model, "text_encoder: after accelerator.backward()")
                if args.train_unet:
                    model_sanity_print(unet_lora_layers, "unet: after accelerator.backward()")
                
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(text_encoder_lora_params, args.max_grad_norm)
                if args.train_text_encoder:
                    model_sanity_print(text_encoder_lora_model, "text_encoder: after accelerator.clip_grad_norm_()")
                if args.train_unet:
                    model_sanity_print(unet_lora_layers, "unet: after accelerator.clip_grad_norm_()")

                # note that up till now grads are not synced
                # we mannually sync grads
                accelerator.wait_for_everyone()
                grad_is_finite = True
                with torch.no_grad():
                    if args.train_text_encoder:
                        for p in text_encoder_lora_model.parameters():
                            if not torch.isfinite(p.grad).all():
                                grad_is_finite = False
                            torch.distributed.all_reduce(p.grad, torch.distributed.ReduceOp.SUM)
                            p.grad = p.grad / accelerator.num_processes
                    if args.train_unet:
                        for p in unet_lora_layers.parameters():
                            if not torch.isfinite(p.grad).all():
                                grad_is_finite = False
                            torch.distributed.all_reduce(p.grad, torch.distributed.ReduceOp.SUM)
                            p.grad = p.grad / accelerator.num_processes
                    
                if args.train_text_encoder:
                    model_sanity_print(text_encoder_lora_model, "text_encoder: after gradients allreduce & average")
                if args.train_unet:
                    model_sanity_print(unet_lora_layers, "unet: after gradients allreduce & average")

                if grad_is_finite:
                    optimizer.step()
                    if args.train_text_encoder:
                        model_sanity_print(text_encoder_lora_model, "text_encoder: after optimizer.step()")
                    if args.train_unet:
                        model_sanity_print(unet_lora_layers, "unet: after optimizer.step()")
                else:
                    accelerator.print(f"grads are not finite, skipped!")
                
                lr_scheduler.step()
            
            if grad_is_finite:
                if accelerator.sync_gradients:
                    if args.train_text_encoder:
                        text_encoder_lora_ema.step(  text_encoder_lora_params )
                    if args.train_unet:
                        unet_lora_ema.step(  unet_lora_layers.parameters() )

            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1

                with torch.no_grad():
                    if args.train_text_encoder:
                        param_norm = np.mean([p.norm().item() for p in text_encoder_lora_params])
                        param_ema_norm = np.mean([p.norm().item() for p in text_encoder_lora_ema.shadow_params])
                        wandb_tracker.log({f"train_text_encoder_lora_norm_avg": param_norm}, step=global_step)
                        wandb_tracker.log({f"train_text_encoder_lora_ema_norm_avg": param_ema_norm}, step=global_step)
                    
                    if args.train_unet:
                        param_norm = np.mean([p.norm().item() for p in unet_lora_layers.parameters()])
                        param_ema_norm = np.mean([p.norm().item() for p in unet_lora_ema.shadow_params])
                        wandb_tracker.log({f"train_unet_lora_norm_avg": param_norm}, step=global_step)
                        wandb_tracker.log({f"train_unet_lora_ema_norm_avg": param_ema_norm}, step=global_step)
            
            accelerator.wait_for_everyone()
            if global_step % args.evaluate_every_n_iter == 0:
            # if True:
                prompts_val = [val_dataset.__getitem__(i) for i in val_dataloader_idxs[0]]
                torch.distributed.broadcast_object_list(prompts_val, src=0)
                noises_val = torch.randn(
                [len(prompts_val), args.val_images_per_prompt,4,64,64],
                dtype=weight_dtype_high_precision
                ).to(accelerator.device)
                evaluate_process(text_encoder, unet, "main", prompts_val, noises_val, global_step)

                # evaluate EMA as well
                if args.train_text_encoder:
                    text_encoder_lora_dict_copy = copy.deepcopy(text_encoder_lora_dict)
                    load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_ema_dict, strict=False)
                
                if args.train_unet:
                    with torch.no_grad():
                        unet_lora_layers_copy = copy.deepcopy(unet_lora_layers)
                        for p, p_from in itertools.zip_longest(list(unet_lora_layers.parameters()), unet_lora_ema.shadow_params):
                            p.data = p_from.data
                    
                evaluate_process(text_encoder, unet, "EMA", prompts_val, noises_val, global_step)
                
                if args.train_text_encoder:
                    load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_dict_copy, strict=False)
                
                if args.train_unet:
                    with torch.no_grad():
                        for p, p_from in itertools.zip_longest(list(unet_lora_layers.parameters()), list(unet_lora_layers_copy.parameters())):
                            p.data = p_from.data

            if accelerator.is_main_process:
                if global_step % args.checkpointing_steps == 0:
                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                    if args.checkpoints_total_limit is not None:
                        name = "checkpoint_tmp"
                        clean_checkpoint(args.ckpts_save_dir, name, args.checkpoints_total_limit)

                    save_path = os.path.join(args.ckpts_save_dir, f"checkpoint_tmp-{global_step}")
                    accelerator.save_state(save_path)
                
                    logger.info(f"Accelerator checkpoint saved to {save_path}")

                if global_step % args.checkpointing_steps_long == 0:
                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`

                    save_path = os.path.join(args.ckpts_save_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                
                    logger.info(f"Accelerator checkpoint saved to {save_path}")

            torch.cuda.empty_cache()
    accelerator.end_training()


if __name__ == "__main__":
    args = parse_args()
    main(args)


