"""
尽量贴近原版的基于uvit的t2i adapter实现, 主要的特点为:
- 采用分离forward的方式进行
- 采用stable diffsion基本块代替原文中convblock
"""

import torch
import torch.nn as nn
import math
import os
from typing import List, Tuple, Union
from .timm import trunc_normal_, DropPath, Mlp
import einops
import torch.utils.checkpoint
import torch.nn.functional as F
from libs.preprocess_modules import get_preprocess_module
from torch.nn.parallel import DistributedDataParallel as DDP
from libs.sd_attention import CrossAttention, FeedForward

if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
    ATTENTION_MODE = 'flash'
else:
    try:
        import xformers
        import xformers.ops
        ATTENTION_MODE = 'xformers'
    except:
        ATTENTION_MODE = 'math'
print(f't2iadapter attention mode is {ATTENTION_MODE}')


def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


def patchify(imgs, patch_size):
    x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
    return x


def unpatchify(x, in_chans):
    patch_size = int((x.shape[2] // in_chans) ** 0.5)
    h = w = int(x.shape[1] ** .5)
    assert h * w == x.shape[1] and patch_size ** 2 * in_chans == x.shape[2]
    x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
    return x


def interpolate_pos_emb(pos_emb, old_shape, new_shape):
    pos_emb = einops.rearrange(pos_emb, 'B (H W) C -> B C H W', H=old_shape[0], W=old_shape[1])
    pos_emb = F.interpolate(pos_emb, new_shape, mode='bilinear')
    pos_emb = einops.rearrange(pos_emb, 'B C H W -> B (H W) C')
    return pos_emb


    
        
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.dim = dim

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, L, C = x.shape

        qkv = self.qkv(x)
        if ATTENTION_MODE == 'flash':
            qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
            q, k, v = qkv[0], qkv[1], qkv[2]  # B H L D
            x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
            x = einops.rearrange(x, 'B H L D -> B L (H D)')
        elif ATTENTION_MODE == 'xformers':
            qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
            q, k, v = qkv[0], qkv[1], qkv[2]  # B L H D
            x = xformers.ops.memory_efficient_attention(q, k, v)
            x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
        elif ATTENTION_MODE == 'math':
            with torch.amp.autocast(device_type='cuda', enabled=False):
                qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
                q, k, v = qkv[0], qkv[1], qkv[2]  # B H L D
                attn = (q @ k.transpose(-2, -1)) * self.scale
                attn = attn.softmax(dim=-1)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2).reshape(B, L, C)
        else:
            raise NotImplemented

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
        super().__init__()
        self.norm1 = norm_layer(dim) if skip else None
        self.norm2 = norm_layer(dim)

        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm3 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
        self.use_checkpoint = use_checkpoint

    def forward(self, x, skip=None):
        if self.use_checkpoint:
            return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
        else:
            return self._forward(x, skip)

    def _forward(self, x, skip=None):
        if self.skip_linear is not None:
            x = self.skip_linear(torch.cat([x, skip], dim=-1))
            x = self.norm1(x)
        x = x + self.drop_path(self.attn(x))
        x = self.norm2(x)

        x = x + self.drop_path(self.mlp(x))
        x = self.norm3(x)

        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, patch_size, in_chans=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        assert H % self.patch_size == 0 and W % self.patch_size == 0
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x



class AdpManagedModule1(nn.Module):
    """
    out-out adapter method, all feature as query
    """
    def __init__(self, origin_module: nn.Module, name, adps: List[nn.Module]):
        super().__init__()
        self.origin_module = origin_module
        self.name = name
        self.adps = adps
        self.num_text_tokens = 77
        self.num_img_tokens = 1024
    
    def forward(self, x, **kwargs):
        x_out, attn_map = self.origin_module.forward(x, **kwargs)
        t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out = x_out.split((1, 1, 1, self.num_text_tokens, 1, self.num_img_tokens), dim=1)
        for adp in self.adps:
            if adp.multiplier != 0.:
                x_adp = adp.forward(x_out, context=adp.get_context())
                adp_t_img, adp_t_text, adp_token_embed, adp_text_out, adp_clip_img_out, adp_img_out = x_adp.split((1, 1, 1, self.num_text_tokens, 1, self.num_img_tokens), dim=1)
                text_out = text_out + adp_text_out
                img_out = img_out + adp_img_out
            
        x_out = torch.cat((t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out), dim=1)
        return x_out, attn_map

class AdpManagedModule2(nn.Module):
    """
    out-out adapter method img-related feature as query
    """
    def __init__(self, origin_module: nn.Module, name, adps: List[nn.Module]):
        super().__init__()
        self.origin_module = origin_module
        self.name = name
        self.adps = adps
        self.num_text_tokens = 77
        self.num_img_tokens = 1024
    
    def forward(self, x, **kwargs):
        x_out, attn_map = self.origin_module.forward(x, **kwargs)
        t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out = x_out.split((1, 1, 1, self.num_text_tokens, 1, self.num_img_tokens), dim=1)
        for adp in self.adps:
            if adp.multiplier != 0.:
                img_adp = adp.forward(img_out, context=adp.get_context())
                img_out = img_out + img_adp
            
        x_out = torch.cat((t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out), dim=1)
        return x_out, attn_map
    
class AdpManagedModule3(nn.Module):
    """
    in-out adapter method img-related feature as query
    """
    def __init__(self, origin_module: nn.Module, name, adps: List[nn.Module]):
        super().__init__()
        self.origin_module = origin_module
        self.name = name
        self.adps = adps
        self.num_text_tokens = 77
        self.num_img_tokens = 1024
    
    def forward(self, x, **kwargs):
        t_img_token_in, t_text_token_in, token_embed_in, text_in, clip_img_in, img_in = x.split((1, 1, 1, self.num_text_tokens, 1, self.num_img_tokens), dim=1)
        img_adp_ls = []
        for adp in self.adps:
            if adp.multiplier != 0.:
                img_adp_ls.append(adp.forward(img_in, context=adp.get_context()))

        x_out, attn_map = self.origin_module.forward(x, **kwargs)
        t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out = x_out.split((1, 1, 1, self.num_text_tokens, 1, self.num_img_tokens), dim=1)
        
        for img_adp in img_adp_ls:
            img_out = img_out + img_adp
            
        x_out = torch.cat((t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out), dim=1)
        return x_out, attn_map
    
class AdpManagedModule4(nn.Module):
    """
    in-out adapter method text-related feature as query
    """
    def __init__(self, origin_module: nn.Module, name, adps: List[nn.Module]):
        super().__init__()
        self.origin_module = origin_module
        self.name = name
        self.adps = adps
        self.num_text_tokens = 77
        self.num_img_tokens = 1024
    
    def forward(self, x, **kwargs):
        t_img_token_in, t_text_token_in, token_embed_in, text_in, clip_img_in, img_in = x.split((1, 1, 1, self.num_text_tokens, 1, self.num_img_tokens), dim=1)
        text_adp_ls = []
        for adp in self.adps:
            if adp.multiplier != 0.:
                text_adp_ls.append(adp.forward(text_in, context=adp.get_context()))

        x_out, attn_map = self.origin_module.forward(x, **kwargs)
        t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out = x_out.split((1, 1, 1, self.num_text_tokens, 1, self.num_img_tokens), dim=1)
        
        for text_adp in text_adp_ls:
            text_out = text_out + text_adp
            
        x_out = torch.cat((t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out), dim=1)
        return x_out, attn_map

class AdpManagedModule21(nn.Module):
    """
    detect image must as same as img(512*512)
    """
    def __init__(self, origin_module: nn.Module, name, adps: List[nn.Module]):
        super().__init__()
        self.origin_module = origin_module
        self.name = name
        self.adps = adps
        self.num_text_tokens = 77
        self.num_img_tokens = 1024
    
    def forward(self, x, **kwargs):
        x_out, attn_map = self.origin_module.forward(x, **kwargs)
        t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out = x_out.split((1, 1, 1, self.num_text_tokens, 1, self.num_img_tokens), dim=1)
        for adp in self.adps:
            if adp.multiplier != 0.:
                adp_img_out = adp.forward(adp.get_context(), context=adp.get_context())
                img_out = img_out + adp_img_out
            
        x_out = torch.cat((t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out), dim=1)
        return x_out, attn_map
    
def replace_module_by_name(origin_module, name:str, replace_module):
    """
    as best of my knowledge, python dont have the reference for a class attribute
    so we have to find last second class and using setattr to replace module
    """
    path_list = name.split(".")
    target = origin_module
    for i in range(len(path_list)-1):
        p = path_list[i]
        if p.isdigit():
            target = target[int(p)]
        else:
            target = getattr(target, p)
    setattr(target, path_list[-1], replace_module)
    

    
class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=False, **kwargs):
        super().__init__()
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

        
    def forward(self, x, context=None):
        if self.checkpoint:
            return torch.utils.checkpoint.checkpoint(self._forward, x, context)
        else:
            return self._forward(x, context)

    def _forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x

