
import argparse
import os
os.chdir("./EML-NET-Saliency")
print("Current location:", os.getcwd())
import pathlib as pl
from pathlib import Path
import cv2

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from PIL import Image

import numpy as np
from skimage import filters
import skimage.io as sio

import resnet
import decoder


def normalize(x):
    p2, p98 = np.percentile(x, (2, 98))
    x = np.clip(x, p2, p98)
    x -= x.min()
    x /= x.max()
    return x

def post_process(pred):
    pred = filters.gaussian(pred, 1)
    pred = normalize(pred)
    return pred

def compute_eccentricity(cnt, H, W):
    M = cv2.moments(cnt)
    if M["m00"] == 0:
        return 1.0
    cx = M["m10"] / M["m00"]
    cy = M["m01"] / M["m00"]
    dx = (cx - W / 2) / (W / 2)
    dy = (cy - H / 2) / (H / 2)
    return np.sqrt(dx ** 2 + dy ** 2)

def draw_contours_on_image(img, regions, color=(0, 255, 0), thickness=2, show_index=True):
    if isinstance(img, Image.Image):
        img_cv = np.array(img)[:, :, ::-1]
    else:
        img_cv = img.copy()
        if img_cv.shape[2] == 3:
            img_cv = img_cv[:, :, ::-1]

    vis = img_cv.copy()

    for i, region in enumerate(regions):
        cnt = region['contour']
        cv2.drawContours(vis, [cnt], -1, color, thickness)

        if show_index:
            M = cv2.moments(cnt)
            if M["m00"] != 0:
                cx = int(M["m10"] / M["m00"])
                cy = int(M["m01"] / M["m00"])
                cv2.putText(vis, f"#{i+1}", (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)

    vis_rgb = vis[:, :, ::-1]
    return Image.fromarray(vis_rgb)

def extract_top_saliency_regions_perceptual(pred, N=5, dilate_kernel_size=5, threshold=0.3, min_area_ratio=0.001,area_penalty_weight=1,center_bias_weight=1.0):
    H, W = pred.shape
    binary = (pred > threshold).astype(np.uint8)

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_kernel_size, dilate_kernel_size))
    binary_dilated = cv2.dilate(binary, kernel, iterations=1)
    binary_cleaned = cv2.morphologyEx(binary_dilated, cv2.MORPH_OPEN, kernel)

    contours, _ = cv2.findContours(binary_cleaned, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    region_info = []
    for cnt in contours:
        mask = np.zeros_like(pred, dtype=np.uint8)
        cv2.drawContours(mask, [cnt], -1, 1, -1)

        region_pixels = pred[mask == 1]
        area = region_pixels.size
        area_ratio = area / (H * W)

        
        if area < H * W * min_area_ratio:
            continue

        mean_sal = region_pixels.mean()
        perceptual_score = area*mean_sal

        region_info.append((perceptual_score, mean_sal, area, cnt, mask))

    region_info.sort(key=lambda x: x[0], reverse=True)
    top_regions = region_info[:N]

    return [{
        'score': score,
        'mean_saliency': mean_sal,
        'area': area,
        'contour': cnt,
        'mask': mask
    } for score, mean_sal, area, cnt, mask in top_regions]

def multiple_regions_to_binary_mask(regions, shape):
   
    mask = np.zeros(shape, dtype=np.uint8)
    for region in regions:
        contour = region['contour']
        cv2.drawContours(mask, [contour], -1, color=255, thickness=-1)
    return mask
    
def main():
    preprocess = transforms.Compose([
        transforms.ToTensor(),
    ])
    input_dir_template = './clic_dataset/clic{:02d}/data/clic{:02d}.png'
    output_dir_contour_vis = './clic_dataset/clic_vis_contour/'
    os.makedirs(output_dir_contour_vis, exist_ok=True)
    output_dir_saliency_vis = './clic_dataset/clic_vis_saliency/'
    os.makedirs(output_dir_saliency_vis, exist_ok=True)
    output_dir_mask='./clic_dataset/clic_new_mask/'
    os.makedirs(output_dir_mask, exist_ok=True)

    for idx in [32]:    
        img_path = input_dir_template.format(idx, idx)
        if not os.path.exists(img_path):
            continue

        pil_img = Image.open(img_path).convert('RGB')
        size = pil_img.size[::-1]
        H, W = size
        img_model = resnet.resnet50('./backbone/res_imagenet.pth').cuda().eval()
        pla_model = resnet.resnet50('./backbone/res_places.pth').cuda().eval()
        decoder_model = decoder.build_decoder('./backbone/res_decoder.pth', (H, W), 5, 5).cuda().eval()

        processed = preprocess(pil_img).unsqueeze(0).cuda()

        with torch.no_grad():
            img_feat = img_model(processed, decode=True)
            pla_feat = pla_model(processed, decode=True)
            saliency_map = decoder_model([img_feat, pla_feat])

        saliency_map = saliency_map.squeeze().detach().cpu().numpy()
        saliency_map_smoothed = post_process(saliency_map)
        saliency_map_smoothed_255 = cv2.cvtColor((saliency_map_smoothed * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)

        regions = extract_top_saliency_regions_perceptual(
            saliency_map_smoothed,
            N=2,
            dilate_kernel_size=3,
            threshold=np.percentile(saliency_map_smoothed, 90),
            min_area_ratio=0.005,
            area_penalty_weight=0.5,
        )

        for i, region in enumerate(regions):
            print(f"[Image {idx:02d}] Region #{i+1}:")
            print(f"  Mean saliency: {region['mean_saliency']:.4f}")
            print(f"  Area (pixels): {region['area']}")
            print(f"  Score (area × mean × center_weight): {region['score']:.2f}")
            
        img_contour = draw_contours_on_image(pil_img, regions)
        pred_path_vis = os.path.join(output_dir_contour_vis, f"kodim{idx:02d}_contour.png")
        pred_path = os.path.join(output_dir_saliency_vis, f"kodim{idx:02d}_saliency.png")

        mask = multiple_regions_to_binary_mask(regions, shape=(H, W))
        mask_path = os.path.join(output_dir_mask, f"kodim{idx:02d}.png")
        cv2.imwrite(mask_path, mask)

        img_contour.save(pred_path_vis)
        sio.imsave(pred_path, saliency_map_smoothed_255)

if __name__ == '__main__':
    main()
