import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.utils as vutils
import math
from functools import partial
import random
import os
from tqdm import tqdm

def set_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

class ExponentialMovingAverage:
    def __init__(self, parameters, decay, use_num_updates=True):
        if decay < 0.0 or decay > 1.0:
            raise ValueError('Decay must be between 0 and 1')
        self.decay = decay
        self.num_updates = 0 if use_num_updates else None
        self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad]
        self.collected_params = []

    def update(self, parameters):
        decay = self.decay
        if self.num_updates is not None:
            self.num_updates += 1
            decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
        one_minus_decay = 1.0 - decay
        with torch.no_grad():
            parameters = [p for p in parameters if p.requires_grad]
            for s_param, param in zip(self.shadow_params, parameters):
                s_param.sub_(one_minus_decay * (s_param - param))

    def copy_to(self, parameters):
        parameters = [p for p in parameters if p.requires_grad]
        for s_param, param in zip(self.shadow_params, parameters):
            if param.requires_grad:
                param.data.copy_(s_param.data)

    def store(self, parameters):
        self.collected_params = [param.clone() for param in parameters if param.requires_grad]

    def restore(self, parameters):
        parameters = [p for p in parameters if p.requires_grad]
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)

    def state_dict(self):
        return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params)

    def load_state_dict(self, state_dict):
        self.decay = state_dict['decay']
        self.num_updates = state_dict['num_updates']
        self.shadow_params = state_dict['shadow_params']

class Config:
    def __init__(self):
        self.data = type('data', (), {
            'dataset': 'CIFAR10',
            'image_size': 32,
            'num_channels': 3,
            'centered': True,
            'random_flip': True,
            'uniform_dequantization': False
        })()
        
        self.model = type('model', (), {
            'name': 'ncsnpp',
            'scale_by_sigma': False,
            'ema_rate': 0.999999,
            'dropout': 0.15,
            'normalization': 'GroupNorm',
            'nonlinearity': 'swish',
            'nf': 128,
            'ch_mult': (1, 2, 2, 2),
            'num_res_blocks': 4,
            'attn_resolutions': (16,),
            'resamp_with_conv': True,
            'conditional': True,
            'fir': False,
            'fir_kernel': [1, 3, 3, 1],
            'skip_rescale': True,
            'resblock_type': 'biggan',
            'progressive': 'none',
            'progressive_input': 'none',
            'progressive_combine': 'sum',
            'attention_type': 'ddpm',
            'init_scale': 0.,
            'embedding_type': 'positional',
            'fourier_scale': 16,
            'conv_size': 3,
            'sigma_min': 0.01,
            'sigma_max': 50,
            'num_scales': 1000,
            'beta_min': 0.1,
            'beta_max': 20.
        })()
        
        self.training = type('training', (), {
            'sde': 'rectified_flow',
            'continuous': False,
            'reduce_mean': True
        })()
        
        self.sampling = type('sampling', (), {
            'method': 'rectified_flow',
            'init_type': 'gaussian',
            'init_noise_scale': 1.0,
            'use_ode_sampler': 'rk45',
            'ode_tol': 1e-5,
            'sample_N': 1000
        })()

def get_sigmas(config):
    sigmas = np.exp(np.linspace(np.log(config.model.sigma_max), 
                   np.log(config.model.sigma_min), 
                   config.model.num_scales))
    return sigmas

def get_act():
    return nn.SiLU()

def variance_scaling(scale, mode, distribution, dtype=torch.float32, device='cpu'):
    def init(shape, dtype=dtype, device=device):
        fan_in = np.prod(shape) / shape[0] if len(shape) > 1 else 1
        fan_out = shape[0] if len(shape) > 1 else 1
        
        if mode == "fan_in":
            denominator = fan_in
        elif mode == "fan_out":
            denominator = fan_out
        elif mode == "fan_avg":
            denominator = (fan_in + fan_out) / 2
        else:
            raise ValueError(f"invalid mode: {mode}")
            
        variance = scale / denominator
        if distribution == "normal":
            return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
        elif distribution == "uniform":
            return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
        else:
            raise ValueError(f"invalid distribution: {distribution}")
    return init

