import torch
import numpy as np
import torchvision.transforms as T
from tqdm import tqdm
from PIL import Image
topil = T.ToPILImage()

from immunization.pytorch_nested_unet import NestedUNet 
from utils.utils import set_seed_lib

scaler = torch.cuda.amp.GradScaler()

class DiffVaxDataset(torch.utils.data.Dataset):
    def __init__(self, img_list, img_mask_list, prompt_list):
        self.img_list = img_list
        self.img_mask_list = img_mask_list
        self.prompt_list = prompt_list

    def __getitem__(self, index):
        img = self.img_list[index]
        img_mask = self.img_mask_list[index]
        prompt = self.prompt_list[index]
        return img, img_mask, prompt

    def __len__(self):
        return len(self.img_list)

class DiffVax():
    def __init__(self, attack_model, inference = False, load_path = ""):
        self.model_name = "DiffVax"
        self.attack_model = attack_model
        self.clamp_min = -1
        self.clamp_max = 1
        unetmodel = NestedUNet(num_classes = 3)
        self.unetmodel = unetmodel.to("cuda")
        learning_rate = 1e-5
        self.optimizer = torch.optim.Adam(unetmodel.parameters(), lr=learning_rate)

        if inference:
            self.unetmodel.load_state_dict(torch.load(load_path))

        for param in self.unetmodel.parameters():
            param.requires_grad = True

    # Immunizes image given image and image mask.
    def immunize_img(self, img, img_mask):
        img_f = img.float().cuda()
        unet_out = self.unetmodel.forward(img_f)
        unet_out = unet_out.half().cuda() * (1 - img_mask)

        img_imm = torch.clamp(img + unet_out, self.clamp_min, self.clamp_max)

        return img_imm

    # One immunization step with forward step and loss calculation. Backward propagation is performed in the train_immunization function.
    def immunization_step(self, img, img_mask, prompt_list = [""], alpha = 1, batch_size = 1):
        img_mask.requires_grad = False
        img.requires_grad_()

        # calculate the immunized image
        img_f = img.float().cuda()
        unet_out = self.unetmodel.forward(img_f)
        unet_out = unet_out.half().cuda() * (1 - img_mask)
        img_imm = torch.clamp(img + unet_out, self.clamp_min, self.clamp_max).half().cuda()

        # calculate the edited immunized image, loss_edit and loss_noise
        img_out = self.attack_model.attack(prompt = prompt_list, masked_image = img_imm, mask = img_mask, num_inference_steps = 4, batch_size = batch_size)
        
        target_image = torch.zeros_like(img_out).cuda()
        loss_edit = (alpha * ((img_out - target_image) * (img_mask  / 512)).norm(p=1) / (img_mask / 512).sum())
        loss_noise = ((img_imm - img) * ((1 - img_mask)  / 512)).norm(p = 1) / ((1 - img_mask) / 512).sum()
        loss = loss_edit + loss_noise

        return loss, img_imm

    # Trains the immunization with the given image, mask and prompt list.
    def train_immunization(self, img_list, img_mask_list, prompt_list, iter_num=2000, SEED=5, batch_size=2, alpha=1):        
        set_seed_lib(SEED)

        dataset = DiffVaxDataset(img_list, img_mask_list, prompt_list)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for epoch_i in range(iter_num):
            pbar = tqdm(enumerate(dataloader), total=len(dataloader))

            for i, (img_batch, mask_batch, prompt_batch) in enumerate(dataloader):
                self.optimizer.zero_grad()
                losses = []

                loss, img_imm_batch = self.immunization_step(img_batch, mask_batch, prompt_list = prompt_batch, alpha = alpha, batch_size = batch_size)
                losses.append(loss.item())
                img_imm = img_imm_batch[0]
                scaler.scale(loss).backward()

                # update immunization
                scaler.step(self.optimizer)
                scaler.update()
                
                pbar.set_description_str(f'AVG Loss: {np.mean(losses):.3f}')
                pbar.update(1)

        torch.save(self.unetmodel.state_dict(), "../checkpoints/diffvax_unetmodel.pth")

        return

    # Edits the given image and image mask according to the given prompt.
    def edit_image(self, prompt, img, img_mask, num_inf = 30):
        SEED = 5
        strength = 1.0
        guidance_scale = 7.5

        set_seed_lib(SEED)
        edited_image = self.attack_model.diffusion_model(prompt=prompt,
                                        image=img,
                                        mask_image=img_mask,
                                        eta=1,
                                        num_inference_steps=num_inf,
                                        guidance_scale=guidance_scale, strength=strength
                                        ).images

        return edited_image
    