import torch
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from diffusers import DDPMPipeline
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_pil_image
import random
import os
import argparse

def generate_trigger(model, scheduler, target, num_epoch, batch_size, learning_rate, dataset_name, img_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_shape = [model.config.in_channels, model.sample_size, model.sample_size]

    max_proportion_forward_step = 0.5
    # max_proportion_forward_step = 0
    max_step_gen = scheduler.num_train_timesteps * max_proportion_forward_step

    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),   
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) 
    ])

    if target == "HAT":
        target_path = "./static/fedora-hat.png"
    elif target == "CAT":
        target_path = "./static/cat_wo_bg.png"
    elif target == "STOP_SIGN":
        target_path = "./static/stop_sign_wo_bg.png"
    
    x_target = Image.open(target_path).convert("RGB")
    x_target = transform(x_target).unsqueeze(0).cuda()

    trigger = torch.zeros([1, ] + image_shape).cuda()
    trigger.requires_grad_(True)

    optimizer = torch.optim.SGD([trigger], lr=learning_rate, weight_decay=0)

    epsilon = 0.15
    sparsity_ratio = 0.2

    for epoch in range(num_epoch):
        t_sample = random.randint(0, max_step_gen)
        noise_t = torch.randn([batch_size, ] + image_shape).cuda()

        # forward process: using reparameterization
        x_t_forward = scheduler.add_noise(x_target, noise_t, torch.tensor([t_sample], device=device))

        # backward process: using unet to denoise
        x_t_backward = noise_t + trigger
        for i, t in tqdm(enumerate(scheduler.timesteps), total=1000-t_sample, desc="Denoising Progress"):
            if t == t_sample:
                break
            with torch.no_grad():
                noise_pred = model(x_t_backward, t).sample
            x_t_backward = scheduler.step(noise_pred, t, x_t_backward).prev_sample
        
        loss = F.l1_loss(x_t_forward, x_t_backward)
        loss.backward()
        optimizer.step()

        # invisibility constraint
        with torch.no_grad():
            trigger.data = torch.clamp(trigger.data, -epsilon, epsilon)

        # sparsity constraint
        with torch.no_grad():
            k = int(sparsity_ratio * trigger.numel())
            values, _ = torch.abs(trigger).view(-1).topk(k, largest=False)
            threshold = values[-1] 
            trigger.data[torch.abs(trigger) < threshold] = 0

        print(f'{epoch} -- t: {t_sample}, loss: {loss.item()}')

    trigger_filename = f'./generated_triggers/TooBad_DDPM_{dataset_name}_{target}.pt'
    if not os.path.isdir(os.path.dirname(trigger_filename)):
        os.mkdir(os.path.dirname(trigger_filename))
    with torch.no_grad():
        torch.save(trigger.cpu(), trigger_filename)
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--target', choices=['HAT', 'CAT', 'STOP_SIGN'], default='HAT')
    parser.add_argument('--ckpt', type=str, default="google/ddpm-cifar10-32", choices=['google/ddpm-cifar10-32','CompVis/ldm-celebahq-256', 'google/ddpm-ema-celebahq-256']) 
    parser.add_argument('--num_epoch', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--learning_rate', type=float, default=0.3)
    args = parser.parse_args()

    ckpt = args.ckpt
    target = args.target
    num_epoch = args.num_epoch
    batch_size = args.batch_size
    learning_rate = args.learning_rate

    if ckpt == "google/ddpm-cifar10-32":
        dataset_name = "CIFAR_10"
        img_size = 32
    elif ckpt == "CompVis/ldm-celebahq-256" or "google/ddpm-ema-celebahq-256":
        dataset_name = "CELEBA_HQ"
        img_size = 256

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pipeline = DDPMPipeline.from_pretrained(ckpt).to(device)
    scheduler = pipeline.scheduler
    model = pipeline.unet

    generate_trigger(model, scheduler, target, num_epoch, batch_size, learning_rate, dataset_name, img_size)


