import json
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
from diffusers import AutoencoderKL, ControlNetModel, UniPCMultistepScheduler
# from controlnet_aux import CannyDetector 
from transformers import AutoModel, AutoImageProcessor
import numpy as np
from random import random, shuffle

from PIL import Image
# from CtrlRegen.custom_i2i_pipeline import CustomStableDiffusionControlNetImg2ImgPipeline
# from color_matcher import ColorMatcher
# from color_matcher.normalizer import Normalizer
from tqdm import tqdm

def color_match(ref_img, src_img):
    cm = ColorMatcher() 
    img_ref_np = Normalizer(np.asarray(ref_img)).type_norm()
    img_src_np = Normalizer(np.asarray(src_img)).type_norm()

    img_res = cm.transfer(src=img_src_np, ref=img_ref_np, method='hm-mkl-hm')   # hm-mvgd-hm / hm-mkl-hm
    img_res = Normalizer(img_res).uint8_norm()
    img_res = Image.fromarray(img_res)
    return img_res



import os

import torchvision.transforms.v2


class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, num_samples=500000, transform=None, load_token_map=False):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.load_token_map = load_token_map
        
        # Collect all image paths from subdirectories
        
        """if os.path.isdir(root_dir):
            for subdir in os.listdir(root_dir):
                subdir_path = os.path.join(root_dir, subdir)
                if os.path.isdir(subdir_path):
                    for file in os.listdir(subdir_path):
                        if file.lower().endswith(('png', 'jpg', 'jpeg')):
                            self.image_paths.append(os.path.join(subdir_path, file))"""
            

        if self.load_token_map:
            self.token_indices = torch.load(os.path.join(root_dir, 'generated_token_indices.pt'), map_location='cpu')
            self.length = self.token_indices[0].shape[0]
            # Precompute image file names
            if root_dir.endswith("_50000"):
                self.image_paths = [os.path.join(root_dir, f"{i:05d}.png") for i in range(self.length)]
            else: 
                self.image_paths = [os.path.join(root_dir, f"{i:03d}.png") for i in range(self.length)]

            self.token_indices = [t[:num_samples] for t in self.token_indices]
            self.image_paths = self.image_paths[:num_samples]
            assert len(self.image_paths) == len(self.token_indices[0]), "Image paths and token indices length mismatch"

        else:
            for file in os.listdir(root_dir):
                if file.lower().endswith(('png', 'jpg', 'jpeg')):
                    # add error check
                    self.image_paths.append(os.path.join(root_dir, file))
            self.image_paths = self.image_paths[:num_samples]



    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")

        # resize the PILimages
        image = image.resize((256, 256))
                
        if self.transform:
            image = self.transform(image)

        image_tensor = transforms.ToTensor()(image)

        if self.load_token_map:
            tokens = [t[idx].clone() for t in self.token_indices]
        else:
            tokens = []

        return image_tensor, tokens

def add_gaussian_noise(image, mean=0.0, std=10.0):
    """
    Add Gaussian noise to a PIL image.

    Args:
        image (PIL.Image): Input image.
        mean (float): Mean of Gaussian noise.
        std (float): Standard deviation of Gaussian noise.

    Returns:
        PIL.Image: Image with added Gaussian noise.
    """
    # Convert PIL image to NumPy array
    img_array = np.array(image).astype(np.float32)
    
    # Add Gaussian noise
    noise = np.random.normal(mean, std, img_array.shape)
    noisy_array = img_array + noise

    # Clip values to valid range and convert back to uint8
    noisy_array = np.clip(noisy_array, 0, 255).astype(np.uint8)

    # Convert back to PIL image
    return Image.fromarray(noisy_array)

class TensorImageDataset(Dataset):
    def __init__(self, tensor_array, transform=None):
        """
        tensor_array: a list or tensor of shape (N, C, H, W)
        transform: optional transform to be applied on a sample
        """
        self.tensor_array = tensor_array
        self.transform = transforms.ToPILImage()

    def __len__(self):
        return len(self.tensor_array)

    def __getitem__(self, idx):
        sample = self.tensor_array[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample


#Define Attacks
#Conventional
def noise(dataset, args):

    variance = args.variance 
    assert variance >= 0 and variance <= 1, "Variance must be between 0 and 1"

    noise_added = 255  * variance


    transform = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Lambda(lambda x: add_gaussian_noise(x, mean=0.0, std=noise_added)),
        #transforms.ToTensor(),
    ])
    
    #apply gaussian noise with variance 
    
    #apply gaussian blur with 8x8 filter
    
    dataset.transform = transform
    
    return dataset