def default_init(scale=1.):
    scale = 1e-10 if scale == 0 else scale
    return variance_scaling(scale, 'fan_avg', 'uniform')

def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
    conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 
                     padding=padding, dilation=dilation, bias=bias)
    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
    if bias:
        nn.init.zeros_(conv.bias)
    return conv

def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
    conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 
                     padding=padding, bias=bias)
    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
    if bias:
        nn.init.zeros_(conv.bias)
    return conv

def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
    half_dim = embedding_dim // 2
    emb = math.log(max_positions) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:
        emb = F.pad(emb, (0, 1), mode='constant')
    return emb

class NIN(nn.Module):
    def __init__(self, in_dim, num_units, init_scale=0.1):
        super().__init__()
        self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
        self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        y = torch.einsum('bhwc,cd->bhwd', x, self.W) + self.b
        return y.permute(0, 3, 1, 2)

class AttnBlockpp(nn.Module):
    def __init__(self, channels, skip_rescale=False, init_scale=0.):
        super().__init__()
        self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), 
                                       num_channels=channels, eps=1e-6)
        self.NIN_0 = NIN(channels, channels)
        self.NIN_1 = NIN(channels, channels)
        self.NIN_2 = NIN(channels, channels)
        self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
        self.skip_rescale = skip_rescale

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.GroupNorm_0(x)
        q = self.NIN_0(h)
        k = self.NIN_1(h)
        v = self.NIN_2(h)

        w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
        w = torch.reshape(w, (B, H, W, H * W))
        w = F.softmax(w, dim=-1)
        w = torch.reshape(w, (B, H, W, H, W))
        h = torch.einsum('bhwij,bcij->bchw', w, v)
        h = self.NIN_3(h)
        
        if not self.skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)

def naive_upsample_2d(x, factor=2):
    _N, C, H, W = x.shape
    x = torch.reshape(x, (-1, C, H, 1, W, 1))
    x = x.repeat(1, 1, 1, factor, 1, factor)
    return torch.reshape(x, (-1, C, H * factor, W * factor))

def naive_downsample_2d(x, factor=2):
    _N, C, H, W = x.shape
    x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
    return torch.mean(x, dim=(3, 5))

class ResnetBlockBigGANpp(nn.Module):
    def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
                 dropout=0.1, skip_rescale=True, init_scale=0.):
        super().__init__()
        out_ch = out_ch if out_ch else in_ch
        self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), 
                                       num_channels=in_ch, eps=1e-6)
        self.up = up
        self.down = down
        self.Conv_0 = conv3x3(in_ch, out_ch)
        
        if temb_dim is not None:
            self.Dense_0 = nn.Linear(temb_dim, out_ch)
            self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
            nn.init.zeros_(self.Dense_0.bias)

        self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), 
                                       num_channels=out_ch, eps=1e-6)
        self.Dropout_0 = nn.Dropout(dropout)
        self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
        
        if in_ch != out_ch or up or down:
            self.Conv_2 = conv1x1(in_ch, out_ch)

        self.skip_rescale = skip_rescale
        self.act = act
        self.in_ch = in_ch
        self.out_ch = out_ch

    def forward(self, x, temb=None):
        h = self.act(self.GroupNorm_0(x))

        if self.up:
            h = naive_upsample_2d(h, factor=2)
            x = naive_upsample_2d(x, factor=2)
        elif self.down:
            h = naive_downsample_2d(h, factor=2)
            x = naive_downsample_2d(x, factor=2)

        h = self.Conv_0(h)
        if temb is not None:
            h += self.Dense_0(self.act(temb))[:, :, None, None]
        h = self.act(self.GroupNorm_1(h))
        h = self.Dropout_0(h)
        h = self.Conv_1(h)

        if self.in_ch != self.out_ch or self.up or self.down:
            x = self.Conv_2(x)

        if not self.skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)

