import torch
import os
import sys
import pathlib
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

sys.path.append(str(pathlib.Path(__file__).resolve().parents[4]))
root_dir = pathlib.Path(__file__).resolve().parents[4]
sys.path.append(root_dir)

from utils import plot_mask, plot_indices, plot_pixel_status

current_dir = os.getcwd()
delta_rgb = 5
N_perturbed = np.floor(0.01 * 304 * 304).astype(int)
height = 304
width = 304

#################################
image_name = '10491'
parent_dir = os.path.dirname(current_dir)
parent_dir = os.path.dirname(parent_dir)
image_path = os.path.join(parent_dir, 'images', image_name +'.bmp')


img = Image.open(image_path)
to_tensor = transforms.ToTensor()
img = to_tensor(img)
img_tensor = img.reshape(1, 1, 304, 304).to('cpu', dtype=torch.float32)
img_np = img_tensor.detach().numpy().astype(np.float32)

ct = 0
indices = []
for i in range(0, height):
    for j in range(0, width):
        if np.min(img_np[:,:,i, j]) > 150 / 255.0:
            indices.append([i, j])
            ct += 1
            if ct == N_perturbed:
                print(f"{N_perturbed} pixels found.")
                break
    if ct == N_perturbed:
        break

indices = np.array(indices)
####################################





load_path =  os.path.join(current_dir, f'CI_result_CLP_eps_{delta_rgb}_Npertubed_{N_perturbed}_{image_name}.pt')

D = torch.load(load_path, weights_only = False)
        
classes = D['classes']
True_class = D['True_class']

plot_mask(True_class, save_image = True)

plot_indices(height, width, indices, save_image = True)

plot_pixel_status(classes, True_class, save_image = True)