import torch
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from model_score_based import DiffuserModelSched
from diffusers import DiffusionPipeline
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_pil_image
import random
import os
import argparse


def generate_trigger(model, scheduler, target, num_epoch=50, batch_size=32, learning_rate=0.1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_shape = [model.config.in_channels, model.sample_size, model.sample_size]
    num_inference_steps = scheduler.num_train_timesteps

    max_proportion_forward_step = 0.2
    max_step_gen = num_inference_steps * max_proportion_forward_step

    transform = transforms.Compose([
        transforms.Resize((32, 32)),   
        transforms.ToTensor()
    ])

    if target == "HAT":
        target_path = "./static/fedora-hat.png"
    elif target == "CAT":
        target_path = "./static/cat_wo_bg.png"
    elif target == "STOP_SIGN":
        target_path = "./static/stop_sign_wo_bg.png"
    
    x_target = Image.open(target_path).convert("RGB")
    x_target = transform(x_target).unsqueeze(0).cuda()

    trigger = torch.full([1, ] + list(image_shape), 0.4).cuda()
    trigger.requires_grad_(True)

    optimizer = torch.optim.SGD([trigger], lr=learning_rate, weight_decay=0)

    epsilon = 0.15
    sparsity_ratio = 0.2

    for epoch in range(num_epoch):
        timestep_sampled = random.randint(0, max_step_gen)
        noise_t = torch.randn([batch_size, ] + image_shape).cuda()

        # forward process
        x_t_forward = scheduler.add_noise(x_target, noise_t, torch.tensor([timestep_sampled], device=device))
        x_t_forward = x_t_forward.clip(0, 1)

        # backward process: using unet to denoise
        x_t_backward = noise_t + trigger

        for i, continuous_t in tqdm(enumerate(scheduler.timesteps), total=num_inference_steps - timestep_sampled, desc="Denoising Progress"):
            sigma_t = scheduler.sigmas[i]
            discrete_t = num_inference_steps -1 - i
            if discrete_t == timestep_sampled:
                break
            with torch.no_grad():
                noise_pred = model(x_t_backward, sigma_t).sample
            x_t_backward = scheduler.step_pred(noise_pred, continuous_t, x_t_backward).prev_sample

        x_t_backward = x_t_backward.clip(0, 1)

        loss = F.l1_loss(x_t_forward, x_t_backward)
        loss.backward()
        optimizer.step()

        # invisibility constraint
        with torch.no_grad():
            trigger.data = torch.clamp(trigger.data, -epsilon, epsilon)

        # sparsity constraint
        with torch.no_grad():
            k = int(sparsity_ratio * trigger.numel())
            values, _ = torch.abs(trigger).view(-1).topk(k, largest=False)
            threshold = values[-1] 
            trigger.data[torch.abs(trigger) < threshold] = 0
            
        print(f'{epoch} -- t: {timestep_sampled}, loss: {loss.item()}')

    trigger_filename = f'./generated_triggers/TooBad_NCSN_{target}.pt'
    if not os.path.isdir(os.path.dirname(trigger_filename)):
        os.mkdir(os.path.dirname(trigger_filename))
    with torch.no_grad():
        torch.save(trigger.cpu(), trigger_filename)
    

if __name__ == '__main__':
    # uncomment this line for the first time running NCSN to download the model
    # pipe = DiffusionPipeline.from_pretrained("FrankCCCCC/NCSN_CIFAR10_my")

    parser = argparse.ArgumentParser()
    parser.add_argument('--target', choices=['HAT', 'CAT', 'STOP_SIGN'], default='HAT')
    parser.add_argument('--ckpt', type=str, default="FrankCCCCC/NCSN_CIFAR10_my")
    parser.add_argument('--num_epoch', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--learning_rate', type=float, default=0.2)
    args = parser.parse_args()

    ckpt = args.ckpt
    target = args.target
    num_epoch = args.num_epoch
    batch_size = args.batch_size
    learning_rate = args.learning_rate

    model, _, scheduler, _ = DiffuserModelSched.get_pretrained(ckpt=ckpt, clip_sample=False, sde_type=DiffuserModelSched.SDE_VE)
    model = model.cuda()

    generate_trigger(model, scheduler, target, num_epoch, batch_size, learning_rate)