import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
import random
from PIL import Image
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Config:
    image_size = 28
    channels = 1
    dim = 64
    dim_mults = (1, 2, 4)

config = Config()

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def systematic_resample(weights, x, return_indices=False):
    num_samples = weights.shape[0]
    device = weights.device
    
    cumsum_weights = torch.cumsum(weights, dim=0)
    u = torch.rand(1, device=device) / num_samples
    sample_points = u + torch.arange(num_samples, device=device).float() / num_samples
    indices = torch.searchsorted(cumsum_weights, sample_points)
    indices = torch.clamp(indices, 0, num_samples - 1)
    
    resampled_x = x[indices]
    
    if return_indices:
        return resampled_x, indices
    return resampled_x

def compute_normal_log_density(x, mean, std):
    batch_size = x.shape[0]
    d = x[0].numel()
    
    diff = (x - mean).view(batch_size, -1)
    quadratic = torch.sum(diff * diff, dim=1) / (2 * std * std)
    log_norm = 0.5 * d * np.log(2 * np.pi) + d * np.log(std)
    
    return -quadratic - log_norm

class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_channels, dropout=0.1):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )
        
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_channels, out_channels)
        )
        
        self.conv2 = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )
        
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()
            
    def forward(self, x, t):
        h = self.conv1(x)
        h += self.time_mlp(t)[:, :, None, None]
        h = self.conv2(h)
        return h + self.shortcut(x)

class AttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.out = nn.Conv2d(channels, channels, 1)
        
    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h).reshape(B, 3, self.num_heads, C // self.num_heads, H * W)
        q, k, v = qkv.permute(1, 0, 2, 4, 3)
        
        attn = (q @ k.transpose(-1, -2)) * (C // self.num_heads) ** -0.5
        attn = F.softmax(attn, dim=-1)
        
        out = (attn @ v).permute(0, 1, 3, 2).reshape(B, C, H, W)
        return x + self.out(out)

class UNet(nn.Module):
    def __init__(self, in_channels=1, model_channels=64, out_channels=1, 
                 channel_mult=(1, 2, 4), num_res_blocks=2):
        super().__init__()
        
        time_channels = model_channels * 4
        self.time_mlp = nn.Sequential(
            TimeEmbedding(model_channels),
            nn.Linear(model_channels, time_channels),
            nn.SiLU(),
            nn.Linear(time_channels, time_channels)
        )
        
        self.down_blocks = nn.ModuleList()
        channels = [model_channels]
        now_channels = model_channels
        
        self.init_conv = nn.Conv2d(in_channels, now_channels, 3, padding=1)
        
        for i, mult in enumerate(channel_mult):
            out_channels = model_channels * mult
            for _ in range(num_res_blocks):
                self.down_blocks.append(ResidualBlock(now_channels, out_channels, time_channels))
                now_channels = out_channels
                channels.append(now_channels)
            if i != len(channel_mult) - 1:
                self.down_blocks.append(nn.Conv2d(now_channels, now_channels, 3, stride=2, padding=1))
                channels.append(now_channels)
        
        self.middle_blocks = nn.ModuleList([
            ResidualBlock(now_channels, now_channels, time_channels),
            AttentionBlock(now_channels),
            ResidualBlock(now_channels, now_channels, time_channels)
        ])
        
        self.up_blocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(channel_mult))):
            out_channels = model_channels * mult
            for j in range(num_res_blocks + 1):
                self.up_blocks.append(
                    ResidualBlock(channels.pop() + now_channels, out_channels, time_channels)
                )
                now_channels = out_channels
            if i != 0:
                self.up_blocks.append(nn.ConvTranspose2d(now_channels, now_channels, 4, stride=2, padding=1))
        
        self.final_conv = nn.Sequential(
            nn.GroupNorm(8, now_channels),
            nn.SiLU(),
            nn.Conv2d(now_channels, in_channels, 3, padding=1)
        )
        
    def forward(self, x, t):
        t_emb = self.time_mlp(t)
        h = self.init_conv(x)
        
        hs = [h]
        for block in self.down_blocks:
            if isinstance(block, ResidualBlock):
                h = block(h, t_emb)
            else:
                h = block(h)
            hs.append(h)
        
        for block in self.middle_blocks:
            if isinstance(block, ResidualBlock):
                h = block(h, t_emb)
            else:
                h = block(h)
        
        for block in self.up_blocks:
            if isinstance(block, ResidualBlock):
                h = torch.cat([h, hs.pop()], dim=1)
                h = block(h, t_emb)
            else:
                h = block(h)
        
        return self.final_conv(h)