class NCSNpp(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.act = get_act()
        
        self.register_buffer('sigmas', torch.tensor(get_sigmas(config)))
        
        self.nf = nf = config.model.nf
        ch_mult = config.model.ch_mult
        self.num_res_blocks = num_res_blocks = config.model.num_res_blocks
        self.attn_resolutions = attn_resolutions = config.model.attn_resolutions
        dropout = config.model.dropout
        resamp_with_conv = config.model.resamp_with_conv
        self.num_resolutions = num_resolutions = len(ch_mult)
        self.all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)]
        
        self.conditional = config.model.conditional
        self.skip_rescale = config.model.skip_rescale
        self.resblock_type = config.model.resblock_type.lower()
        self.progressive = config.model.progressive.lower()
        self.progressive_input = config.model.progressive_input.lower()
        self.embedding_type = config.model.embedding_type.lower()
        init_scale = config.model.init_scale
        
        modules = []
        
        if self.embedding_type == 'positional':
            embed_dim = nf
        else:
            raise ValueError(f'embedding type {self.embedding_type} unknown.')

        if self.conditional:
            modules.append(nn.Linear(embed_dim, nf * 4))
            modules[-1].weight.data = default_init()(modules[-1].weight.shape)
            nn.init.zeros_(modules[-1].bias)
            modules.append(nn.Linear(nf * 4, nf * 4))
            modules[-1].weight.data = default_init()(modules[-1].weight.shape)
            nn.init.zeros_(modules[-1].bias)

        AttnBlock = partial(AttnBlockpp, init_scale=init_scale, skip_rescale=self.skip_rescale)
        ResnetBlock = partial(ResnetBlockBigGANpp, act=self.act, dropout=dropout,
                            skip_rescale=self.skip_rescale, init_scale=init_scale,
                            temb_dim=nf * 4)

        channels = config.data.num_channels
        modules.append(conv3x3(channels, nf))
        hs_c = [nf]

        in_ch = nf
        for i_level in range(num_resolutions):
            for i_block in range(num_res_blocks):
                out_ch = nf * ch_mult[i_level]
                modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
                in_ch = out_ch
                if self.all_resolutions[i_level] in attn_resolutions:
                    modules.append(AttnBlock(channels=in_ch))
                hs_c.append(in_ch)

            if i_level != num_resolutions - 1:
                modules.append(ResnetBlock(down=True, in_ch=in_ch))
                hs_c.append(in_ch)

        in_ch = hs_c[-1]
        modules.append(ResnetBlock(in_ch=in_ch))
        modules.append(AttnBlock(channels=in_ch))
        modules.append(ResnetBlock(in_ch=in_ch))

        for i_level in reversed(range(num_resolutions)):
            for i_block in range(num_res_blocks + 1):
                out_ch = nf * ch_mult[i_level]
                modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
                in_ch = out_ch

            if self.all_resolutions[i_level] in attn_resolutions:
                modules.append(AttnBlock(channels=in_ch))

            if i_level != 0:
                modules.append(ResnetBlock(in_ch=in_ch, up=True))

        modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 
                                  num_channels=in_ch, eps=1e-6))
        modules.append(conv3x3(in_ch, channels, init_scale=init_scale))

        self.all_modules = nn.ModuleList(modules)

    def forward(self, x, time_cond):
        modules = self.all_modules
        m_idx = 0
        
        if self.embedding_type == 'positional':
            timesteps = time_cond
            used_sigmas = self.sigmas[time_cond.long()]
            temb = get_timestep_embedding(timesteps, self.nf)
        else:
            raise ValueError(f'embedding type {self.embedding_type} unknown.')

        if self.conditional:
            temb = modules[m_idx](temb)
            m_idx += 1
            temb = modules[m_idx](self.act(temb))
            m_idx += 1
        else:
            temb = None

        if not self.config.data.centered:
            x = 2 * x - 1.

        hs = [modules[m_idx](x)]
        m_idx += 1
        
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = modules[m_idx](hs[-1], temb)
                m_idx += 1
                if h.shape[-1] in self.attn_resolutions:
                    h = modules[m_idx](h)
                    m_idx += 1
                hs.append(h)

            if i_level != self.num_resolutions - 1:
                h = modules[m_idx](hs[-1], temb)
                m_idx += 1
                hs.append(h)

        h = hs[-1]
        h = modules[m_idx](h, temb)
        m_idx += 1
        h = modules[m_idx](h)
        m_idx += 1
        h = modules[m_idx](h, temb)
        m_idx += 1

        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
                m_idx += 1

            if h.shape[-1] in self.attn_resolutions:
                h = modules[m_idx](h)
                m_idx += 1

            if i_level != 0:
                h = modules[m_idx](h, temb)
                m_idx += 1

        h = self.act(modules[m_idx](h))
        m_idx += 1
        h = modules[m_idx](h)
        m_idx += 1

        if self.config.model.scale_by_sigma:
            used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
            h = h / used_sigmas

        return h