class BasicTransformerBlockSeq(nn.Module):
    def __init__(self, num_block=1, **kwargs):
        super().__init__()
        self.blocks = nn.ModuleList([BasicTransformerBlock(**kwargs) for i in range(num_block)])
        self.multiplier = 1.
        self.zero_linear = None
        if kwargs.get("zero_linear") and kwargs["zero_linear"]:
            dim = kwargs.get("dim")
            self.zero_linear = nn.Linear(dim, dim, bias=False)
            torch.nn.init.zeros_(self.zero_linear.weight)
        
    def forward(self, x, context=None):
        for b in self.blocks:
            x = b(x, context)
        if self.zero_linear is not None:
            x = self.zero_linear(x)
        return x * self.multiplier



from libs.sd_attention import CrossAttention, FeedForward
class BasicTransformerBlockCross(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=False, **kwargs):
        super().__init__()
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

        
    def forward(self, x, context=None):
        if self.checkpoint:
            return torch.utils.checkpoint.checkpoint(self._forward, x, context)
        else:
            return self._forward(x, context)

    def _forward(self, x, context=None):
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x

class BasicTransformerBlockCrossSeq(nn.Module):
    def __init__(self, num_block=1, **kwargs):
        super().__init__()
        self.blocks = nn.ModuleList([BasicTransformerBlockCross(**kwargs) for i in range(num_block)])
        self.multiplier = 1.
        self.zero_linear = None
        if kwargs.get("zero_linear") and kwargs["zero_linear"]:
            dim = kwargs.get("dim")
            self.zero_linear = nn.Linear(dim, dim, bias=False)
            torch.nn.init.zeros_(self.zero_linear.weight)
        
    def forward(self, x, context=None):
        for b in self.blocks:
            x = b(x, context)
        if self.zero_linear is not None:
            x = self.zero_linear(x)
        return x * self.multiplier
    

