import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.utils import save_image
from torchvision.transforms import Normalize
import math
from tqdm import tqdm
import random
import argparse
import open_clip
import os
import json


_clip_model = None
_clip_tokenizer = None


def setup_kernel(k):
    k = np.asarray(k, dtype=np.float32)
    if k.ndim == 1:
        k = np.outer(k, k)
    k /= np.sum(k)
    assert k.ndim == 2
    assert k.shape[0] == k.shape[1]
    return k


def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)
    
    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape
    
    out = input.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)
    
    out = F.pad(
        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
    )
    out = out[
        :,
        max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
        :,
    ]
    
    out = out.permute(0, 3, 1, 2)
    out = out.reshape(
        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
    )
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    
    out = out[:, ::down_y, ::down_x, :]
    
    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
    
    return out.view(-1, channel, out_h, out_w)


def upsample_2d(x, k, factor=2, gain=1):
    assert isinstance(factor, int) and factor >= 1
    if k is None:
        k = [1] * factor
    k = setup_kernel(k) * (gain * (factor ** 2))
    p = k.shape[0] - factor
    k = torch.tensor(k, dtype=torch.float32, device=x.device)
    return upfirdn2d_native(x, k, factor, factor, 1, 1, (p + 1) // 2 + factor - 1, p // 2, (p + 1) // 2 + factor - 1, p // 2)


def downsample_2d(x, k, factor=2, gain=1):
    assert isinstance(factor, int) and factor >= 1
    if k is None:
        k = [1] * factor
    k = setup_kernel(k) * gain
    p = k.shape[0] - factor
    k = torch.tensor(k, dtype=torch.float32, device=x.device)
    return upfirdn2d_native(x, k, 1, 1, factor, factor, (p + 1) // 2, p // 2, (p + 1) // 2, p // 2)


def naive_upsample_2d(x, factor=2):
    _N, C, H, W = x.shape
    x = torch.reshape(x, (-1, C, H, 1, W, 1))
    x = x.repeat(1, 1, 1, factor, 1, factor)
    return torch.reshape(x, (-1, C, H * factor, W * factor))


def naive_downsample_2d(x, factor=2):
    _N, C, H, W = x.shape
    x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
    return torch.mean(x, dim=(3, 5))


def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device='cpu'):
    def _compute_fans(shape, in_axis=1, out_axis=0):
        receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
        fan_in = shape[in_axis] * receptive_field_size
        fan_out = shape[out_axis] * receptive_field_size
        return fan_in, fan_out
    
    def init(shape, dtype=dtype, device=device):
        fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
        if mode == "fan_in":
            denominator = fan_in
        elif mode == "fan_out":
            denominator = fan_out
        elif mode == "fan_avg":
            denominator = (fan_in + fan_out) / 2
        else:
            raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
        variance = scale / denominator
        if distribution == "normal":
            return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
        elif distribution == "uniform":
            return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
        else:
            raise ValueError("invalid distribution for variance scaling initializer")
    
    return init


def default_init(scale=1.):
    scale = 1e-10 if scale == 0 else scale
    return variance_scaling(scale, 'fan_avg', 'uniform')


def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
    conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
                     dilation=dilation, bias=bias)
    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
    nn.init.zeros_(conv.bias)
    return conv


def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
    conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
    nn.init.zeros_(conv.bias)
    return conv


class NIN(nn.Module):
    def __init__(self, in_dim, num_units, init_scale=0.1):
        super().__init__()
        self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
        self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
    
    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        y = torch.einsum('bhwc,cd->bhwd', x, self.W) + self.b
        return y.permute(0, 3, 1, 2)


class GaussianFourierProjection(nn.Module):
    def __init__(self, embedding_size=256, scale=1.0):
        super().__init__()
        self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
    
    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Combine(nn.Module):
    def __init__(self, dim1, dim2, method='cat'):
        super().__init__()
        self.Conv_0 = ddpm_conv1x1(dim1, dim2)
        self.method = method
    
    def forward(self, x, y):
        h = self.Conv_0(x)
        if self.method == 'cat':
            return torch.cat([h, y], dim=1)
        elif self.method == 'sum':
            return h + y
        else:
            raise ValueError(f'Method {self.method} not recognized.')