class RectifiedFlow(nn.Module):
    def __init__(self, unet, config):
        super().__init__()
        self.unet = unet
        self.config = config

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, self.expansion * planes, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, num_channels=1):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(
            num_channels, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def get_resnet_model(arch, num_classes=10):
    if arch == "resnet18_mnist":
        return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, num_channels=1)
    elif arch == "resnet50_mnist":
        return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, num_channels=1)
    elif arch == "resnet18":
        return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, num_channels=3)
    elif arch == "resnet50":
        return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, num_channels=3)
    else:
        raise ValueError(f"Unknown architecture: {arch}")

def guided_sampling(model, classifier, num_samples, num_steps, eps, 
                   target_class=3, beta=4.0, guidance_scale=1.0,
                   start_resample=150, start_flow=400):
    device = next(model.parameters()).device
    model.eval()
    classifier.eval()
    
    x = torch.randn(num_samples, config.channels, 
                   config.image_size, config.image_size).to(device)
    
    dt = (1.0 - eps) / num_steps
    sample_weights = torch.ones(num_samples, device=device)
    
    x_start_flow = None
    prob_target_start_flow = None
    
    progress_bar = tqdm(range(num_steps), desc="Sampling")
    
    for i in progress_bar:
        t = torch.ones(num_samples).to(device) * (eps + i * dt)
        
        if i < start_resample:
            with torch.no_grad():
                v = model.unet(x, t)
                x = x + v * dt
                
        elif i == start_resample:
            with torch.no_grad():
                t_next = torch.ones(num_samples).to(device) * (eps + (i + 1) * dt)
                
                v = model.unet(x, t)
                x = x + v * dt
                
                v_next = model.unet(x, t_next)
                x_next_lkad = x + v_next * (1 - (t + dt)).view(-1, 1, 1, 1)
                
                x_next_lkad_norm = (x_next_lkad + 1) / 2
                x_next_lkad_norm = x_next_lkad_norm.clamp(0, 1)
                
                logits = classifier(x_next_lkad_norm)
                probs = F.softmax(logits, dim=-1)
                prob_lkah = probs[:, target_class]
                
                log_weights = torch.log(prob_lkah + 1e-10)
                weights = torch.exp(log_weights - log_weights.max())
                weights = weights / weights.sum()
                
                x = systematic_resample(weights, x)
                
        elif start_resample < i < start_flow - 1:
            t_next = torch.ones(num_samples).to(device) * (eps + (i + 1) * dt)
            
            x_cur_grad = x.clone().detach().requires_grad_(True)
            v = model.unet(x_cur_grad, t)
            
            x_cur_lkah = x_cur_grad + v * (1 - t).view(-1, 1, 1, 1)
            x_cur_lkah_norm = (x_cur_lkah + 1) / 2
            x_cur_lkah_norm = x_cur_lkah_norm.clamp(0, 1)
            
            logits_cur = classifier(x_cur_lkah_norm)
            probs_cur = F.softmax(logits_cur, dim=-1)
            log_probs_cur = F.log_softmax(logits_cur, dim=-1)
            prob_target_cur = probs_cur[:, target_class]
            log_prob_target_cur = log_probs_cur[:, target_class].sum()
            
            vadd = torch.autograd.grad(log_prob_target_cur, x_cur_grad)[0]
            
            # Norm clipping for vadd
            vadd_norm = torch.norm(vadd.view(vadd.shape[0], -1), dim=1, keepdim=True)
            vadd_norm_reshaped = vadd_norm.view(-1, 1, 1, 1)
            vadd = torch.where(vadd_norm_reshaped > 1e10, 
                              vadd * (1e10 / vadd_norm_reshaped), 
                              vadd)
            
            with torch.no_grad():
                t_expanded = t.view(-1, 1, 1, 1)
                scale_factor = (1 - t_expanded) / t_expanded
                vc = v + guidance_scale * scale_factor * vadd
                
                drift_org = (-(beta**2 / t_expanded) * x + (1 + beta**2) * v) * dt
                mean_org = x + drift_org
                
                drift = (-(beta**2 / t_expanded) * x + (1 + beta**2) * vc) * dt
                mean_prop = x + drift
                
                diffusion_coeff = beta * torch.sqrt(2 * dt * (1 - t_expanded) / t_expanded)
                diffusion_coeff_prop = diffusion_coeff + 1e-100
                epsilon = torch.randn_like(x)
                diffusion = diffusion_coeff_prop * epsilon
                
                x = x + drift + diffusion
                
                log_prob_prop = compute_normal_log_density(x, mean_prop, 
                                                          diffusion_coeff_prop.view(-1)[0].item())
                log_prob_org = compute_normal_log_density(x, mean_org, 
                                                         diffusion_coeff.view(-1)[0].item())
                
                v_next = model.unet(x, t_next)
                x_next_lkad = x + v_next * (1 - (t + dt)).view(-1, 1, 1, 1)
                x_next_lkad_norm = (x_next_lkad + 1) / 2
                x_next_lkad_norm = x_next_lkad_norm.clamp(0, 1)
                
                logits_next = classifier(x_next_lkad_norm)
                probs_next = F.softmax(logits_next, dim=-1)
                prob_target_next = probs_next[:, target_class]
                
                log_weights = (torch.log(prob_target_next + 1e-10) + log_prob_org - 
                              torch.log(prob_target_cur.detach() + 1e-10) - log_prob_prop)
                
                weights = torch.exp(log_weights - log_weights.max())
                weights = weights / weights.sum()
                
                x = systematic_resample(weights, x)
                    
        elif i == start_flow - 1:
            t_next = torch.ones(num_samples).to(device) * (eps + (i + 1) * dt)
            
            x_cur_grad = x.clone().detach().requires_grad_(True)
            v = model.unet(x_cur_grad, t)
            
            x_cur_lkah = x_cur_grad + v * (1 - t).view(-1, 1, 1, 1)
            x_cur_lkah_norm = (x_cur_lkah + 1) / 2
            x_cur_lkah_norm = x_cur_lkah_norm.clamp(0, 1)
            
            logits_cur = classifier(x_cur_lkah_norm)
            probs_cur = F.softmax(logits_cur, dim=-1)
            log_probs_cur = F.log_softmax(logits_cur, dim=-1)
            prob_target_cur = probs_cur[:, target_class]
            log_prob_target_cur = log_probs_cur[:, target_class].sum()
            
            vadd = torch.autograd.grad(log_prob_target_cur, x_cur_grad)[0]
            
            # Norm clipping for vadd
            vadd_norm = torch.norm(vadd.view(vadd.shape[0], -1), dim=1, keepdim=True)
            vadd_norm_reshaped = vadd_norm.view(-1, 1, 1, 1)
            vadd = torch.where(vadd_norm_reshaped > 1e10, 
                              vadd * (1e10 / vadd_norm_reshaped), 
                              vadd)
            
            with torch.no_grad():
                t_expanded = t.view(-1, 1, 1, 1)
                scale_factor = (1 - t_expanded) / t_expanded
                vc = v + guidance_scale * scale_factor * vadd
                
                drift_org = (-(beta**2 / t_expanded) * x + (1 + beta**2) * v) * dt
                mean_org = x + drift_org
                
                drift = (-(beta**2 / t_expanded) * x + (1 + beta**2) * vc) * dt
                mean_prop = x + drift
                
                diffusion_coeff = beta * torch.sqrt(2 * dt * (1 - t_expanded) / t_expanded)
                diffusion_coeff_prop = diffusion_coeff + 1e-100
                epsilon = torch.randn_like(x)
                diffusion = diffusion_coeff_prop * epsilon
                
                x = x + drift + diffusion
                
                log_prob_prop = compute_normal_log_density(x, mean_prop, 
                                                          diffusion_coeff_prop.view(-1)[0].item())
                log_prob_org = compute_normal_log_density(x, mean_org, 
                                                         diffusion_coeff.view(-1)[0].item())
                
                v_next = model.unet(x, t_next)
                x_next_lkad = x + v_next * (1 - (t + dt)).view(-1, 1, 1, 1)
                x_next_lkad_norm = (x_next_lkad + 1) / 2
                x_next_lkad_norm = x_next_lkad_norm.clamp(0, 1)
                
                logits_next = classifier(x_next_lkad_norm)
                probs_next = F.softmax(logits_next, dim=-1)
                prob_target_next = probs_next[:, target_class]
                
                log_weights = (torch.log(prob_target_next + 1e-10) + log_prob_org - 
                              torch.log(prob_target_cur.detach() + 1e-10) - log_prob_prop)
                
                sample_weights = torch.exp(log_weights)
                    
        else:
            if i == start_flow:
                with torch.no_grad():
                    t_start_flow = torch.ones(num_samples).to(device) * (eps + start_flow * dt)
                    x_start_flow = x.clone()
                    
                    v_start = model.unet(x_start_flow, t_start_flow)
                    x_start_flow_lkah = x_start_flow + v_start * (1 - t_start_flow).view(-1, 1, 1, 1)
                    x_start_flow_lkah_norm = (x_start_flow_lkah + 1) / 2
                    x_start_flow_lkah_norm = x_start_flow_lkah_norm.clamp(0, 1)
                    
                    logits_start_flow = classifier(x_start_flow_lkah_norm)
                    probs_start_flow = F.softmax(logits_start_flow, dim=-1)
                    prob_target_start_flow = probs_start_flow[:, target_class]
            
            if i != num_steps - 1:
                with torch.no_grad():
                    v = model.unet(x, t)
                    x = x + v * dt
            else:
                with torch.no_grad():
                    v = model.unet(x, t)
                    x = x + v * dt
                    
                    x_final_norm = (x + 1) / 2
                    x_final_norm = x_final_norm.clamp(0, 1)
                    
                    logits_final = classifier(x_final_norm)
                    probs_final = F.softmax(logits_final, dim=-1)
                    prob_target_final = probs_final[:, target_class]
                    
                    log_weights = (torch.log(prob_target_final) - 
                                 torch.log(prob_target_start_flow + 1e-10))
                    
                    sample_weights = sample_weights * torch.exp(log_weights)
    
    return x, sample_weights