class T2IAdapter(torch.nn.Module):
    """
    adapter的定义和创建, 同时包括forward
    """
    def __init__(
        self,
        block_names: List[str],
        block_args: dict,
        preprocess_module: str,
        multiplier = 1.,
        type_ = 1,
        verbose=True
    ) -> None:
        """
        Adapter network
        """
        super().__init__()
        self.multiplier = multiplier
        self.block_args = block_args
        self.block_names = block_names
        self.preprocess_module = get_preprocess_module(preprocess_module)
        self.init()
        self.type = type_

        self.verbose = verbose
        self.enable = True
        self.feature_dict = dict()
        
    def make_block(self, block_args: dict):
        num_block = block_args.get("num_block", 1)
        block_name = block_args.get("name", "SDBT") # basic transformer block os stable diffusion
        block_para = block_args["block_para"]
        

        def get_context(self):
            return self.preprocessed
        if block_name == "SDBT":
            b = BasicTransformerBlockSeq(num_block, **block_para)
            b.get_context = get_context.__get__(b, type(b))
        elif block_name == "SDBTC": # cross attention only:
            b = BasicTransformerBlockCrossSeq(num_block, **block_para)
            b.get_context = get_context.__get__(b, type(b))
            
        return b
    
    def init(self):
        self.blocks = nn.ModuleList([self.make_block(self.block_args) for i in range(len(self.block_names))])
        
    
    def query(self, name):
        if self.enable and self.feature_dict.get(name, None) is not None:
            return self.multiplier * self.feature_dict[name]
        return None
            
    def forward(self, x, **kwargs):
        """
        only do preprocess, other forward done with the main module forwad in the Adp managed module
        zero linears should appear in adapt blocks
        """
        self.clear_feature()
        if self.preprocess_module is not None:
            x = self.preprocess_module(x)
            for b in self.blocks:
                b.preprocessed = x
        else:
            raise


    def clear_feature(self):
        for k, v in self.feature_dict.items():
            v.cpu()
        self.feature_dict.clear()
            
    def dump_meta(self)-> dict:
        keys = ["block_names", "block_args"]
        res = dict()
        for key in keys:
            res[key] = getattr(self, key)
        return res
    
    def load_meta(self, d):
        keys = ["block_names", "block_args"]
        for key in keys:
            setattr(self, key, d[key])
        
            
    def init_from_file(self, file_path, te, nnet):
        weights_sd = torch.load(file_path, map_location="cpu")
        self.load_meta(weights_sd["meta"])
        self.init(te, nnet)
        info = self.load_state_dict(weights_sd, False)
        return info


    def set_multiplier(self, multiplier):
        self.multiplier = multiplier
        for b in self.blocks:
            b.multiplier = multiplier
        
            
    def set_enable(self, en=True):
        self.enable = en

    
    def save_weights(self, file, metadata:dict=None, dtype=None,):
        if metadata is not None and len(metadata) == 0:
            metadata = {}

        state_dict = self.state_dict()
        state_dict["meta"] = self.dump_meta()

        if dtype is not None:
            for key in list(state_dict.keys()):
                v = state_dict[key]
                v = v.detach().clone().to("cpu").to(dtype)
                state_dict[key] = v

        if os.path.splitext(file)[1] == ".safetensors":
            from safetensors.torch import save_file
            from library import train_util

            # Precalculate model hashes to save time on indexing
            if metadata is None:
                metadata = {}
            model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
            metadata["sshs_model_hash"] = model_hash
            metadata["sshs_legacy_hash"] = legacy_hash

            save_file(state_dict, file, metadata)
        else:
            torch.save(state_dict, file)

    def load_weights(self, file):
        if os.path.splitext(file)[1] == ".safetensors":
            from safetensors.torch import load_file
            weights_sd = load_file(file)
        else:
            weights_sd = torch.load(file, map_location="cpu")

        info = self.load_state_dict(weights_sd, False)
        return info


    def apply_to(self, nnet):
        print("T2iAdapter type>>>>>>>:", self.type)
        for b_name, block in zip(self.block_names, self.blocks):
            for name, module in nnet.named_modules():
                if name == b_name:
                    if self.type == 1:
                        if type(module) is AdpManagedModule1 and block not in module.adps:
                            module.adps.append(block)
                        else:
                            replace_module_by_name(nnet, name, AdpManagedModule1(module, name, [block]))
                    elif self.type == 2:
                        if type(module) is AdpManagedModule2 and block not in module.adps:
                            module.adps.append(block)
                        else:
                            replace_module_by_name(nnet, name, AdpManagedModule2(module, name, [block]))
                    elif self.type == 3:
                        if type(module) is AdpManagedModule3 and block not in module.adps:
                            module.adps.append(block)
                        else:
                            replace_module_by_name(nnet, name, AdpManagedModule3(module, name, [block]))
                    elif self.type == 4:
                        if type(module) is AdpManagedModule4 and block not in module.adps:
                            module.adps.append(block)
                        else:
                            replace_module_by_name(nnet, name, AdpManagedModule4(module, name, [block]))
                    elif self.type == 21:
                        if (type(module)) is AdpManagedModule21 and block not in module.adps:
                            module.adps.append(block)
                        else:
                            replace_module_by_name(nnet, name, AdpManagedModule21(module, name, [block]))
                    else:
                        raise

    def detach(self, nnet):
        for b_name, block in zip(self.block_names, self.blocks):
            for name, module in nnet.named_modules():
                if name == b_name:
                    if self.type == 1 and type(module) is AdpManagedModule1 and block in module.adps:
                        module.adps.remove(block)
                    elif self.type == 2 and type(module) is AdpManagedModule2 and block in module.adps:
                        module.adps.remove(block)
                    elif self.type == 21 and type(module) is AdpManagedModule21 and block in module.adps:
                        module.adps.remove(block)
                    elif self.type ==3 and type(module) is AdpManagedModule3 and block in module.adps:
                        module.adps.remove(block)
                    elif self.type == 4 and type(module) is AdpManagedModule4 and block in module.adps:
                        module.adps.remove(block)
                    else:
                        raise
 

    def on_epoch_start(self, text_encoder, unet):
        self.train()