class AttnBlockpp(nn.Module):
    def __init__(self, channels, skip_rescale=False, init_scale=0.):
        super().__init__()
        self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
        self.NIN_0 = NIN(channels, channels)
        self.NIN_1 = NIN(channels, channels)
        self.NIN_2 = NIN(channels, channels)
        self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
        self.skip_rescale = skip_rescale
    
    def forward(self, x):
        B, C, H, W = x.shape
        h = self.GroupNorm_0(x)
        q = self.NIN_0(h)
        k = self.NIN_1(h)
        v = self.NIN_2(h)
        
        w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
        w = torch.reshape(w, (B, H, W, H * W))
        w = F.softmax(w, dim=-1)
        w = torch.reshape(w, (B, H, W, H, W))
        h = torch.einsum('bhwij,bcij->bchw', w, v)
        h = self.NIN_3(h)
        if not self.skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)


class Upsample(nn.Module):
    def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_ch = out_ch if out_ch else in_ch
        if not fir:
            if with_conv:
                self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
        else:
            if with_conv:
                self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
        self.fir = fir
        self.with_conv = with_conv
        self.fir_kernel = fir_kernel
        self.out_ch = out_ch
    
    def forward(self, x):
        B, C, H, W = x.shape
        if not self.fir:
            h = F.interpolate(x, (H * 2, W * 2), 'nearest')
            if self.with_conv:
                h = self.Conv_0(h)
        else:
            if not self.with_conv:
                h = upsample_2d(x, self.fir_kernel, factor=2)
            else:
                h = upsample_2d(x, self.fir_kernel, factor=2)
                h = self.Conv_0(h)
        return h


class Downsample(nn.Module):
    def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        out_ch = out_ch if out_ch else in_ch
        if not fir:
            if with_conv:
                self.Conv_0 = ddpm_conv3x3(in_ch, out_ch, stride=2, padding=0)
        else:
            if with_conv:
                self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
        self.fir = fir
        self.fir_kernel = fir_kernel
        self.with_conv = with_conv
        self.out_ch = out_ch
    
    def forward(self, x):
        B, C, H, W = x.shape
        if not self.fir:
            if self.with_conv:
                x = F.pad(x, (0, 1, 0, 1))
                x = self.Conv_0(x)
            else:
                x = F.avg_pool2d(x, 2, stride=2)
        else:
            if not self.with_conv:
                x = downsample_2d(x, self.fir_kernel, factor=2)
            else:
                x = downsample_2d(x, self.fir_kernel, factor=2)
                x = self.Conv_0(x)
        return x


class ResnetBlockDDPMpp(nn.Module):
    def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
                 dropout=0.1, skip_rescale=False, init_scale=0.):
        super().__init__()
        out_ch = out_ch if out_ch else in_ch
        self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
        self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
        if temb_dim is not None:
            self.Dense_0 = nn.Linear(temb_dim, out_ch)
            self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
            nn.init.zeros_(self.Dense_0.bias)
        self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
        self.Dropout_0 = nn.Dropout(dropout)
        self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=init_scale)
        if in_ch != out_ch:
            if conv_shortcut:
                self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
            else:
                self.NIN_0 = NIN(in_ch, out_ch)
        
        self.skip_rescale = skip_rescale
        self.act = act
        self.out_ch = out_ch
        self.conv_shortcut = conv_shortcut
    
    def forward(self, x, temb=None):
        h = self.act(self.GroupNorm_0(x))
        h = self.Conv_0(h)
        if temb is not None:
            h += self.Dense_0(self.act(temb))[:, :, None, None]
        h = self.act(self.GroupNorm_1(h))
        h = self.Dropout_0(h)
        h = self.Conv_1(h)
        if x.shape[1] != self.out_ch:
            if self.conv_shortcut:
                x = self.Conv_2(x)
            else:
                x = self.NIN_0(x)
        if not self.skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)


