import os
import shutil
import random
import torch
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import numpy as np
import scipy.stats as st
from PIL import Image, ImageDraw, ImageFont


def crop_resize(image, image_size):
    width, height = image.size
    new_size = min(width, height)
    left = (width - new_size)/2
    top = (height - new_size)/2
    right = (width + new_size)/2
    bottom = (height + new_size)/2
    image = image.crop((left, top, right, bottom)).resize((image_size, image_size))
    return image


def DIM(x, resize_rate = 1.3):
    x = transforms.ToTensor()(x)
    
    img_size = x.shape[-1]
    img_resize = int(img_size * resize_rate)
    
    rnd = torch.randint(low=min(img_size, img_resize), high=max(img_size, img_resize), size=(1,), dtype=torch.int32)
    rescaled = F.interpolate(x.unsqueeze(0), size=[rnd, rnd], mode='bilinear', align_corners=False)
    
    h_rem = img_resize - rnd
    w_rem = img_resize - rnd
    pad_top = torch.randint(low=0, high=h_rem.item(), size=(1,), dtype=torch.int32)
    pad_bottom = h_rem - pad_top
    pad_left = torch.randint(low=0, high=w_rem.item(), size=(1,), dtype=torch.int32)
    pad_right = w_rem - pad_left

    padded = F.pad(rescaled, [pad_left.item(), pad_right.item(), pad_top.item(), pad_bottom.item()], value=0)
    return transforms.ToPILImage()(F.interpolate(padded, size=[img_size, img_size], mode='bilinear', align_corners=False).squeeze(0))


# def SIM(x, num_scale=5):
#     x = transforms.ToTensor()(x)
#     scaled_images = [x / (2**i) for i in range(num_scale)]
#     selected_image = random.choice(scaled_images)
#     return transforms.ToPILImage()(selected_image)

def SIM(x, brightness_range=(0.2, 1.8)):
    x = transforms.ToTensor()(x)
    brightness_transform = transforms.ColorJitter(brightness=brightness_range)
    adjusted_image = brightness_transform(x)
    return transforms.ToPILImage()(adjusted_image)


def SGA(image, min_scale_ratio=0.5, max_scale_ratio=1.5, noise_level=0.02):
    image_np = np.array(image)
    original_height, original_width, channels = image_np.shape

    random_scale_ratio = np.random.uniform(min_scale_ratio, max_scale_ratio)
    new_width = int(original_width * random_scale_ratio)
    new_height = int(original_height * random_scale_ratio)

    resized_image = Image.fromarray(image_np).resize((new_width, new_height), Image.BICUBIC)
    resized_image_np = np.array(resized_image)

    noise = np.random.normal(0, noise_level * 255, resized_image_np.shape).astype(np.uint8)
    noisy_image_np = resized_image_np + noise
    noisy_image_np = np.clip(noisy_image_np, 0, 255)
    noisy_image = Image.fromarray(noisy_image_np)

    restored_image = noisy_image.resize((original_width, original_height), Image.BICUBIC)
    return restored_image


def SIA(x, num_scale=1):
    def vertical_shift(x):
        _, _, w, _ = x.shape
        step = np.random.randint(low = 0, high=w, dtype=np.int32)
        return x.roll(step, dims=2)

    def horizontal_shift(x):
        _, _, _, h = x.shape
        step = np.random.randint(low = 0, high=h, dtype=np.int32)
        return x.roll(step, dims=3)

    def vertical_flip(x):
        return x.flip(dims=(2,))

    def horizontal_flip(x):
        return x.flip(dims=(3,))

    def rotate180(x):
        return x.rot90(k=2, dims=(2,3))
    
    def scale(x):
        return torch.rand(1)[0] * x

    def add_noise(x):
        return torch.clip(x + torch.zeros_like(x).uniform_(-16/255,16/255), 0, 1)

    def gkern(kernel_size=3, nsig=3):
        x = np.linspace(-nsig, nsig, kernel_size)
        kern1d = st.norm.pdf(x)
        kernel_raw = np.outer(kern1d, kern1d)
        kernel = kernel_raw / kernel_raw.sum()
        stack_kernel = np.stack([kernel, kernel, kernel])
        stack_kernel = np.expand_dims(stack_kernel, 1)
        return torch.from_numpy(stack_kernel.astype(np.float32))

    def blur(x):
        return F.conv2d(x, gkern, stride=1, padding='same', groups=3)

    def blocktransform(x, num_block=3, choice=-1, op = [vertical_shift, horizontal_shift, vertical_flip, horizontal_flip, rotate180, scale, add_noise]):
        _, _, w, h = x.shape
        y_axis = [0,] + np.random.choice(list(range(1, h)), num_block-1, replace=False).tolist() + [h,]
        x_axis = [0,] + np.random.choice(list(range(1, w)), num_block-1, replace=False).tolist() + [w,]
        y_axis.sort()
        x_axis.sort()
        
        x_copy = x.clone()
        for i, idx_x in enumerate(x_axis[1:]):
            for j, idx_y in enumerate(y_axis[1:]):
                chosen = choice if choice >= 0 else np.random.randint(0, high=len(op), dtype=np.int32)
                x_copy[:, :, x_axis[i]:idx_x, y_axis[j]:idx_y] = op[chosen](x_copy[:, :, x_axis[i]:idx_x, y_axis[j]:idx_y])

        return x_copy

    x = transforms.ToTensor()(x).unsqueeze(0)
    return transforms.ToPILImage()(torch.cat([blocktransform(x) for _ in range(num_scale)]).squeeze(0))