import random
def T2i_compute_mask_loss_no_encode_unidiffuser(feed_model, nnet, clip_text_model, caption_decoder, autoencoder, clip_img_model,
                                schedule, device, iter_dict, config, **kwargs):
    """
    with mask loss
    """
    img = iter_dict["image1"].to(device)
    text = iter_dict["caption"]
    img4clip = iter_dict["img4clip"].to(device)
    detect = iter_dict["detect"].to(device)
    mask = iter_dict["mask"].to(device)
    data_type = iter_dict["data_type"].to(device)

    with torch.no_grad():
        img = autoencoder.encode(img)
        clip_img = clip_img_model.encode_image(img4clip).unsqueeze(1)
        text = clip_text_model.encode(text)

    if random.random() < config.cfg_p:
        detect = torch.zeros_like(detect)
    feed_model(detect)
    text = caption_decoder.encode_prefix(text)
    
    def mos(a, start_dim=1):  # mean of square
        return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
    # add noise
    n, (img_eps, clip_img_eps), (img_n, clip_img_n) = schedule.sample([img, clip_img])  # n in {1, ..., 1000}
    n = n.to(device)
    
    dict_out = nnet(img_n, clip_img_n, text, t_img=n, t_text=torch.zeros_like(n, device=device), data_type=data_type)
    img_out, clip_img_out, text_out = dict_out["img_out"], dict_out["clip_img_out"], dict_out["text_out"]
    diff = img_eps - img_out
    if random.random() < config.mask_p:
        loss_img = mos(diff * mask)
    else:
        loss_img = mos(diff)
    loss_clip_img = mos(clip_img_eps - clip_img_out)
    loss = loss_img + 0. * loss_clip_img + 0. * mos(text_out)
    return loss, loss_img, loss_clip_img


