import torch
import numpy as np
import onnxruntime as ort
import cv2
from PIL import Image
import os
import sys
import pathlib
import torchvision.transforms as transforms
import gc

sys.path.append(str(pathlib.Path(__file__).resolve().parents[2]))
root_dir = pathlib.Path(__file__).resolve().parents[2]


sys.path.append(root_dir)
from utils import plot_binary_logits_to_mask



reach_factory_path = os.path.join(root_dir, 'Reach_Factory')
sys.path.append(reach_factory_path)
from Provision_All import Reachability_provider
from Segmentation import Segmentor




def CheXpert_exp1( delta_rgb,  Nt, N_dir, image_name,
                  Ns, Nsp, rank, device,  threshold_normal,
                  sim_batch, trn_batch, surrogate_mode):
    
    
    model_name = 'lung_segmentation.onnx'

    current_dir = os.getcwd()
    model_path = os.path.join(current_dir, 'models', model_name)
    image_path = os.path.join(current_dir, 'images', image_name)


    img = Image.open(image_path)
    to_tensor = transforms.ToTensor()
    img = to_tensor(img)


    ort_session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider'])

    eval_transforms = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ])


    img = Image.open(image_path).convert('L')  # 'L' mode = single-channel grayscale
    img = eval_transforms(img)
    img_tensor = img.unsqueeze(0).to(device)
    img_np = img_tensor.detach().cpu().numpy().astype(np.float32)

    
    output = ort_session.run(None, {'input': img_np})
    output = torch.tensor(output[0]).to(device)
    
    
    threshold = 0.0

    plot_binary_logits_to_mask(output, threshold)

    output_dim = output.shape

    # output_np = torch.sigmoid(output).squeeze().cpu().numpy()  # shape: [512, 512]
    output_np = output.squeeze().cpu().numpy()
    threshold = 0.0
    True_class = [[int(output_np[i, j] > threshold) for j in range(512)] for i in range(512)]

    
    Cent = torch.from_numpy(img_np).to(device)
    de = delta_rgb / 255
    
    
    params = {
        'sim_batch' : sim_batch,
        'Nt' : Nt,
        'N_dir' : N_dir,
        'trn_batch' : trn_batch,
        'threshold_normal' : threshold_normal,
        'Ns' : Ns,
        'Nsp' : Nsp,
        'rank' : rank,
        'perturbation' : delta_rgb,
        'True_class' : True_class,
        'class_threshold' : threshold,
        'image_name' : image_name,
        'input_name' : 'input'
    }
    
    
    provide = Reachability_provider(
        de = de,
        device = device,
        model = ort_session,
        Cent = Cent,
        original_dim = (1, 512, 512),
        output_dim = output_dim,
        mode = surrogate_mode,
        radii_mode = "L2",
        params = params
        )
    
    params.update({
    'gb_workers': 112,          # number of CPU workers (processes)
    'gb_threads': 1,            # threads per Gurobi model (keep small when many workers)
    'gb_presolve': 2,
    'gb_method': 1,
    'gb_opt_tol': 1e-9,
    'gb_feas_tol': 1e-9,

    # batch sizing knobs (optional)
    # 'gb_inner_batch': 16,      # hard override if you want a fixed size
    'gb_tasks_per_worker': 4,   # auto: at least 4 LPs queued per worker
    'gb_inner_batch_min': 4,    # don’t go below this
    'gb_inner_batch_max': 64,   # keep LPs small; 32–64 works well
    'gb_cap_by_ndir': True,     # also bound by N_dir // 2 for safety
    })
    
        
    provide.Provider()




def CheXpert_exp2( projection_batch, guarantee, device, src_dir, nnv_dir ):
    
    params = {
        'projection_batch' : projection_batch,
        'guarantee': guarantee,
    }
    
    Segment = Segmentor(
        device = device,
        src_dir = src_dir,
        nnv_dir = nnv_dir,
        params = params
        )

    Segment.Mask_titles()





if __name__ == '__main__':
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    Ns = 8000
    Nsp = 1000
    rank = 7999
    guarantee = 0.999
    Nt = 1500
    N_dir = 100
    threshold_normal = 1e-5
    sim_batch = 50
    trn_batch = 20
    surrogate_mode = 'CLP'
    
    src_dir = None
    nnv_dir = None
    projection_batch = 512 * 512
    


    image_names = [
        'CHNCXR_0005_0.png',
        'MCUCXR_0258_1.png',
        'MCUCXR_0264_1.png',
        'MCUCXR_0266_1.png',
        'MCUCXR_0275_1.png',
        'MCUCXR_0282_1.png'
        # 'MCUCXR_0289_1.png',
        # 'MCUCXR_0294_1.png',
        # 'MCUCXR_0301_1.png',
        # 'MCUCXR_0309_1.png',
        # 'MCUCXR_0311_1.png',
        # 'MCUCXR_0313_1.png',
        # 'MCUCXR_0316_1.png',
        # 'MCUCXR_0331_1.png',
        # 'MCUCXR_0334_1.png',
        # 'MCUCXR_0338_1.png',
        # 'MCUCXR_0348_1.png',
        # 'MCUCXR_0350_1.png',
        # 'MCUCXR_0352_1.png',
        # 'MCUCXR_0354_1.png'
        ]

    
    delta_rgb_list = [50, 100, 150] #L2
    

    for idx, image_name in enumerate(image_names):
        for delta_rgb in delta_rgb_list:
            
            print(f"Running: {image_name} with N_perturbed = ALL")
            
            CheXpert_exp1( delta_rgb,  Nt, N_dir, image_name,
                              Ns, Nsp, rank, device,  threshold_normal,
                              sim_batch, trn_batch, surrogate_mode)
            
            
            CheXpert_exp2( projection_batch, guarantee, device, src_dir, nnv_dir )