import os
import cv2
import math
import numpy as np
from tqdm import tqdm
import imageio
import torch
import torchvision.utils as tvu
from torchvision import transforms

class RandomMaskingGenerator:
    def __init__(self, input_size, mask_ratio):
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2

        self.height, self.width = input_size

        self.num_patches = self.height * self.width
        self.num_mask = int(mask_ratio * self.num_patches)

    def __repr__(self):
        repr_str = "Maks: total patches {}, mask patches {}".format(
            self.num_patches, self.num_mask
        )
        return repr_str

    def __call__(self):
        mask = np.hstack([ 
            np.zeros(self.num_patches - self.num_mask), 
            np.ones(self.num_mask), 
        ])
        np.random.shuffle(mask)
        return mask # [196]

def pad_image_to_size(image, target_width=224, target_height=224, fill_value=255):

    height, width = image.shape[:2]

    if height < target_height:
        pad_height = target_height - height
        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
    else:
        pad_top = pad_bottom = 0

    if width < target_width:
        pad_width = target_width - width
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left
    else:
        pad_left = pad_right = 0

    padded_image = np.pad(
        image,
        ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)),
        mode="constant",
        constant_values=fill_value
    )

    return padded_image

def center_crop(image, crop_width, crop_height):
    height, width = image.shape[:2]

    if width > crop_width:
        start_x = (width - crop_width) // 2
        end_x = start_x + crop_width
    else:
        start_x, end_x = 0, width
    if height > crop_height:
        start_y = (height - crop_height) // 2
        end_y = start_y + crop_height
    else:
        start_y, end_y = 0, height

    cropped_image = image[start_y:end_y, start_x:end_x]
    if cropped_image.shape[0] < crop_height or cropped_image.shape[1] < crop_width:
        cropped_image = pad_image_to_size(cropped_image, target_width=crop_width, target_height=crop_width,
                                          fill_value=255)

    return cropped_image

def create_half_masks():
    size = 16
    total_size = size * size
    
    upper_half = np.zeros((size, size))
    upper_half[size//2:, :] = 1
    upper_mask = upper_half.flatten()

    lower_half = np.zeros((size, size))
    lower_half[:size//2, :] = 1
    lower_mask = lower_half.flatten()
    
    left_half = np.zeros((size, size))
    left_half[:, size//2:] = 1
    left_mask = left_half.flatten()
    
    right_half = np.zeros((size, size))
    right_half[:, :size//2] = 1
    right_mask = right_half.flatten()
    
    return upper_mask, lower_mask, left_mask, right_mask

def sample_inpaint(image, rec_image, ratio, size=224, multiple=14, is_train=True):
    patch_number = (size // multiple, size // multiple)
    mask = RandomMaskingGenerator((patch_number[0], patch_number[1]), ratio)()
    # upper_mask, lower_mask, left_mask, right_mask = create_half_masks()
    # mask = upper_mask
    if is_train:
        gt = mask
        # gt = np.reshape(gt, (patch_number[0], patch_number[1]))
        # gt = np.kron(gt, np.ones(shape=(16, 16), dtype=np.uint8))
        # gt = gt.reshape(-1)
    # mask = [0,0,0,0,0,1,1,0,0,1,1,0,0,0,0,0]
    # mask = [1,1,1,1,1,0,0,1,1,0,0,1,1,1,1,1]
    # mask = [0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1]
    # mask = [1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0]
    
    mask = np.reshape(mask, (patch_number[0], patch_number[1]))
    mask = np.kron(mask, np.ones(shape=(multiple, multiple), dtype=np.uint8))
    mask = np.dstack([mask] * 3)
    if is_train:
        mask = torch.from_numpy(mask).permute(2,0,1).float()
    new_image = mask * rec_image + (1 - mask) * image

    if is_train:
        return new_image, gt
    if not is_train:
        mask = mask * 255
        gt = mask
        return new_image.astype(np.uint8), gt.astype(np.uint8)

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    filename = "filename"
    image1 = cv2.imread(f'/path/to/GenImage/imagenet_ai_0419_sdv4/train/nature/crop/{filename}.png')
    image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
    image2 = cv2.imread(f'/path/to/GenImage/imagenet_ai_0419_sdv4/train/nature/inpainting/{filename}.png')
    image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
    image1 = center_crop(image1, 224, 224)
    image2 = center_crop(image2, 224, 224)
    image, mask = sample_inpaint(image1, image2, 0.25, 224, 56, is_train=False)
    print(image)
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(image)
    axes[0].axis('off') 

    axes[1].imshow(mask)
    axes[1].axis('off')

    plt.tight_layout()

    plt.savefig('side_by_side_images.jpg', bbox_inches='tight', pad_inches=0)

    plt.show()
    