def T2i_compute_mask_loss_no_encode_addition_unidiffuser(feed_model, nnet, clip_text_model, caption_decoder, autoencoder, clip_img_model,
                                schedule, device, iter_dict, config, **kwargs):
    """
    with mask loss
    """
    img = iter_dict["image1"].to(device)
    text = iter_dict["caption"]
    img4clip = iter_dict["img4clip"].to(device)
    detect = iter_dict["detect"].to(device)
    mask = iter_dict["mask"].to(device)
    data_type = iter_dict["data_type"].to(device)

    with torch.no_grad():
        img = autoencoder.encode(img)
        clip_img = clip_img_model.encode_image(img4clip).unsqueeze(1)
        text = clip_text_model.encode(text)

    if random.random() < config.cfg_p:
        detect = torch.zeros_like(detect)
    feed_model(detect)
    text = caption_decoder.encode_prefix(text)
    
    def mos(a, start_dim=1):  # mean of square
        return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
    # add noise
    n, (img_eps, clip_img_eps), (img_n, clip_img_n) = schedule.sample([img, clip_img])  # n in {1, ..., 1000}
    n = n.to(device)
    dict_out = nnet(img_n, clip_img_n, text, t_img=n, t_text=torch.zeros_like(n, device=device), data_type=data_type)
    
    if isinstance(feed_model, DDP):
        origin_multiplier = feed_model.module.multiplier
        feed_model.module.set_multiplier(0.)
    else:
        feed_model.set_multiplier(0.)
    with torch.no_grad():
        origin_dict_out = nnet(img_n, clip_img_n, text, t_img=n, t_text=torch.zeros_like(n, device=device), data_type=data_type)
    if isinstance(feed_model, DDP):
        feed_model.module.set_multiplier(origin_multiplier)
    else:
        feed_model.set_multiplier(origin_multiplier)
    
    img_out, clip_img_out, text_out = dict_out["img_out"], dict_out["clip_img_out"], dict_out["text_out"]
    origin_img_out = origin_dict_out["img_out"]
    diff = img_eps - img_out
    if random.random() < config.mask_p:
        loss_img = mos(diff * mask) + mos((origin_img_out - img_out) * (1 - mask)) # make no mask area sim to origin nnet
    else:
        loss_img = mos(diff)
    loss_clip_img = mos(clip_img_eps - clip_img_out)
    loss = loss_img + 0. * loss_clip_img + 0. * mos(text_out)
    return loss, loss_img, loss_clip_img

