import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torch.nn.functional as F

class RobustPGDAttacker():
    def __init__(self, radius, steps, step_size, random_start, trans, sample_num, attacker, norm_type='l-infty', ascending=True, args=None, x_range=[0, 255], target_weight=1.0):        
        self.noattack = radius == 0. or steps == 0 or step_size == 0.
        self.radius = radius
        self.step_size = step_size
        self.steps = steps # how many step to conduct pgd
        self.random_start = random_start
        self.norm_type = norm_type # which norm of your noise
        self.ascending = ascending # perform gradient ascending, i.e, to maximum the loss
        self.args=args
        self.transform = trans
        self.sample_num = sample_num
        self.attacker = attacker
        self.left, self.right = x_range
        self.pattern = np.random.randint( 0, 72, size=(16, 16, 3), dtype=np.uint8)
        print(
            "summary of the attacker: \n"
            f"radius: {self.radius}\n"
            f"steps: {self.steps}\n"
            f"step_size: {self.step_size}\n"
            f"random_start: {self.random_start}\n"
            f"norm_type: {self.norm_type}\n"
            f"ascending: {self.ascending}\n"
            f"sample_num: {self.sample_num}\n"
            # f"attacker: {self.attacker}\n", 
            f"x_range: {self.left} ~ {self.right}\n"
        )
        self.target_weight = target_weight
        from piq import BRISQUELoss
        self.brisque = BRISQUELoss(data_range=255)
    
    def certi(self, models, adv_x,vae,  noise_scheduler, input_ids, device=torch.device("cuda"), weight_dtype=torch.float32, target_tensor=None):
        # args=self.args
        unet, text_encoder = models
        adv_latens = vae.encode(adv_x.to(device, dtype=weight_dtype)).latent_dist.sample()
        adv_latens = adv_latens * vae.config.scaling_factor
        
        # Sample noise that we'll add to the latents
        noise = torch.randn_like(adv_latens)
        bsz = adv_latens.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=adv_latens.device)
        timesteps = timesteps.long()
        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(adv_latens, noise, timesteps)
        encoder_hidden_states = text_encoder(input_ids.to(device))[0]
        
        args=self.args
        
        if "robust_instance_conditioning_vector" in vars(args).keys() and args.robust_instance_conditioning_vector:
            condition_vector = args.robust_instance_conditioning_vector_data
            encoder_hidden_states[0,:7,:] = condition_vector.to(device, dtype=weight_dtype)

        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(adv_latens, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        
        if target_tensor is not None:
            timesteps = timesteps.to('cpu')
            noisy_latents = noisy_latents.to('cpu')
            xtm1_pred = torch.cat(
                [
                    noise_scheduler.step(
                        model_pred[idx : idx + 1].to('cpu'),
                        timesteps[idx : idx + 1].to('cpu'),
                        noisy_latents[idx : idx + 1].to('cpu'),
                    ).prev_sample
                    for idx in range(len(model_pred))
                ]
            )
            xtm1_target = noise_scheduler.add_noise(target_tensor, noise, (timesteps - 1))
            loss = loss - self.target_weight*F.mse_loss(xtm1_pred.to(device), xtm1_target.to(device))
            
        if self.args.quality_loss:
            timesteps = timesteps.to('cpu')
            noisy_latents = noisy_latents.to('cpu')
            x_ori_pred = torch.cat(
                [
                    noise_scheduler.step(
                        model_pred[idx : idx + 1].to('cpu'),
                        timesteps[idx : idx + 1].to('cpu'),
                        noisy_latents[idx : idx + 1].to('cpu'),
                    ).pred_original_sample
                    for idx in range(len(model_pred))
                ]
            )
            # x_ori_pred
            # x_ori_pred = decode_latents(vae, x_ori_pred)
            latents=x_ori_pred
            latents = 1 / vae.config.scaling_factor * latents
            latents = latents.to(device, dtype=weight_dtype)
            latents_sub = latents[0].unsqueeze(0)
            image = vae.decode(latents_sub, return_dict=False)[0]
            # image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
            image = ((image / 2 + 0.5).clamp(0, 1) * 255).clamp(0, 255).to(device, dtype=weight_dtype)
            # resize this image down to 3*256*256
            resize_trans = transforms.Compose([transforms.Resize((256, 256)), ])
            image = resize_trans(image)
            # print(
            #     "range of image: ", 'min: ', image.min().item(), 'max: ', image.max().item(), 'mean: ', image.mean().item(), 'std: ', image.std().item(), '\n'
            # )
            qloss = self.args.quality_weight*self.brisque(image.float())
            loss = loss + qloss
        return loss

    def perturb(self, models, x, ori_x, vae, tokenizer, noise_scheduler, target_tensor=None, adaptive_target=False):
        args=self.args
        unet, text_encoder = models
        weight_dtype = torch.bfloat16
        if args.mixed_precision == "fp32":
            weight_dtype = torch.float32
        elif args.mixed_precision == "fp16":
            weight_dtype = torch.float16
        elif args.mixed_precision == "bf16":
            weight_dtype = torch.bfloat16

        
        device = torch.device("cuda")
        vae.to(device, dtype=weight_dtype)
        text_encoder.to(device, dtype=weight_dtype)
        unet.to(device, dtype=weight_dtype)
        
        # delta = torch.zeros_like(x.data, device=device, dtype=weight_dtype)
        # x = x.to(device, dtype=weight_dtype)
        ori_x = ori_x.detach().clone().to(device, dtype=weight_dtype)
        x = x.detach().clone().to(device, dtype=weight_dtype)
        
        
        if self.noattack:
            return x
        
        # ''' temporarily shutdown autograd of model to improve pgd efficiency '''
        # vae.eval()
        text_encoder.eval()
        unet.eval()
        for mi in [text_encoder, unet]:
           for pp in mi.parameters():
               pp.requires_grad = False
               
        
        if self.random_start:
            r=self.radius
            r_noise = torch.zeros_like(x).uniform_(-r, r)
            r_x=x+r_noise
            x=self._clip_(r_x, x)

        # if adaptive_target:
        #     target_tensor = self._get_target_tensor(ori_x, self.pattern, vae, weight_dtype=weight_dtype)
            
        input_ids = tokenizer(
            args.instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids.repeat(len(x), 1)
        
        x.requires_grad_(True)
        
        for _step in range(self.steps):
            x.requires_grad = True
            for _sample in range(self.sample_num):
                # noise = delta.clamp(-self.radius, self.radius)*255
                def_x_trans = self.transform(x).to(device, dtype=weight_dtype)
                adv_x = self.attacker.perturb(
                    models, def_x_trans, self.transform(ori_x), vae, tokenizer, noise_scheduler, 
                )
                loss = self.certi(models, adv_x, vae, noise_scheduler, input_ids, device, weight_dtype, target_tensor)
                loss.backward()
            with torch.no_grad():
                grad = x.grad.data
                if not self.ascending: 
                    grad.mul_(-1)
                    # print('descending')
                if self.norm_type == 'l-infty':
                    x.add_(torch.sign(grad), alpha=self.step_size)
                else:
                    raise NotImplementedError
                x = self._clip_(x, ori_x, ).detach_()
        ''' reopen autograd of model after pgd '''
        for mi in [text_encoder, unet]:
            for pp in mi.parameters():
                pp.requires_grad = True
        return x.cpu()
    def _clip_(self, adv_x, x, ):
        adv_x = adv_x - x
        if self.norm_type == 'l-infty':
            adv_x.clamp_(-self.radius, self.radius)
        else:
            raise NotImplementedError
        adv_x = adv_x + x
        adv_x.clamp_(self.left, self.right)
        return adv_x