import torch 
from PIL import Image
import torchvision.transforms as T
import numpy as np
import os 



enhance_coe = 5
kernel_size = 3
interpolate_method_name = "LANCZOS"
interpolate_method = getattr(Image, interpolate_method_name)


def readImg(p):
    return Image.open(p)

def toImg(t):
    return T.ToPILImage()(t)

def invtrans(mask, image, method = Image.BICUBIC):
    return mask.resize(image.size, method)


def normalize(mat, method = "max"):
    if method == "max":
        return (mat.max() - mat) / (mat.max() - mat.min())
    elif method == "min":
        return (mat - mat.min()) / (mat.max() - mat.min())
    else:
        raise NotImplementedError

def enhance(mat, coe=10):
    mat = mat - mat.mean()
    mat = mat / mat.std()
    mat = mat * coe
    mat = torch.sigmoid(mat)
    mat = mat.clamp(0,1)
    return mat

def merge(mask, image, grap_scale = 200):
    gray = np.ones((image.size[1], image.size[0], 3))*grap_scale
    image_np = np.array(image).astype(np.float32)[..., :3]
    mask_np = np.array(mask).astype(np.float32)
    mask_np = mask_np / 255.0
    blended_np = image_np * mask_np[:, :, None]  + (1 - mask_np[:, :, None]) * gray
    blended_image = Image.fromarray((blended_np).astype(np.uint8))
    return blended_image

def revise_mask(patch_mask, kernel_size = 3, enhance_coe = 10):

    patch_mask = normalize(patch_mask, "min")
    patch_mask = enhance(patch_mask, coe = enhance_coe)

    assert kernel_size % 2 == 1
    padding_size = int((kernel_size - 1) / 2)
    conv = torch.nn.Conv2d(1,1,kernel_size = kernel_size, padding = padding_size, padding_mode = "replicate", stride = 1, bias = False)
    conv.weight.data = torch.ones_like(conv.weight.data) / kernel_size**2
    conv.to(patch_mask.device)

    patch_mask = conv(patch_mask.unsqueeze(0))[0]

    mask = patch_mask

    return mask


def blend_mask(image_path_or_pil_image, mask, enhance_coe, kernel_size, interpolate_method):
    mask = revise_mask(mask.float(), kernel_size = kernel_size, enhance_coe = enhance_coe)
    mask = mask.detach().cpu()
    mask = toImg(mask.reshape(1,24,24))

    # if isinstance(image_path_or_pil_image, Image.Image):
    #     image = image_path_or_pil_image
    # else:
    #     raise NotImplementedError
    # image = toImg(image_path_or_pil_image)
    image = image_path_or_pil_image

    mask = invtrans(mask, image, method = interpolate_method)
    merged_image = merge(mask.convert("L"), image.convert("RGB")).convert("RGB")

    # file_name = os.path.join(folder, f"{key}.jpg")
    # merged_image.save(file_name)
    # print(file_name)

    # merged_image = torch.from_numpy(np.array(merged_image)).float()
   
    return merged_image

def convert_to_image(image, image_path = None): 
    if isinstance(image, torch.Tensor): 
        image = image.detach().cpu().numpy() 
    
    if not isinstance(image, Image.Image): 
        image = toImg(image)
    

    if image_path is not None:
        image.save(image_path)
    return image


def generate_mask(attention, image): 
    # attention shape: 576
    mask = attention.view(24,24)
    # print("before conversion: ", image.shape)
    # image = convert_to_image(image)

    merged_image = blend_mask(image, mask,  enhance_coe, kernel_size, interpolate_method)
    return merged_image



# images shape:  torch.Size([8, 3, 336, 336])

# if __name__ == "__main__":
#     image_path = "/ROOT_DIR/EncyclopedicVQA/inat/val/01875_Animalia_Arthropoda_Insecta_Lepidoptera_Nymphalidae_Hipparchia_statilinus/122e973f-0efb-43e4-a6a6-3e620f855a43.jpg"
#     merged_image = blend_mask(image_path, mask, "new", enhance_coe, kernel_size, interpolate_method, "./")