class ResnetBlockBigGANpp(nn.Module):
    def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
                 dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
                 skip_rescale=True, init_scale=0.):
        super().__init__()
        
        out_ch = out_ch if out_ch else in_ch
        self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
        self.up = up
        self.down = down
        self.fir = fir
        self.fir_kernel = fir_kernel
        
        self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
        if temb_dim is not None:
            self.Dense_0 = nn.Linear(temb_dim, out_ch)
            self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
            nn.init.zeros_(self.Dense_0.bias)
        
        self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
        self.Dropout_0 = nn.Dropout(dropout)
        self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=init_scale)
        if in_ch != out_ch or up or down:
            self.Conv_2 = ddpm_conv1x1(in_ch, out_ch)
        
        self.skip_rescale = skip_rescale
        self.act = act
        self.in_ch = in_ch
        self.out_ch = out_ch
    
    def forward(self, x, temb=None):
        h = self.act(self.GroupNorm_0(x))
        
        if self.up:
            if self.fir:
                h = upsample_2d(h, self.fir_kernel, factor=2)
                x = upsample_2d(x, self.fir_kernel, factor=2)
            else:
                h = naive_upsample_2d(h, factor=2)
                x = naive_upsample_2d(x, factor=2)
        elif self.down:
            if self.fir:
                h = downsample_2d(h, self.fir_kernel, factor=2)
                x = downsample_2d(x, self.fir_kernel, factor=2)
            else:
                h = naive_downsample_2d(h, factor=2)
                x = naive_downsample_2d(x, factor=2)
        
        h = self.Conv_0(h)
        if temb is not None:
            h += self.Dense_0(self.act(temb))[:, :, None, None]
        h = self.act(self.GroupNorm_1(h))
        h = self.Dropout_0(h)
        h = self.Conv_1(h)
        
        if self.in_ch != self.out_ch or self.up or self.down:
            x = self.Conv_2(x)
        
        if not self.skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)


