import torch
import os
import sys
import pathlib
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

sys.path.append(str(pathlib.Path(__file__).resolve().parents[4]))
root_dir = pathlib.Path(__file__).resolve().parents[4]
sys.path.append(root_dir)

from utils import plot_mask, plot_indices, plot_pixel_status

current_dir = os.getcwd()
delta_rgb = 5
N_perturbed = np.floor(0.01 * 512 * 512).astype(int)
height = 512
width = 512

#################################
image_name = 'CHNCXR_0005_0'
parent_dir = os.path.dirname(current_dir)
parent_dir = os.path.dirname(parent_dir)
image_path = os.path.join(parent_dir, 'images', image_name +'.png')
eval_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

img = Image.open(image_path).convert('L')  # 'L' mode = single-channel grayscale
img = eval_transforms(img)
img_tensor = img.unsqueeze(0)
img_np = img_tensor.detach().numpy().astype(np.float32)

ct = 0
indices = []
for i in range(0, height):
    for j in range(0, width):
        if np.min(img_np[:,:,i, j]) > 150 / 255.0:
            indices.append([i, j])
            ct += 1
            if ct == N_perturbed:
                print(f"{N_perturbed} pixels found.")
                break
    if ct == N_perturbed:
        break

indices = np.array(indices)
####################################





load_path =  os.path.join(current_dir, f'CI_result_CLP_eps_{delta_rgb}_Npertubed_{N_perturbed}_{image_name}.pt')

D = torch.load(load_path, weights_only = False)
        
classes = D['classes']
True_class = D['True_class']

plot_mask(True_class, save_image = True)

plot_indices(height, width, indices, save_image = True)

plot_pixel_status(classes, True_class, save_image = True)