class VGG(nn.Module):
    def __init__(self, features, num_classes=10, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.classifier = nn.Sequential(
            nn.Linear(512 * 1 * 1, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

cfg_B = [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"]

def vgg13_bn(pretrained=False, weights_path=None):
    model = VGG(make_layers(cfg_B, batch_norm=True), num_classes=10)
    
    if pretrained and weights_path:
        if os.path.exists(weights_path):
            state_dict = torch.load(weights_path, map_location='cpu')
            model.load_state_dict(state_dict)
    
    return model

class RectifiedFlow:
    def __init__(self, config, classifier=None, guidance_scale=1.0, beta=1.0, target_class=1):
        self.init_type = config.sampling.init_type
        self.noise_scale = config.sampling.init_noise_scale
        self.T = 1.0
        self.classifier = classifier
        self.guidance_scale = guidance_scale
        self.beta = beta
        
        self.mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(1, 3, 1, 1)
        self.std = torch.tensor([0.2471, 0.2435, 0.2616]).view(1, 3, 1, 1)
        
        self.target_class = target_class

    def get_z0(self, shape, device, generator=None):
        if self.init_type == 'gaussian':
            if generator is not None:
                return torch.randn(shape, device=device, generator=generator) * self.noise_scale
            else:
                return torch.randn(shape, device=device) * self.noise_scale
        else:
            raise NotImplementedError("INITIALIZATION TYPE NOT IMPLEMENTED")

    def compute_vadd(self, x, velocity, t_current, device):
        x = x.detach().requires_grad_(True)
        y = x + (1 - t_current) * velocity
        y_01 = (y + 1) / 2
        
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)
        y_normalized = (y_01 - self.mean) / self.std
        
        logits = self.classifier(y_normalized)
        probs = F.softmax(logits, dim=1)
        target_prob = probs[:, self.target_class]
        log_prob = torch.log(target_prob + 1e-8)
        
        grad = torch.autograd.grad(log_prob.sum(), x, create_graph=False)[0]
        
        return grad
    
    def compute_target_prob(self, x, device):
        x_01 = (x + 1) / 2
        
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)
        x_normalized = (x_01 - self.mean) / self.std
        
        with torch.no_grad():
            logits = self.classifier(x_normalized)
            probs = F.softmax(logits, dim=1)
            target_prob = probs[:, self.target_class]
        
        return target_prob
    
    def resample_particles(self, x, weights, device):
        num_samples = x.shape[0]
        weights_normalized = weights / weights.sum()
        
        cumsum = torch.cumsum(weights_normalized, dim=0)
        uniform_samples = (torch.arange(num_samples, device=device).float() + torch.rand(1, device=device)) / num_samples
        
        indices = torch.searchsorted(cumsum, uniform_samples)
        indices = torch.clamp(indices, 0, num_samples - 1)
        
        x_resampled = x[indices]
        
        return x_resampled, weights_normalized
    
    def compute_gaussian_log_prob(self, x, mean, variance):
        batch_size = x.shape[0]
        dim = x[0].numel()
        
        diff_squared = ((x - mean) ** 2).view(batch_size, -1).sum(dim=1)
        log_prob = -0.5 * (dim * torch.log(2 * torch.tensor(np.pi) * variance) + diff_squared / variance)
        
        return log_prob
    
    @torch.no_grad()
    def euler_ode_sampling(self, model, shape, device, N=50, generator=None, 
                          start_resample=300, start_flow=500, vadd_norm_max=1e100):
        
        eps = 1e-3
        dt = (self.T - eps) / N
        num_samples = shape[0]
        
        x = self.get_z0(shape, device, generator)
        
        model.eval()
        if self.classifier is not None:
            self.classifier.eval()
        
        beta_squared = self.beta ** 2
        sample_weights = None
        prob_lkah_start_flow = None
        
        def clip_vadd_norm(vadd, max_norm=1e100):
            batch_size = vadd.shape[0]
            vadd_flat = vadd.view(batch_size, -1)
            norms = torch.norm(vadd_flat, dim=1, keepdim=True)
            scale = torch.clamp(max_norm / (norms + 1e-8), max=1.0)
            vadd_flat_clipped = vadd_flat * scale
            clipped_vadd = vadd_flat_clipped.view_as(vadd)
            return clipped_vadd
        
        pbar = tqdm(range(N), desc="Sampling", unit="step")
        
        for i in pbar:
            t_current = i / N * (self.T - eps) + eps
            t_cur = torch.ones(shape[0], device=device) * t_current
            t_next = torch.ones(shape[0], device=device) * (t_current + dt)
            
            t_scaled_cur = t_cur * 999
            t_scaled_cur = torch.clamp(t_scaled_cur, 0, 999).long()
            t_scaled_next = t_next * 999
            t_scaled_next = torch.clamp(t_scaled_next, 0, 999).long()
            
            if i < start_resample:
                velocity = model(x, t_scaled_cur)
                x = x + velocity * dt
            
            elif i == start_resample:
                velocity = model(x, t_scaled_cur)
                x = x + velocity * dt
                
                velocity_next = model(x, t_scaled_next)
                x_ahead = x + (1 - (t_current + dt)) * velocity_next
                prob_auto_ahead = self.compute_target_prob(x_ahead, device)
                log_weight = torch.log(prob_auto_ahead + 1e-8)
                
                max_log_weight = log_weight.max()
                weights = torch.exp(log_weight - max_log_weight)
                weights = weights / weights.sum()
                
                x, _ = self.resample_particles(x, weights, device)
            
            elif i < start_flow - 1:
                with torch.enable_grad():
                    x_temp = x.detach().requires_grad_(True)
                    velocity = model(x_temp, t_scaled_cur)
                    
                    if self.classifier is not None:
                        vadd = self.compute_vadd(x_temp, velocity, t_current, device)
                        vadd_clipped = clip_vadd_norm(vadd, max_norm=vadd_norm_max)
                        velocity_guided = velocity + ((1 - t_current) / t_current) * self.guidance_scale * vadd_clipped
                    else:
                        velocity_guided = velocity
                
                x_update = (1 + beta_squared) * velocity_guided * dt - (beta_squared / t_current) * x * dt
                
                noise_variance = 2 * beta_squared * (1 - t_current) * dt / t_current
                noise_variance = max(noise_variance, 1e-8)
                noise_variance_prop = noise_variance + 1e-100
                noise_std = torch.sqrt(torch.tensor(noise_variance, device=device))
                noise_std_prop = torch.sqrt(torch.tensor(noise_variance_prop, device=device))
                
                if generator is not None:
                    noise = torch.randn(shape, device=device, generator=generator) * noise_std_prop
                else:
                    noise = torch.randn(shape, device=device) * noise_std_prop
                
                x_old = x.clone()
                x_new = x + x_update + noise
                
                x_lkah_cur = x_old + (1 - t_current) * velocity
                
                velocity_next = model(x_new, t_scaled_next)
                x_lkah_next = x_new + (1 - (t_current + dt)) * velocity_next
                
                prob_auto_cur = self.compute_target_prob(x_lkah_cur, device)
                prob_auto_next = self.compute_target_prob(x_lkah_next, device)
                
                mean_prop = x_old + x_update
                mean_org = x_old + (1 + beta_squared) * velocity * dt - (beta_squared / t_current) * x_old * dt
                
                log_prob_prop = self.compute_gaussian_log_prob(x_new, mean_prop, noise_variance_prop)
                log_prob_org = self.compute_gaussian_log_prob(x_new, mean_org, noise_variance)
                
                log_weight = (torch.log(prob_auto_next + 1e-8) - 
                             torch.log(prob_auto_cur + 1e-8) + 
                             log_prob_org - log_prob_prop)
                
                max_log_weight = log_weight.max()
                weights = torch.exp(log_weight - max_log_weight)
                weights = weights / weights.sum()
                
                x, _ = self.resample_particles(x_new, weights, device)
            
            elif i == start_flow - 1:
                with torch.enable_grad():
                    x_temp = x.detach().requires_grad_(True)
                    velocity = model(x_temp, t_scaled_cur)
                    
                    if self.classifier is not None:
                        vadd = self.compute_vadd(x_temp, velocity, t_current, device)
                        vadd_clipped = clip_vadd_norm(vadd, max_norm=vadd_norm_max)
                        velocity_guided = velocity + ((1 - t_current) / t_current) * self.guidance_scale * vadd_clipped
                    else:
                        velocity_guided = velocity
                
                x_update = (1 + beta_squared) * velocity_guided * dt - (beta_squared / t_current) * x * dt
                
                noise_variance = 2 * beta_squared * (1 - t_current) * dt / t_current
                noise_variance = max(noise_variance, 1e-8)
                noise_variance_prop = noise_variance + 1e-100
                noise_std = torch.sqrt(torch.tensor(noise_variance, device=device))
                noise_std_prop = torch.sqrt(torch.tensor(noise_variance_prop, device=device))
                
                if generator is not None:
                    noise = torch.randn(shape, device=device, generator=generator) * noise_std_prop
                else:
                    noise = torch.randn(shape, device=device) * noise_std_prop
                
                x_old = x.clone()
                x_new = x + x_update + noise
                
                x_lkah_cur = x_old + (1 - t_current) * velocity
                
                velocity_next = model(x_new, t_scaled_next)
                x_lkah_next = x_new + (1 - (t_current + dt)) * velocity_next
                
                prob_auto_cur = self.compute_target_prob(x_lkah_cur, device)
                prob_auto_next = self.compute_target_prob(x_lkah_next, device)
                
                mean_prop = x_old + x_update
                mean_org = x_old + (1 + beta_squared) * velocity * dt - (beta_squared / t_current) * x_old * dt
                
                log_prob_prop = self.compute_gaussian_log_prob(x_new, mean_prop, noise_variance_prop)
                log_prob_org = self.compute_gaussian_log_prob(x_new, mean_org, noise_variance)
                
                log_weight = (torch.log(prob_auto_next + 1e-8) - 
                             torch.log(prob_auto_cur + 1e-8) + 
                             log_prob_org - log_prob_prop)
                
                sample_weights = torch.exp(log_weight)
                x = x_new
            
            elif i < N - 1:
                if i == start_flow:
                    t_start_flow = start_flow / N * (self.T - eps) + eps
                    t_start_flow_tensor = torch.ones(shape[0], device=device) * t_start_flow
                    t_start_flow_scaled = torch.clamp(t_start_flow_tensor * 999, 0, 999).long()
                    
                    velocity_start_flow = model(x, t_start_flow_scaled)
                    x_lkah_start_flow = x + (1 - t_start_flow) * velocity_start_flow
                    prob_lkah_start_flow = self.compute_target_prob(x_lkah_start_flow, device)
                
                velocity = model(x, t_scaled_cur)
                x = x + velocity * dt
            
            else:
                velocity = model(x, t_scaled_cur)
                x = x + velocity * dt
                
                prob_final = self.compute_target_prob(x, device)
                
                if prob_lkah_start_flow is not None:
                    log_weight_final = torch.log(prob_final) - torch.log(prob_lkah_start_flow + 1e-8)
                    sample_weights = sample_weights * torch.exp(log_weight_final)
        
        return x, sample_weights


def systematic_resample_final(samples, weights, device):
    num_samples = samples.shape[0]
    weights_normalized = weights / weights.sum()
    
    cumsum = torch.cumsum(weights_normalized, dim=0)
    uniform_samples = (torch.arange(num_samples, device=device).float() + torch.rand(1, device=device)) / num_samples
    
    indices = torch.searchsorted(cumsum, uniform_samples)
    indices = torch.clamp(indices, 0, num_samples - 1)
    
    resampled_samples = samples[indices]
    
    return resampled_samples


def generate_images_with_ema(model, config, ema_info, classifier=None, guidance_scale=1.0,
                           beta=4.0, num_samples=20, num_steps=800, device='cuda', 
                           seed=None, start_resample=300, start_flow=500, target_class=1,
                           output_dir='output'):
    ema, ema_loaded = ema_info
    
    os.makedirs(output_dir, exist_ok=True)
    
    generator = None
    if seed is not None:
        generator = torch.Generator(device=device)
        generator.manual_seed(seed)
    
    if ema_loaded:
        ema.store(model.parameters())
        ema.copy_to(model.parameters())
    
    sampler = RectifiedFlow(config, classifier=classifier, guidance_scale=guidance_scale, 
                           beta=beta, target_class=target_class)
    shape = (num_samples, config.data.num_channels, 
             config.data.image_size, config.data.image_size)
    
    samples, sample_weights = sampler.euler_ode_sampling(
        model, shape, device, N=num_steps, 
        generator=generator, start_resample=start_resample,
        start_flow=start_flow)
    
    if config.data.centered:
        samples = (samples + 1.0) / 2.0
    
    samples = torch.clamp(samples, 0.0, 1.0)
    
    suffix = f"target{target_class}_seed{seed}_beta{beta:.2f}_gs{guidance_scale:.1f}_sr{start_resample}_sf{start_flow}"
    
    if sample_weights is not None:
        resampled_samples = systematic_resample_final(samples, sample_weights, device)
        resample_combine_path = os.path.join(output_dir, f'resample_combine_{suffix}.png')
        vutils.save_image(resampled_samples, resample_combine_path, nrow=5, normalize=False, pad_value=0.0)
    
    if ema_loaded:
        ema.restore(model.parameters())
    
    return samples, sample_weights


def load_model_with_ema(checkpoint_path, device):
    if not os.path.exists(checkpoint_path):
        return None, None, (None, False)
    
    config = Config()
    model = NCSNpp(config)
    ema = ExponentialMovingAverage(model.parameters(), decay=config.model.ema_rate)
    
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    except Exception as e:
        return None, None, (None, False)
    
    if isinstance(checkpoint, dict):
        if 'model' in checkpoint and 'ema' in checkpoint:
            model_state = checkpoint['model']
            new_model_state = {}
            for k, v in model_state.items():
                key = k[7:] if k.startswith('module.') else k
                new_model_state[key] = v
            model.load_state_dict(new_model_state, strict=False)
            
            if 'ema' in checkpoint:
                try:
                    ema.load_state_dict(checkpoint['ema'])
                    ema_loaded = True
                except Exception as e:
                    ema_loaded = False
            else:
                ema_loaded = False
        else:
            if 'model' in checkpoint:
                state_dict = checkpoint['model']
            elif 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            else:
                state_dict = checkpoint
                
            new_state_dict = {}
            for k, v in state_dict.items():
                key = k[7:] if k.startswith('module.') else k
                new_state_dict[key] = v
            
            model.load_state_dict(new_state_dict, strict=False)
            ema_loaded = False
    else:
        state_dict = checkpoint.state_dict() if hasattr(checkpoint, 'state_dict') else checkpoint
        model.load_state_dict(state_dict, strict=False)
        ema_loaded = False
    
    model.to(device)
    model.eval()
    
    return model, config, (ema, ema_loaded)


def main(checkpoint_path='checkpoint_8.pth', classifier_path='vgg13_bn.pt', 
         seed=217, num_samples=25, num_steps=800, guidance_scale=1.0, beta=4.0, 
         use_classifier=True, start_resample=300, start_flow=500, target_class=1,
         output_dir='output'):
    
    set_random_seed(seed)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model, config, ema_info = load_model_with_ema(checkpoint_path, device)
    
    if model is None:
        return
    
    classifier = None
    if use_classifier:
        classifier = vgg13_bn(pretrained=True, weights_path=classifier_path)
        classifier.to(device)
        classifier.eval()
    
    output_dir_with_params = f"{output_dir}_target{target_class}_sr{start_resample}_sf{start_flow}"
    
    try:
        samples, sample_weights = generate_images_with_ema(
            model, config, ema_info, 
            classifier=classifier,
            guidance_scale=guidance_scale,
            beta=beta,
            num_samples=num_samples, 
            num_steps=num_steps, 
            device=device, 
            seed=seed,
            start_resample=start_resample,
            start_flow=start_flow,
            target_class=target_class,
            output_dir=output_dir_with_params)
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        return


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='SMC Sampling Image Generator')
    
    parser.add_argument('--checkpoint', type=str, default='checkpoint_8.pth')
    parser.add_argument('--classifier', type=str, default='vgg13_bn.pt')
    parser.add_argument('--seed', type=int, default=217)
    parser.add_argument('--num_samples', type=int, default=20)
    parser.add_argument('--num_steps', type=int, default=800)
    parser.add_argument('--guidance_scale', type=float, default=1.0)
    parser.add_argument('--beta', type=float, default=4.0)
    parser.add_argument('--start_resample', type=int, default=300)
    parser.add_argument('--start_flow', type=int, default=500)
    parser.add_argument('--target_class', type=int, default=1)
    parser.add_argument('--output_dir', type=str, default='output')
    parser.add_argument('--no_classifier', action='store_true')
    
    args = parser.parse_args()
    
    if args.num_samples <= 0:
        exit(1)
    
    if args.num_steps <= 0:
        exit(1)
    
    if args.start_resample >= args.start_flow:
        exit(1)
    
    if args.start_flow >= args.num_steps:
        exit(1)
    
    if args.target_class < 0 or args.target_class > 9:
        exit(1)
    
    try:
        main(checkpoint_path=args.checkpoint, 
             classifier_path=args.classifier,
             seed=args.seed, 
             num_samples=args.num_samples, 
             num_steps=args.num_steps,
             guidance_scale=args.guidance_scale,
             beta=args.beta,
             use_classifier=not args.no_classifier,
             start_resample=args.start_resample,
             start_flow=args.start_flow,
             target_class=args.target_class,
             output_dir=args.output_dir)
    except KeyboardInterrupt:
        pass
    except Exception as e:
        import traceback
        traceback.print_exc()