def gauss(dataset, args):

    kernel_size = args.kernel_size
    variance = args.variance
    assert variance >= 0 and variance <= 1, "Variance must be between 0 and 1"

    noise_added = 255  * variance

    transform = transforms.Compose([
        #transforms.Lambda(lambda x: add_gaussian_noise(x, mean=0.0, std=noise_added)),
        transforms.GaussianBlur(kernel_size),
        #transforms.ToTensor(),
    ])
    
    #apply gaussian noise with variance 
    
    #apply gaussian blur with 8x8 filter
    
    dataset.transform = transform
    
    return dataset

def color(dataset, args, hue=0.3, saturation=3.0, contrast=3.0):
    
    #apply color jitter with random hue (0.3)
    #saturation scaling (3.0)
    #contrast scaling (3.0)
    transform = transforms.Compose([
        transforms.ColorJitter(hue=0.3, saturation=3.0, contrast=3.0),
        #transforms.ToTensor(),
    ])
    
    dataset.transform = transform
    
    return dataset

def crop(dataset, args):
    
    #crop and resize: 0.7, random rotation 0-180degrees
    crop_ratio = args.crop_ratio

    original_size = dataset[0][0].size[0]
    
    transform = transforms.Compose([
        transforms.RandomCrop(size=crop_ratio*original_size),
        # transforms.CenterCrop(size=crop_ratio*original_size),
        # transforms.RandomRotation((0, 180)),
        transforms.Resize((original_size, original_size)),
        #transforms.ToTensor(),
    ])
    
    dataset.transform = transform
    
    return dataset

def rotate(dataset, args):
    
    transform = transforms.Compose([
        transforms.RandomRotation(args.degrees),
    ])
    
    dataset.transform = transform
    
    return dataset

def flip(dataset, args):
    
    transform = transforms.RandomVerticalFlip(p=1)
    
    dataset.transform= transform
    
    return dataset

def jpeg(dataset, args, compression=0.25):

    final_quality = args.final_quality
    assert 1 <= final_quality <= 100, "Quality must be between 1 and 100"
    
    transform = transforms.Compose([
        torchvision.transforms.v2.JPEG(final_quality), #25% compression
        #transforms.ToTensor()
    ])
    
    dataset.transform = transform
    
    return dataset

def conventional_all(dataset, args):
    original_size = dataset[0].size[0]
    
    transform = transforms.Compose([
        transforms.Lambda(lambda x: add_gaussian_noise(x, mean=0.0, std=25)),
        transforms.GaussianBlur(7),
        transforms.ColorJitter(hue=0.3, saturation=3.0, contrast=3.0),
        transforms.RandomCrop(size=0.7*original_size),
        transforms.RandomRotation((0, 180)),
        transforms.Resize((original_size, original_size)),
        torchvision.transforms.v2.JPEG(75),
        #transforms.ToTensor()    
    ])
    
    dataset.transform = transform
    
    return dataset


