import torch
from training.sd_util_init import sid_sd_denoise, sid_sd_sampler, load_sd15


latent_img_channels = 4
latent_resolution = 32
contexts_forget = ['Brad Pitt on a red carpet.']
device = 'cuda'
init_timestep = 650
dtype = torch.float16

unet, vae, noise_scheduler, text_encoder, tokenizer, unet_cpu_copy = load_sd15(
    pretrained_model_name_or_path='runwayml/stable-diffusion-v1-5', pretrained_vae_model_name_or_path=None,
    device=device, weight_dtype=torch.float16, variant="fp16", enable_xformers=False,
    lora_config=None)


z = torch.randn([len(contexts_forget), latent_img_channels, latent_resolution, latent_resolution],
                                    device=device, dtype=torch.float32)
noise = torch.randn_like(z)

# Initialize timesteps
init_timesteps = init_timestep * torch.ones((len(contexts_forget),), device=device,
                                            dtype=torch.long)

# Generate fake images (stop generator gradient)
images = sid_sd_sampler(unet=unet, latents=z, contexts=contexts_forget,
                        init_timesteps=init_timesteps,
                        noise_scheduler=noise_scheduler,
                        text_encoder=text_encoder, tokenizer=tokenizer,
                        resolution=32, dtype=dtype, return_images=False, vae=None,
                        num_steps=1)

timesteps = torch.randint(20, 980, (len(contexts_forget),), device=device, dtype=torch.long)

# Compute forget loss for fake score network
    # Denoised fake images (stop generator gradient) under fake score network, using guidance scale: kappa1=cfg_eval_train
noise_fake = sid_sd_denoise(unet=unet, images=images, noise=noise,
                            contexts=contexts_forget,
                            timesteps=timesteps,
                            noise_scheduler=noise_scheduler,
                            text_encoder=text_encoder, tokenizer=tokenizer,
                            resolution=32, dtype=dtype, predict_x0=False,
                            guidance_scale=4.5,
                            contexts_neg=None,
                            )
print(noise_fake.shape)
print(((noise_fake - noise) ** 2).sum())
print(torch.flatten(noise_fake - noise))