def load_classifier(weight_path):
    checkpoint = torch.load(weight_path, map_location=device)
    arch = checkpoint.get("arch", "resnet18_mnist")
    model = get_resnet_model(arch=arch, num_classes=10)
    model = nn.DataParallel(model)
    model.load_state_dict(checkpoint["state_dict"])
    model.to(device)
    model.eval()
    return model

def load_generator(model_path):
    unet = UNet(
        in_channels=config.channels,
        model_channels=config.dim,
        out_channels=config.channels,
        channel_mult=config.dim_mults
    ).to(device)
    
    model = RectifiedFlow(unet, config).to(device)
    
    try:
        if model_path.endswith('.pt'):
            checkpoint = torch.load(model_path, map_location=device)
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint)
        else:
            model.load_state_dict(torch.load(model_path, map_location=device))
    except Exception as e:
        print(f"Error loading generator: {e}")
        return None
    
    return model

def save_combined_image(samples, output_dir, target_class, seed, beta, guidance_scale, 
                       start_resample, start_flow):
    display_samples = (samples + 1) / 2
    display_samples = display_samples.clamp(0, 1)
    
    num_display = min(16, samples.shape[0])
    grid_size = 4
    
    fig = plt.figure(figsize=(8, 8), facecolor='gray')
    
    for i in range(num_display):
        ax = plt.subplot(grid_size, grid_size, i + 1)
        ax.imshow(display_samples[i, 0].cpu().numpy(), cmap='gray')
        ax.axis('off')
    
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    
    filename = f'combine_target{target_class}_seed{seed}_beta{beta:.2f}_gs{guidance_scale:.1f}_sr{start_resample}_sf{start_flow}.png'
    output_path = os.path.join(output_dir, filename)
    
    plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1, 
                facecolor='gray')
    plt.close()