#Regeneration
def VAE(dataset, args): #vae_path as kwarg
    
    #encoding and deconding with vae of StableDiffusion 1.5 or 2.1
    
    vae_path = args.stable_diff_vae
    
    vae = AutoencoderKL.from_pretrained(vae_path).to("cuda")

    t = transforms.ToTensor()
    
    dataset.transform = t

    def encode_img(input_img):
        # Single image -> single latent in a batch (so size 1, 4, 64, 64)
        if len(input_img.shape)<4:
            input_img = input_img.unsqueeze(0)
        with torch.no_grad():
            latent = vae.encode(input_img*2 - 1) # Note scaling
        return 0.18215 * latent.latent_dist.sample()


    def decode_img(latents):
        # bath of latents -> list of images
        latents = (1 / 0.18215) * latents
        with torch.no_grad():
            image = vae.decode(latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach()
        return image
    
    processed = [decode_img(encode_img(image.to("cuda"))).squeeze(0).to("cpu") for image in dataset]
    
    dataset = TensorImageDataset(processed)
    
    return dataset

def DiffPure(dataset, args, t=0.15):
    
    #https://arxiv.org/abs/2408.11039
    
    raise NotImplementedError("This is not yet implemented")
    
    return dataset

def CtrlRegen(dataset, args): 


    #https://github.com/yepengliu/ctrlregen
    
    device =  'cuda'

    transform_size_to_512 = transforms.Compose([
            transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(512),
            ])

    DIFFUSION_MODEL = 'SG161222/Realistic_Vision_V4.0_noVAE'
    SPATIAL_CONTROL_PATH = '[CTRLREGEN_PATH]/models/ctrlregen/spatialnet_ckp/spatial_control_ckp_14000'
    SEMANTIC_CONTROL_PATH = '[CTRLREGEN_PATH]/models/ctrlregen/semanticnet_ckp'
    SEMANTIC_CONTROL_NAME = '[CTRLREGEN_PATH]/models/ctrlregen/semanticnet_ckp/models/semantic_control_ckp_435000.bin'
    IMAGE_ENCODER = 'facebook/dinov2-giant'
    VAE = 'stabilityai/sd-vae-ft-mse'

    spatialnet = [ControlNetModel.from_pretrained(SPATIAL_CONTROL_PATH, torch_dtype=torch.float16)]
    pipe = CustomStableDiffusionControlNetImg2ImgPipeline.from_pretrained(DIFFUSION_MODEL, \
                                                            controlnet=spatialnet, \
                                                            torch_dtype=torch.float16,
                                                            safety_checker = None,
                                                            requires_safety_checker = False
                                                            )
    pipe.costum_load_ip_adapter(SEMANTIC_CONTROL_PATH, subfolder='models', weight_name=SEMANTIC_CONTROL_NAME)
    pipe.image_encoder = AutoModel.from_pretrained(IMAGE_ENCODER, cache_dir="[CTRLREGEN_PATH]/models/").to(device, dtype=torch.float16)
    pipe.feature_extractor = AutoImageProcessor.from_pretrained(IMAGE_ENCODER, cache_dir="[CTRLREGEN_PATH]/models/")
    pipe.vae = AutoencoderKL.from_pretrained(VAE, cache_dir="[CTRLREGEN_PATH]/models/").to(dtype=torch.float16)
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.set_ip_adapter_scale(1.0)
    pipe.set_progress_bar_config(disable=True)
    pipe.to(device)

    processor = CannyDetector()

    def ctrl_regen_plus(input_img, step, seed=0):
        generator = torch.manual_seed(seed)
        input_img = transform_size_to_512(input_img)
        processed_img = processor(input_img, low_threshold=100, high_threshold=150)
        prompt = 'best quality, high quality'
        negative_prompt = 'monochrome, lowres, bad anatomy, worst quality, low quality'
        output_img = pipe(prompt,
                        negative_prompt=negative_prompt,
                        image = [input_img],
                        control_image = [processed_img], # spatial condition
                        ip_adapter_image = [input_img],   # semantic condition
                        strength = step,
                        generator = generator,
                        num_inference_steps=50,
                        controlnet_conditioning_scale = 1.0,
                        guidance_scale = 2.0,
                        control_guidance_start = 0,
                        control_guidance_end = 1,
                        ).images[0]
        output_img = color_match(input_img, output_img)
        return output_img

    t = transforms.Compose([
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        
    ])

    num_steps = args.ctrl_regen_steps

    processed = [t(ctrl_regen_plus(image, step=num_steps)) for i, image in enumerate(tqdm(dataset)) if i <= args.num_samples]
    dataset = TensorImageDataset(processed)
    
    return dataset

def none(dataset, args):
    return dataset


def apply_attack(img_path, attack, load_token_map, args):

    dataset = ImageFolderDataset(
        root_dir=img_path, 
        num_samples=args.num_samples, 
        transform=None, 
        load_token_map=load_token_map)
    
    
    attack_map = {
        'noise' : noise,
        'gauss': gauss,
        'color': color,
        'crop': crop,
        'rotate' : rotate,
        'flip' : flip,
        'jpeg' : jpeg,
        'VAE' : VAE, #vae_path as kwarg
        'DiffPure': DiffPure,
        'CtrlRegen' : CtrlRegen,
        'conventional_all' : conventional_all,
        'none' : none,
    }


    if attack not in attack_map:
        raise ValueError(f"Unsupported attack: {attack}")
    
    
    return attack_map[attack](dataset, args)

