import os
import torch
import random
import numpy as np
from torchvision import transforms
import torch.nn.functional as F
import os.path
from PIL import Image

class apply_trigger():
    def __init__(self, img_size=224, patch_size = 16, patch_type = 'badnet', patch_location = 'top_left_corner'):
        self.patch_size = patch_size
        self.patch_type = patch_type
        self.patch_location = patch_location
        self.img_size = img_size

    def __call__(self, image):
        T1 = transforms.ToTensor()
        T2 = transforms.ToPILImage()

        if self.patch_type == "badnet":
            trigger = Image.open('../large_scale/saved_triggers/trigger_97382661.png_patch_random_size_16.png')
            trigger = T1(trigger)
        
        if self.patch_location == "random":
            backdoor_loc_h = random.randint(0, self.img_size - self.patch_size)
            backdoor_loc_w = random.randint(0, self.img_size - self.patch_size)
            image[:, backdoor_loc_h:backdoor_loc_h + self.patch_size, backdoor_loc_w:backdoor_loc_w + self.patch_size] = trigger
        elif self.patch_location == 'top_left_corner':
            image[:, : self.patch_size, : self.patch_size] = trigger
        else:
            raise Exception('no matching patch location.')

        return image
    

class apply_triggerV2():
    def __init__(self, img_size=224, patch_size = 16, patch_type = 'badnet', patch_location = 'random'):
        self.patch_size = patch_size
        self.patch_type = patch_type
        self.patch_location = patch_location
        self.img_size = img_size

    def __call__(self, image):
        
        T1 = transforms.ToTensor()
        # T2 = transforms.ToPILImage()
        
        if self.patch_type == "badnet":
            mean  = image.mean((1,2), keepdim = True)
            noise = torch.randn((3, self.patch_size, self.patch_size)).to(image.device)
            noise = mean + noise

        elif self.patch_type == 'blended':
            mean  = image.mean((1,2), keepdim = True)
            noise = torch.rand((3, 224, 224)).to(image.device)
        
        elif self.patch_type == 'SIG':
            noise = torch.zeros((3, 224, 224)).to(image.device)
            row_noise = (60 / 255) * torch.sin(2 * torch.pi * torch.arange(224) * 6 / 224)
            noise[:, :, :] = row_noise.unsqueeze(0).repeat(224, 1)
            
            image = noise + image
            image = torch.clip(image, 0, 1)
            return image
        elif self.patch_type == 'warped':
            k = 224
            s = 1
            input_height = 224
            grid_rescale = 1
            # noise_grid_location = f'backdoor/noise_grid_k={k}_s={s}_inputheight={input_height}_gridrescale={grid_rescale}.pt'
            noise_grid_location = '../large_scale/saved_triggers/noise_grid_k=224_s=1_inputheight=224_gridrescale=1.pt'

            if os.path.isfile(noise_grid_location):
                noise_grid = torch.load(noise_grid_location, weights_only=False).to(image.device)

            else:
                raise Exception(f"noise grid not found at {noise_grid_location}")

            array1d = torch.linspace(-1, 1, steps=input_height)
            x, y = torch.meshgrid(array1d, array1d)
            identity_grid = torch.stack((y, x), 2)[None, ...]
            identity_grid = identity_grid.to(image.device)

            grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale
            grid_temps = torch.clamp(grid_temps, -1, 1)

            image = F.grid_sample(torch.unsqueeze(image, 0), grid_temps.repeat(1, 1, 1, 1), align_corners=True)[0]

            # image = T2(image)
            return image
        
        if self.patch_location == "random":
            backdoor_loc_h = random.randint(0, self.img_size - self.patch_size)
            backdoor_loc_w = random.randint(0, self.img_size - self.patch_size)
            image[:, backdoor_loc_h:backdoor_loc_h + self.patch_size, backdoor_loc_w:backdoor_loc_w + self.patch_size] = noise
        
        elif self.patch_location == 'top_left_corner':
            image[:, : self.patch_size, : self.patch_size] = noise
        
        elif self.patch_location == 'blended':
            image = (0.2 * noise) + (0.8 * image)
            image = torch.clip(image, 0, 1)

        else:
            raise Exception('no matching patch location.')

        return image
    
def corner_mask_generation(patch=None, image_size=(3, 224, 224)):
    applied_patch = np.zeros(image_size)
    x_location = image_size[1]-patch.shape[1]
    y_location = image_size[2]-patch.shape[2]
    applied_patch[:, x_location:x_location + patch.shape[1], y_location:y_location + patch.shape[2]] = patch
    mask = applied_patch.copy()
    mask[mask != 0] = 1.0
    return applied_patch, mask, x_location, y_location
    