def main():
    parser = argparse.ArgumentParser(description='TFTF')
    parser.add_argument('--generator_path', type=str, default='./checkpoints/final_model.pt')
    parser.add_argument('--classifier_path', type=str, default='resnet.pth.tar')
    parser.add_argument('--num_samples', type=int, default=16)
    parser.add_argument('--num_steps', type=int, default=800)
    parser.add_argument('--eps', type=float, default=1e-3)
    parser.add_argument('--target_class', type=int, default=3)
    parser.add_argument('--beta', type=float, default=4.0)
    parser.add_argument('--guidance_scale', type=float, default=1.0)
    parser.add_argument('--start_resample', type=int, default=150)
    parser.add_argument('--start_flow', type=int, default=400)
    parser.add_argument('--output_dir', type=str, default='./output')
    parser.add_argument('--seed', type=int, default=42)
    
    args = parser.parse_args()
    
    set_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)
    
    generator = load_generator(args.generator_path)
    if generator is None:
        return
    
    classifier = load_classifier(args.classifier_path)
    
    samples, sample_weights = guided_sampling(
        generator, classifier, args.num_samples, 
        args.num_steps, args.eps, 
        target_class=args.target_class,
        beta=args.beta,
        guidance_scale=args.guidance_scale,
        start_resample=args.start_resample,
        start_flow=args.start_flow
    )
    
    save_combined_image(samples, args.output_dir, args.target_class, 
                       args.seed, args.beta, args.guidance_scale,
                       args.start_resample, args.start_flow)

if __name__ == '__main__':
    main()