class NCSNpp(nn.Module):
    def __init__(self, image_size=256, num_channels=3, nf=128, ch_mult=(1, 1, 2, 2, 2, 2, 2),
                 num_res_blocks=2, attn_resolutions=(16,), dropout=0.0, resamp_with_conv=True,
                 fir=True, fir_kernel=[1, 3, 3, 1], skip_rescale=True, 
                 resblock_type='biggan', progressive='output_skip', 
                 progressive_input='input_skip', progressive_combine='sum',
                 embedding_type='fourier', fourier_scale=16, init_scale=0.,
                 centered=True, scale_by_sigma=True):
        super().__init__()
        
        self.act = nn.SiLU()
        self.nf = nf
        self.num_res_blocks = num_res_blocks
        self.attn_resolutions = attn_resolutions
        self.num_resolutions = len(ch_mult)
        self.all_resolutions = [image_size // (2 ** i) for i in range(self.num_resolutions)]
        self.conditional = True
        self.skip_rescale = skip_rescale
        self.resblock_type = resblock_type.lower()
        self.progressive = progressive.lower()
        self.progressive_input = progressive_input.lower()
        self.embedding_type = embedding_type.lower()
        self.centered = centered
        self.scale_by_sigma = scale_by_sigma
        
        modules = []
        
        if embedding_type == 'fourier':
            modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale))
            embed_dim = 2 * nf
        else:
            raise ValueError(f'embedding type {embedding_type} unknown.')
        
        if self.conditional:
            modules.append(nn.Linear(embed_dim, nf * 4))
            modules[-1].weight.data = default_init()(modules[-1].weight.shape)
            nn.init.zeros_(modules[-1].bias)
            modules.append(nn.Linear(nf * 4, nf * 4))
            modules[-1].weight.data = default_init()(modules[-1].weight.shape)
            nn.init.zeros_(modules[-1].bias)
        
        AttnBlock = lambda channels: AttnBlockpp(channels, init_scale=init_scale, skip_rescale=skip_rescale)
        Upsample_func = lambda in_ch: Upsample(in_ch=in_ch, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
        
        if progressive == 'output_skip':
            self.pyramid_upsample = Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
        
        Downsample_func = lambda in_ch: Downsample(in_ch=in_ch, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
        
        if progressive_input == 'input_skip':
            self.pyramid_downsample = Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
        
        if resblock_type == 'ddpm':
            ResnetBlock = lambda in_ch, out_ch: ResnetBlockDDPMpp(
                self.act, in_ch, out_ch, temb_dim=nf * 4, dropout=dropout,
                init_scale=init_scale, skip_rescale=skip_rescale)
        elif resblock_type == 'biggan':
            ResnetBlock = lambda in_ch, out_ch, down=False, up=False: ResnetBlockBigGANpp(
                self.act, in_ch, out_ch, temb_dim=nf * 4, dropout=dropout,
                fir=fir, fir_kernel=fir_kernel, init_scale=init_scale,
                skip_rescale=skip_rescale, down=down, up=up)
        else:
            raise ValueError(f'resblock type {resblock_type} unrecognized.')
        
        if progressive_input != 'none':
            input_pyramid_ch = num_channels
        
        modules.append(ddpm_conv3x3(num_channels, nf))
        hs_c = [nf]
        
        in_ch = nf
        for i_level in range(self.num_resolutions):
            for i_block in range(num_res_blocks):
                out_ch = nf * ch_mult[i_level]
                if resblock_type == 'ddpm':
                    modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
                else:
                    modules.append(ResnetBlock(in_ch, out_ch))
                in_ch = out_ch
                
                if self.all_resolutions[i_level] in attn_resolutions:
                    modules.append(AttnBlock(channels=in_ch))
                hs_c.append(in_ch)
            
            if i_level != self.num_resolutions - 1:
                if resblock_type == 'ddpm':
                    modules.append(Downsample_func(in_ch=in_ch))
                else:
                    modules.append(ResnetBlock(in_ch, in_ch, down=True))
                
                if progressive_input == 'input_skip':
                    modules.append(Combine(dim1=input_pyramid_ch, dim2=in_ch, method=progressive_combine))
                    if progressive_combine == 'cat':
                        in_ch *= 2
                
                hs_c.append(in_ch)
        
        in_ch = hs_c[-1]
        if resblock_type == 'ddpm':
            modules.append(ResnetBlock(in_ch=in_ch, out_ch=in_ch))
        else:
            modules.append(ResnetBlock(in_ch, in_ch))
        modules.append(AttnBlock(channels=in_ch))
        if resblock_type == 'ddpm':
            modules.append(ResnetBlock(in_ch=in_ch, out_ch=in_ch))
        else:
            modules.append(ResnetBlock(in_ch, in_ch))
        
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(num_res_blocks + 1):
                out_ch = nf * ch_mult[i_level]
                if resblock_type == 'ddpm':
                    modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
                else:
                    modules.append(ResnetBlock(in_ch + hs_c.pop(), out_ch))
                in_ch = out_ch
            
            if self.all_resolutions[i_level] in attn_resolutions:
                modules.append(AttnBlock(channels=in_ch))
            
            if progressive != 'none':
                if i_level == self.num_resolutions - 1:
                    if progressive == 'output_skip':
                        modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
                                                    num_channels=in_ch, eps=1e-6))
                        modules.append(ddpm_conv3x3(in_ch, num_channels, init_scale=init_scale))
                        pyramid_ch = num_channels
                else:
                    if progressive == 'output_skip':
                        modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
                                                    num_channels=in_ch, eps=1e-6))
                        modules.append(ddpm_conv3x3(in_ch, num_channels, bias=True, init_scale=init_scale))
                        pyramid_ch = num_channels
            
            if i_level != 0:
                if resblock_type == 'ddpm':
                    modules.append(Upsample_func(in_ch=in_ch))
                else:
                    modules.append(ResnetBlock(in_ch, in_ch, up=True))
        
        assert not hs_c
        
        if progressive != 'output_skip':
            modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
                                        num_channels=in_ch, eps=1e-6))
            modules.append(ddpm_conv3x3(in_ch, num_channels, init_scale=init_scale))
        
        self.all_modules = nn.ModuleList(modules)
        self.hs_c_len = len(hs_c)
    
    def forward(self, x, time_cond):
        modules = self.all_modules
        m_idx = 0
        
        if self.embedding_type == 'fourier':
            used_sigmas = time_cond
            temb = modules[m_idx](torch.log(used_sigmas))
            m_idx += 1
        else:
            raise ValueError(f'embedding type {self.embedding_type} unknown.')
        
        if self.conditional:
            temb = modules[m_idx](temb)
            m_idx += 1
            temb = modules[m_idx](self.act(temb))
            m_idx += 1
        else:
            temb = None
        
        if not self.centered:
            x = 2 * x - 1.
        
        input_pyramid = None
        if self.progressive_input != 'none':
            input_pyramid = x
        
        hs = [modules[m_idx](x)]
        m_idx += 1
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = modules[m_idx](hs[-1], temb)
                m_idx += 1
                if h.shape[-1] in self.attn_resolutions:
                    h = modules[m_idx](h)
                    m_idx += 1
                hs.append(h)
            
            if i_level != self.num_resolutions - 1:
                if self.resblock_type == 'ddpm':
                    h = modules[m_idx](hs[-1])
                    m_idx += 1
                else:
                    h = modules[m_idx](hs[-1], temb)
                    m_idx += 1
                
                if self.progressive_input == 'input_skip':
                    input_pyramid = self.pyramid_downsample(input_pyramid)
                    h = modules[m_idx](input_pyramid, h)
                    m_idx += 1
                
                hs.append(h)
        
        h = hs[-1]
        h = modules[m_idx](h, temb)
        m_idx += 1
        h = modules[m_idx](h)
        m_idx += 1
        h = modules[m_idx](h, temb)
        m_idx += 1
        
        pyramid = None
        
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
                m_idx += 1
            
            if h.shape[-1] in self.attn_resolutions:
                h = modules[m_idx](h)
                m_idx += 1
            
            if self.progressive != 'none':
                if i_level == self.num_resolutions - 1:
                    if self.progressive == 'output_skip':
                        pyramid = self.act(modules[m_idx](h))
                        m_idx += 1
                        pyramid = modules[m_idx](pyramid)
                        m_idx += 1
                else:
                    if self.progressive == 'output_skip':
                        pyramid = self.pyramid_upsample(pyramid)
                        pyramid_h = self.act(modules[m_idx](h))
                        m_idx += 1
                        pyramid_h = modules[m_idx](pyramid_h)
                        m_idx += 1
                        pyramid = pyramid + pyramid_h
            
            if i_level != 0:
                if self.resblock_type == 'ddpm':
                    h = modules[m_idx](h)
                    m_idx += 1
                else:
                    h = modules[m_idx](h, temb)
                    m_idx += 1
        
        assert not hs
        
        if self.progressive == 'output_skip':
            h = pyramid
        else:
            h = self.act(modules[m_idx](h))
            m_idx += 1
            h = modules[m_idx](h)
            m_idx += 1
        
        if self.scale_by_sigma:
            used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
            h = h / used_sigmas
        
        return h


