import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from random import random, shuffle

from PIL import Image
from tqdm import tqdm

import os

import torchvision.transforms.v2


class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, num_samples=-1, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
                    
        for file in os.listdir(root_dir):
            if file.lower().endswith(('png', 'jpg', 'jpeg')):
                self.image_paths.append(os.path.join(root_dir, file))
        shuffle(self.image_paths)
        if num_samples >0:
            if num_samples < len(self.image_paths):
                self.image_paths = self.image_paths[:num_samples]
        

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
                
        if self.transform:
            image = self.transform(image)
        
        return image

def add_gaussian_noise(image, mean=0.0, std=10.0):

    img_array = np.array(image).astype(np.float32)
    noise = np.random.normal(mean, std, img_array.shape)
    noisy_array = img_array + noise
    noisy_array = np.clip(noisy_array, 0, 255).astype(np.uint8)
    
    return Image.fromarray(noisy_array)


def noise(dataset, args):
    transform = transforms.Compose([
        transforms.Lambda(lambda x: add_gaussian_noise(x, mean=0.0, std=25)),
    ])
    
    dataset.transform = transform
    
    return dataset

def blur(dataset, args, variance=25, kernel_size=7):
    
    transform = transforms.Compose([
        transforms.Lambda(lambda x: add_gaussian_noise(x, mean=0.0, std=variance)),
        transforms.GaussianBlur(kernel_size),
    ])
    

    
    dataset.transform = transform
    
    return dataset

def color(dataset, args, hue=0.3, saturation=3.0, contrast=3.0):
    
    transform = transforms.Compose([
        transforms.ColorJitter(hue=0.3, saturation=3.0, contrast=3.0),
    ])
    
    dataset.transform = transform
    
    return dataset

def crop(dataset, args, crop=0.7, rotation=(0, 180)):    
    original_size = dataset[0].size[0]
    
    transform = transforms.Compose([
        transforms.RandomCrop(size=0.7*original_size),
        transforms.Resize((original_size, original_size)),
    ])
    
    dataset.transform = transform
    
    return dataset

def rotate(dataset, args):
    
    transform = transforms.Compose([
        transforms.RandomRotation((args.mini, args.maxi)),
    ])
    
    dataset.transform = transform
    
    return dataset

def flip(dataset, args):
    
    transform = transforms.RandomVerticalFlip(p=1)
    
    dataset.transform= transform
    
    return dataset

def jpeg(dataset, args, compression=0.25):
    
    transform = transforms.Compose([
        torchvision.transforms.v2.JPEG(75), #25% compression
    ])
    
    dataset.transform = transform
    
    return dataset


def none(dataset, args):
    return dataset


def apply_attack(img_path, attack, args):

    dataset = ImageFolderDataset(img_path, num_samples=args.num_samples)
    
    
    attack_map = {
        'noise' : noise,
        'blur': blur,
        'color': color,
        'crop': crop,
        'rotate' : rotate,
        'flip' : flip,
        'jpeg' : jpeg,
        'none' : none,
    }


    if attack not in attack_map:
        raise ValueError(f"Unsupported attack: {attack}")
    
    
    return attack_map[attack](dataset, args)

