import torch
import os
import sys
import pathlib
import numpy as np


sys.path.append(str(pathlib.Path(__file__).resolve().parents[4]))
root_dir = pathlib.Path(__file__).resolve().parents[4]
sys.path.append(root_dir)

from utils import plot_mask, plot_indices, plot_pixel_status



delta_rgb = 5
N_perturbed = np.floor(0.06 * 720 * 960).astype(int)


image_name = '0001TP_008790'


current_dir = os.getcwd()

load_path =  os.path.join(current_dir, f'CI_result_CLP_eps_{delta_rgb}_Npertubed_{N_perturbed}_{image_name}.pt')

D = torch.load(load_path, weights_only = False)
        
classes = D['classes']
True_class = D['True_class']
indices = D['indices']
height = 720
width = 960

plot_mask(True_class, save_image = True)

plot_indices(height, width, indices, save_image = True)

plot_pixel_status(classes, True_class, save_image = True)