def alignment_eval_with_grad(x, 
                  positive_texts=["a photo of a woman"],
                  negative_texts=["a photo of a man"],
                  model_path='./open_clip_pytorch_model.bin'):
    global _clip_model, _clip_tokenizer
    
    device = x.device
    
    if _clip_model is None:
        _clip_model, _, _ = open_clip.create_model_and_transforms(
            'convnext_xxlarge', pretrained=model_path)
        _clip_model = _clip_model.to(device)
        _clip_model.eval()
        for param in _clip_model.parameters():
            param.requires_grad = False
        _clip_tokenizer = open_clip.get_tokenizer('convnext_xxlarge')
    
    x_norm = (x + 1) / 2
    
    normalize = Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
    x_normalized = normalize(x_norm)
    
    positive_tokens = _clip_tokenizer(positive_texts).to(device)
    negative_tokens = _clip_tokenizer(negative_texts).to(device)
    
    image_features = _clip_model.encode_image(x_normalized)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    
    with torch.no_grad():
        positive_text_features = _clip_model.encode_text(positive_tokens)
        positive_text_features = positive_text_features / positive_text_features.norm(dim=-1, keepdim=True)
        
        negative_text_features = _clip_model.encode_text(negative_tokens)
        negative_text_features = negative_text_features / negative_text_features.norm(dim=-1, keepdim=True)
    
    positive_similarities = image_features @ positive_text_features.T
    negative_similarities = image_features @ negative_text_features.T
    
    max_positive_cos = positive_similarities.max(dim=1)[0]
    max_negative_cos = negative_similarities.max(dim=1)[0]
    
    exp_positive = torch.exp(100 * max_positive_cos)
    exp_negative = torch.exp(100 * max_negative_cos)
    align_prob = exp_positive / (exp_positive + exp_negative)
    
    align_logit = torch.log(align_prob + 1e-10)
    
    return align_logit, align_prob


