import json
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from random import random, shuffle

from PIL import Image
from tqdm import tqdm





import os

import torchvision.transforms.v2


class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, num_samples=-1, transform=None, load_token_map=False):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.load_token_map = load_token_map
        
        # Collect all image paths from subdirectories
        
        """if os.path.isdir(root_dir):
            for subdir in os.listdir(root_dir):
                subdir_path = os.path.join(root_dir, subdir)
                if os.path.isdir(subdir_path):
                    for file in os.listdir(subdir_path):
                        if file.lower().endswith(('png', 'jpg', 'jpeg')):
                            self.image_paths.append(os.path.join(subdir_path, file))"""
            
        for file in os.listdir(root_dir):
            if file.lower().endswith(('png', 'jpg', 'jpeg')):
                self.image_paths.append(os.path.join(root_dir, file))
        shuffle(self.image_paths)
        if num_samples >0:
            if num_samples < len(self.image_paths):
                self.image_paths = self.image_paths[:num_samples]

        if self.load_token_map:
            with open(os.path.join(root_dir, "token_map.jsonl"), "r") as f:
                token_map_loaded = [json.loads(line) for line in f.readlines()]
            # convert to dict with image path as key
            self.token_map_dict = {item["image"]: item["tokens"] for i, item in enumerate(token_map_loaded)}
            # assert len(self.image_paths) == len(self.token_map_dict), f"Number of images ({len(self.image_paths)}) and tokens ({len(self.token_map_dict)}) do not match!"


    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)#.convert("RGB")
                
        if self.transform:
            image = self.transform(image)

        if not self.load_token_map:
            tokens = None
        else:
            # Get the tokens from the token map
            image_name = os.path.basename(img_path)
            tokens = self.token_map_dict.get(image_name, None)
            if tokens is not None:
                tokens = torch.tensor(tokens, dtype=torch.long)
            else:
                tokens = None
        return image, tokens

def add_gaussian_noise(image, mean=0.0, std=10.0):
    """
    Add Gaussian noise to a PIL image.

    Args:
        image (PIL.Image): Input image.
        mean (float): Mean of Gaussian noise.
        std (float): Standard deviation of Gaussian noise.

    Returns:
        PIL.Image: Image with added Gaussian noise.
    """
    # Convert PIL image to NumPy array
    img_array = np.array(image).astype(np.float32)
    
    # Add Gaussian noise
    noise = np.random.normal(mean, std, img_array.shape)
    noisy_array = img_array + noise

    # Clip values to valid range and convert back to uint8
    noisy_array = np.clip(noisy_array, 0, 255).astype(np.uint8)

    # Convert back to PIL image
    return Image.fromarray(noisy_array)

class TensorImageDataset(Dataset):
    def __init__(self, tensor_array, transform=None):
        """
        tensor_array: a list or tensor of shape (N, C, H, W)
        transform: optional transform to be applied on a sample
        """
        self.tensor_array = tensor_array
        self.transform = transforms.ToPILImage()

    def __len__(self):
        return len(self.tensor_array)

    def __getitem__(self, idx):
        sample = self.tensor_array[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample


#Define Attacks
#Conventional
def noise(dataset, args):

    variance = args.variance 
    assert variance >= 0 and variance <= 1, "Variance must be between 0 and 1"

    noise_added = 255  * variance


    transform = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Lambda(lambda x: add_gaussian_noise(x, mean=0.0, std=noise_added)),
        #transforms.ToTensor(),
    ])
    
    #apply gaussian noise with variance 
    
    #apply gaussian blur with 8x8 filter
    
    dataset.transform = transform
    
    return dataset

def gauss(dataset, args):

    kernel_size = args.kernel_size
    variance = args.variance
    assert variance >= 0 and variance <= 1, "Variance must be between 0 and 1"

    noise_added = 255  * variance

    transform = transforms.Compose([
        #transforms.Lambda(lambda x: add_gaussian_noise(x, mean=0.0, std=noise_added)),
        transforms.GaussianBlur(kernel_size),
        #transforms.ToTensor(),
    ])
    
    #apply gaussian noise with variance 
    
    #apply gaussian blur with 8x8 filter
    
    dataset.transform = transform
    
    return dataset