def TIM(image, max_shift=15):
    width, height = image.size
    tx = np.random.randint(-max_shift, max_shift + 1)  # Translate shift on the x-axis
    ty = np.random.randint(-max_shift, max_shift + 1)  # Translate shift on the y-axis

    # Create a new image and apply wrap around translation
    translated_image = Image.new("RGB", (width, height))

    # Calculate wrap around parts
    x_part = tx if tx > 0 else width + tx
    y_part = ty if ty > 0 else height + ty

    # Handle the horizontal wrap
    part_left = image.crop((x_part, 0, width, height))
    part_right = image.crop((0, 0, x_part, height))
    translated_image.paste(part_left, (0, 0))
    translated_image.paste(part_right, (width - x_part, 0))

    # Create a new image for the vertical wrap
    final_translated = Image.new("RGB", (width, height))
    part_top = translated_image.crop((0, y_part, width, height))
    part_bottom = translated_image.crop((0, 0, width, y_part))
    final_translated.paste(part_top, (0, 0))
    final_translated.paste(part_bottom, (0, height - y_part))

    return final_translated


def Admix(x, y, num_scale=5, admix_strength=0.2):
    x = transforms.ToTensor()(x)
    y = transforms.ToTensor()(y)
    admix_image = x + admix_strength * y
    scaled_images = [admix_image / (2 ** i) for i in range(num_scale)]
    selected_image = random.choice(scaled_images)
    return transforms.ToPILImage()(selected_image)


def render_typos(image, texts, font_path, font_size, font_color, max_attempts=1000):
    draw = ImageDraw.Draw(image)
    image_width, image_height = image.size

    text_positions = []

    for text in texts:
        font = ImageFont.truetype(font_path, font_size)
        text_width = int(draw.textlength(text, font=font))
        text_height = 30

        attempt = 0
        while attempt < max_attempts:
            text_x = random.randint(0, max(0, image_width - text_width))
            text_y = random.randint(0, max(0, image_height - text_height))

            overlap = False
            for pos in text_positions:
                if (text_x < pos[0] + pos[2] and text_x + text_width > pos[0] and
                    text_y < pos[1] + pos[3] and text_y + text_height > pos[1]):
                    overlap = True
                    break
            
            if not overlap:
                text_positions.append((text_x, text_y, text_width, text_height))
                draw.text((text_x, text_y), text, fill=font_color, font=font)
                break
            attempt += 1

        if attempt == max_attempts:
            print(f"Failed to place text '{text}' without overlap after {max_attempts} attempts.")

    return image


def AIP(image, added_image):
    width, height = image.size
    
    new_width = width // 5
    new_height = height // 5
    added_image = added_image.resize((new_width, new_height))
    
    max_x = width - new_width
    max_y = height - new_height
    x = random.randint(0, max_x)
    y = random.randint(0, max_y)
    
    image.paste(added_image, (x, y))
    
    return image


if __name__ == "__main__":
    directory = 'transformed_images'
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)

    img_path = '/home/ubuntu/LLaVA/dataset/transferable/test_images/cat.jpg'
    image = Image.open(img_path)
    image = crop_resize(image, 224)
    image_tensor = transforms.ToTensor()(image)

    transformed_image = DIM(image)
    transformed_image.save(os.path.join(directory, 'transformed_image_DIM.jpg'))

    transformed_image = SIM(image)
    transformed_image.save(os.path.join(directory, f'transformed_image_SIM.jpg'))

    transformed_image = SGA(image)
    transformed_image.save(os.path.join(directory, 'transformed_image_SGA.jpg'))

    transformed_image = SIA(image)
    transformed_image.save(os.path.join(directory, f'transformed_image_SIA.jpg'))

    transformed_image = TIM(image)
    transformed_image.save(os.path.join(directory, f'transformed_image_TIM.jpg'))

    added_img_path = '/home/ubuntu/LLaVA/dataset/transferable/test_images/000000002149.jpg'
    add_image = Image.open(added_img_path)
    transformed_image = Admix(image, add_image)
    transformed_image.save(os.path.join(directory, 'transformed_image_Admix.jpg'))
    
    added_img_path = '/home/ubuntu/LLaVA/dataset/transferable/test_images/000000002149.jpg'
    add_image = Image.open(added_img_path)
    transformed_image = AIP(image, add_image)
    transformed_image.save(os.path.join(directory, 'transformed_image_AIP.jpg'))