def systematic_resample(x, weights, device):
    num_samples = x.shape[0]
    
    weights_normalized = weights / weights.sum()
    
    cumsum = torch.cumsum(weights_normalized, dim=0)
    uniform_samples = (torch.arange(num_samples, device=device).float() + torch.rand(1, device=device)) / num_samples
    
    indices = torch.searchsorted(cumsum, uniform_samples)
    indices = torch.clamp(indices, 0, num_samples - 1)
    
    x_resampled = x[indices]
    
    return x_resampled, indices


def euler_sampler(model, z0, N=400, device='cuda', verbose=True, 
                  guidance_scale=1.0, beta=2.0,
                  start_resample=30, start_flow=160, n_add=200,
                  positive_texts=["a photo of a woman"],
                  negative_texts=["a photo of a man"],
                  model_path='./open_clip_pytorch_model.bin'):
    eps = 1e-3
    dt = (1.0 - eps) / N
    
    x = z0.to(device)
    num_samples = x.shape[0]
    model.eval()
    
    sample_weights = None
    log_align_sf = None
    
    iterator = tqdm(range(N), desc="Sampling")
    for i in iterator:
        t = i / N * (1.0 - eps) + eps
        t_tensor = torch.ones(x.shape[0], device=device) * t
        t_next = (i + 1) / N * (1.0 - eps) + eps
        t_next_tensor = torch.ones(x.shape[0], device=device) * t_next
        
        if i < start_resample:
            with torch.no_grad():
                pred = model(x, t_tensor * 999)
                x = x + pred * dt
        
        elif i == start_resample:
            with torch.no_grad():
                pred = model(x, t_tensor * 999)
                x = x + pred * dt
                
                pred_next = model(x, t_next_tensor * 999)
                x_next_lkah = x + pred_next * (1 - t_next)
                
                align_logit, _ = alignment_eval_with_grad(x_next_lkah,
                                                         positive_texts=positive_texts,
                                                         negative_texts=negative_texts,
                                                         model_path=model_path)
                
                log_weights = align_logit
                weights = torch.exp(log_weights - log_weights.max())
                
                x, _ = systematic_resample(x, weights, device)
        
        elif start_resample < i < start_flow - 1:
            with torch.no_grad():
                pred = model(x, t_tensor * 999)
                x_cur_lkah = x + pred * (1 - t)
                drift_org = -(beta**2 / t) * x + (1 + beta**2) * pred
            
            if guidance_scale > 0:
                x_lkad = (x + pred * (1 - t)).detach().requires_grad_(True)
                
                align_logit_temp, _ = alignment_eval_with_grad(x_lkad,
                                                              positive_texts=positive_texts,
                                                              negative_texts=negative_texts,
                                                              model_path=model_path)
                
                vadd = torch.autograd.grad(outputs=align_logit_temp.sum(),
                                          inputs=x_lkad,
                                          create_graph=False,
                                          retain_graph=False)[0]
                
                with torch.no_grad():
                    v_mod = pred + ((1 - t) / t) * guidance_scale * vadd
            else:
                v_mod = pred
            
            with torch.no_grad():
                x_prev = x.clone()
                
                epsilon = torch.randn_like(x)
                drift = -(beta**2 / t) * x + (1 + beta**2) * v_mod
                diffusion_arg = torch.tensor(2 * (1 - t) * dt / t, device=device, dtype=x.dtype)
                diffusion_coeff = beta * torch.sqrt(diffusion_arg)
                
                x = x + drift * dt + diffusion_coeff * epsilon
                
                pred_next = model(x, t_next_tensor * 999)
                x_next_lkah = x + pred_next * (1 - t_next)
                
                mean_prop = x_prev + drift * dt
                log_prop = -0.5 * ((x - mean_prop) ** 2).sum(dim=(1,2,3)) / (diffusion_coeff ** 2 + 1e-8)
                
                mean_org = x_prev + drift_org * dt
                log_org = -0.5 * ((x - mean_org) ** 2).sum(dim=(1,2,3)) / (diffusion_coeff ** 2 + 1e-8)
                
                align_logit_next, _ = alignment_eval_with_grad(x_next_lkah,
                                                              positive_texts=positive_texts,
                                                              negative_texts=negative_texts,
                                                              model_path=model_path)
                
                align_logit_cur, _ = alignment_eval_with_grad(x_cur_lkah,
                                                             positive_texts=positive_texts,
                                                             negative_texts=negative_texts,
                                                             model_path=model_path)
                
                log_weights = align_logit_next + log_org - align_logit_cur - log_prop
                weights = torch.exp(log_weights - log_weights.max())
                
                x, _ = systematic_resample(x, weights, device)
        
        elif i == start_flow - 1:
            with torch.no_grad():
                pred = model(x, t_tensor * 999)
                x_cur_lkah = x + pred * (1 - t)
                drift_org = -(beta**2 / t) * x + (1 + beta**2) * pred
            
            if guidance_scale > 0:
                x_lkad = (x + pred * (1 - t)).detach().requires_grad_(True)
                
                align_logit_temp, _ = alignment_eval_with_grad(x_lkad,
                                                              positive_texts=positive_texts,
                                                              negative_texts=negative_texts,
                                                              model_path=model_path)
                
                vadd = torch.autograd.grad(outputs=align_logit_temp.sum(),
                                          inputs=x_lkad,
                                          create_graph=False,
                                          retain_graph=False)[0]
                
                with torch.no_grad():
                    v_mod = pred + ((1 - t) / t) * guidance_scale * vadd
            else:
                v_mod = pred
            
            with torch.no_grad():
                x_prev = x.clone()
                
                epsilon = torch.randn_like(x)
                drift = -(beta**2 / t) * x + (1 + beta**2) * v_mod
                diffusion_arg = torch.tensor(2 * (1 - t) * dt / t, device=device, dtype=x.dtype)
                diffusion_coeff = beta * torch.sqrt(diffusion_arg)
                
                x = x + drift * dt + diffusion_coeff * epsilon
                
                pred_next = model(x, t_next_tensor * 999)
                x_next_lkah = x + pred_next * (1 - t_next)
                
                mean_prop = x_prev + drift * dt
                log_prop = -0.5 * ((x - mean_prop) ** 2).sum(dim=(1,2,3)) / (diffusion_coeff ** 2 + 1e-8)
                
                mean_org = x_prev + drift_org * dt
                log_org = -0.5 * ((x - mean_org) ** 2).sum(dim=(1,2,3)) / (diffusion_coeff ** 2 + 1e-8)
                
                align_logit_next, _ = alignment_eval_with_grad(x_next_lkah,
                                                              positive_texts=positive_texts,
                                                              negative_texts=negative_texts,
                                                              model_path=model_path)
                
                align_logit_cur, _ = alignment_eval_with_grad(x_cur_lkah,
                                                             positive_texts=positive_texts,
                                                             negative_texts=negative_texts,
                                                             model_path=model_path)
                
                log_weights = align_logit_next + log_org - align_logit_cur - log_prop
                
                sample_weights = torch.exp(log_weights)
        
        else:
            with torch.no_grad():
                pred = model(x, t_tensor * 999)
                
                if i == start_flow:
                    x_sf_lkad = x + pred * (1 - t)
                    log_align_sf, _ = alignment_eval_with_grad(x_sf_lkad,
                                                              positive_texts=positive_texts,
                                                              negative_texts=negative_texts,
                                                              model_path=model_path)
                
                if i <= n_add:
                    drift_org = -(beta**2 / t) * x + (1 + beta**2) * pred
                    epsilon = torch.randn_like(x)
                    diffusion_arg = torch.tensor(2 * (1 - t) * dt / t, device=device, dtype=x.dtype)
                    diffusion_coeff = beta * torch.sqrt(diffusion_arg)
                    
                    x = x + drift_org * dt + diffusion_coeff * epsilon
                else:
                    x = x + pred * dt
                
                if i == N - 1:
                    log_align_final, _ = alignment_eval_with_grad(x,
                                                                 positive_texts=positive_texts,
                                                                 negative_texts=negative_texts,
                                                                 model_path=model_path)
                    
                    log_weights = log_align_final - log_align_sf
                    
                    sample_weights = sample_weights * torch.exp(log_weights)
    
    return x, sample_weights