def color(dataset, args, hue=0.3, saturation=3.0, contrast=3.0):
    
    #apply color jitter with random hue (0.3)
    #saturation scaling (3.0)
    #contrast scaling (3.0)
    transform = transforms.Compose([
        transforms.ColorJitter(hue=0.3, saturation=3.0, contrast=3.0),
        #transforms.ToTensor(),
    ])
    
    dataset.transform = transform
    
    return dataset

def brightness(dataset, args):
    
    bright = args.brightness
    
    transform = transforms.Compose([
        transforms.ColorJitter(brightness=[bright, bright]),
        #transforms.ToTensor(),
    ])
    
    dataset.transform = transform
    
    return dataset

def saturation(dataset, args):
    
    satur = args.saturation
    
    transform = transforms.Compose([
        transforms.ColorJitter(saturation=[satur, satur]),
        #transforms.ToTensor(),
    ])
    
    dataset.transform = transform
    
    return dataset

def contrast(dataset, args):
    
    cont = args.contrast
    
    transform = transforms.Compose([
        transforms.ColorJitter(contrast=[cont, cont]),
        #transforms.ToTensor(),
    ])
    
    dataset.transform = transform
    
    return dataset

def resize(dataset, args):

    resize_ratio = args.resize_ratio

    original_size = dataset[0][0].size[0]

    resizing_size = int(resize_ratio*original_size)

    transform = transforms.Compose([
        transforms.Resize((resizing_size, resizing_size)),
        # transforms.RandomRotation((0, 180)),
        transforms.Resize((original_size, original_size)),
        #transforms.ToTensor(),
    ])
    
    dataset.transform = transform
    
    return dataset


def crop(dataset, args):
    
    #crop and resize: 0.7, random rotation 0-180degrees
    crop_ratio = args.crop_ratio

    original_size = dataset[0][0].size[0]
    
    transform = transforms.Compose([
        transforms.CenterCrop(size=crop_ratio*original_size),
        # transforms.RandomRotation((0, 180)),
        transforms.Resize((original_size, original_size)),
        #transforms.ToTensor(),
    ])
    
    dataset.transform = transform
    
    return dataset

def rotate(dataset, args):
    
    transform = transforms.Compose([
        transforms.RandomRotation(args.degrees),
    ])
    
    dataset.transform = transform
    
    return dataset

def flip(dataset, args):
    
    transform = transforms.RandomVerticalFlip(p=1)
    
    dataset.transform= transform
    
    return dataset

def jpeg(dataset, args, compression=0.25):

    final_quality = args.final_quality
    assert 1 <= final_quality <= 100, "Quality must be between 1 and 100"
    
    transform = transforms.Compose([
        torchvision.transforms.v2.JPEG(final_quality), #25% compression
        #transforms.ToTensor()
    ])
    
    dataset.transform = transform
    
    return dataset

def conventional_all(dataset, args):
    original_size = dataset[0].size[0]
    
    transform = transforms.Compose([
        transforms.Lambda(lambda x: add_gaussian_noise(x, mean=0.0, std=25)),
        transforms.GaussianBlur(7),
        transforms.ColorJitter(hue=0.3, saturation=3.0, contrast=3.0),
        transforms.RandomCrop(size=0.7*original_size),
        transforms.RandomRotation((0, 180)),
        transforms.Resize((original_size, original_size)),
        torchvision.transforms.v2.JPEG(75),
        #transforms.ToTensor()    
    ])
    
    dataset.transform = transform
    
    return dataset



def none(dataset, args):
    return dataset


def apply_attack(img_path, attack, load_token_map, args):

    dataset = ImageFolderDataset(
        root_dir=img_path, 
        num_samples=args.num_samples, 
        transform=None, 
        load_token_map=load_token_map)
    
    
    attack_map = {
        'noise' : noise,
        'gauss': gauss,
        'color': color,
        'brightness': brightness,
        'saturation': saturation,
        'contrast': contrast,
        'crop': crop,
        'resize' : resize,
        'rotate' : rotate,
        'flip' : flip,
        'jpeg' : jpeg,
        'conventional_all' : conventional_all,
        'none' : none,
    }


    if attack not in attack_map:
        raise ValueError(f"Unsupported attack: {attack}")
    
    
    return attack_map[attack](dataset, args)