def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

PROMPT_CONFIG = {
    "attribute": "Female",
    "positive_text": ["a photo of a woman"],
    "negative_texts": ["a photo of a man", "a photo of a masculine person", "a photo of a male individual", "a photo of a gentleman"]
}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--num_steps', type=int, default=400)
    parser.add_argument('--num_samples', type=int, default=25)
    parser.add_argument('--batch_size', type=int, default=None)
    parser.add_argument('--checkpoint', type=str, default='checkpoint_10.pth')
    parser.add_argument('--guidance_scale', type=float, default=1.0)
    parser.add_argument('--beta', type=float, default=2.0)
    parser.add_argument('--start_resample', type=int, default=30)
    parser.add_argument('--start_flow', type=int, default=160)
    parser.add_argument('--n_add', type=int, default=200)
    parser.add_argument('--clip_model_path', type=str, default='./open_clip_pytorch_model.bin')
    parser.add_argument('--gpu', type=int, default=0)
    
    args = parser.parse_args()
    
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu)
        device = torch.device(f'cuda:{args.gpu}')
    else:
        device = torch.device('cpu')
    
    config = {
        'image_size': 256,
        'num_channels': 3,
        'nf': 128,
        'ch_mult': (1, 1, 2, 2, 2, 2, 2),
        'num_res_blocks': 2,
        'attn_resolutions': (16,),
        'dropout': 0.0,
        'resamp_with_conv': True,
        'fir': True,
        'fir_kernel': [1, 3, 3, 1],
        'skip_rescale': True,
        'resblock_type': 'biggan',
        'progressive': 'output_skip',
        'progressive_input': 'input_skip',
        'progressive_combine': 'sum',
        'embedding_type': 'fourier',
        'fourier_scale': 16,
        'init_scale': 0.,
        'centered': True,
        'scale_by_sigma': True,
    }
    
    model = NCSNpp(**config).to(device)
    
    checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
    
    if 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v
        else:
            new_state_dict[k] = v
    
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()
    
    positive_texts = PROMPT_CONFIG["positive_text"]
    negative_texts = PROMPT_CONFIG["negative_texts"]
    
    output_dir = "output"
    os.makedirs(output_dir, exist_ok=True)
    
    base_filename = f"seed{args.seed}_beta{args.beta:.1f}_gs{args.guidance_scale:.1f}_sr{args.start_resample}_sf{args.start_flow}_nadd{args.n_add}"
    combine_resample_filename = os.path.join(output_dir, f"female_{base_filename}.png")
    
    batch_size = args.batch_size if args.batch_size is not None else args.num_samples
    
    set_seed(args.seed)
    
    all_samples = []
    all_sample_weights = []
    num_batches = (args.num_samples + batch_size - 1) // batch_size
    
    for batch_idx in range(num_batches):
        current_batch_size = min(batch_size, args.num_samples - batch_idx * batch_size)
        
        z0 = torch.randn(current_batch_size, config['num_channels'], 
                         config['image_size'], config['image_size'])
        
        samples, sample_weights = euler_sampler(
            model, z0, N=args.num_steps, device=device, 
            verbose=True, 
            guidance_scale=args.guidance_scale,
            beta=args.beta,
            start_resample=args.start_resample,
            start_flow=args.start_flow,
            n_add=args.n_add,
            positive_texts=positive_texts,
            negative_texts=negative_texts,
            model_path=args.clip_model_path
        )
        
        all_samples.append(samples.cpu())
        all_sample_weights.append(sample_weights.cpu())
    
    all_samples = torch.cat(all_samples, dim=0)
    all_sample_weights = torch.cat(all_sample_weights, dim=0)
    
    resampled_samples, _ = systematic_resample(
        all_samples, all_sample_weights, device='cpu'
    )
    
    resampled_samples_display = (resampled_samples + 1.0) / 2.0
    resampled_samples_display = torch.clamp(resampled_samples_display, 0.0, 1.0)
    
    nrow = int(np.ceil(np.sqrt(args.num_samples)))
    save_image(resampled_samples_display, combine_resample_filename, nrow=nrow, padding=2)


if __name__ == '__main__':
    main()