from time import time
import torch
import numpy as np
import onnxruntime as ort
import cv2
from PIL import Image
import os
import sys
import pathlib
from torchvision import transforms

sys.path.append(str(pathlib.Path(__file__).resolve().parents[2]))
root_dir = pathlib.Path(__file__).resolve().parents[2]
reach_factory_path = os.path.join(root_dir, 'Reach_Factory')
sys.path.append(reach_factory_path)
from Provision_All import Reachability_provider
from Segmentation import Segmentor




def OCTA_max_exp1(delta_rgb,  Nt, N_dir, image_name,
                  Ns, Nsp, rank, device,  threshold_normal,
                  sim_batch, trn_batch, surrogate_mode):
    
    
    model_name = 'betti_best.onnx'

    current_dir = os.getcwd()
    model_path = os.path.join(current_dir, 'models', model_name)
    image_path = os.path.join(current_dir, 'images', image_name)


    img = Image.open(image_path)
    to_tensor = transforms.ToTensor()
    img = to_tensor(img)


    ort_session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider'])

    img_tensor = img.reshape(1, 1, 304, 304).to(device, dtype=torch.float32)
    img_np = img_tensor.detach().cpu().numpy().astype(np.float32)
    
    output = ort_session.run(None, {'input': img_np})
    output = torch.tensor(output[0]).to(device)

    output_dim = output.shape

    # output_np = torch.sigmoid(output).squeeze().cpu().numpy()  # shape: [512, 512]
    output_np = output.squeeze().cpu().numpy()
    threshold = np.log(45/55)
    True_class = [[int(output_np[i, j] > threshold) for j in range(304)] for i in range(304)]

    Cent = torch.from_numpy(img_np).to(device)
    de = delta_rgb / 255
    
    params = {
        'sim_batch' : sim_batch,
        'Nt' : Nt,
        'N_dir' : N_dir,
        'trn_batch' : trn_batch,
        'threshold_normal' : threshold_normal,
        'Ns' : Ns,
        'Nsp' : Nsp,
        'rank' : rank,
        'perturbation' : delta_rgb,
        'True_class' : True_class,
        'class_threshold' : threshold,
        'image_name' : image_name,
        'input_name' : 'input'
    }
    
    
    provide = Reachability_provider(
        de = de,
        device = device,
        model = ort_session,
        Cent = Cent,
        original_dim = (1, 304, 304),
        output_dim = output_dim,
        mode = surrogate_mode,
        radii_mode = "Linf",
        params = params
        )
    
    params.update({
    'gb_workers': 112,          # number of CPU workers (processes)
    'gb_threads': 1,            # threads per Gurobi model (keep small when many workers)
    'gb_presolve': 2,
    'gb_method': 1,
    'gb_opt_tol': 1e-9,
    'gb_feas_tol': 1e-9,

    # batch sizing knobs (optional)
    # 'gb_inner_batch': 16,      # hard override if you want a fixed size
    'gb_tasks_per_worker': 4,   # auto: at least 4 LPs queued per worker
    'gb_inner_batch_min': 4,    # don’t go below this
    'gb_inner_batch_max': 64,   # keep LPs small; 32–64 works well
    'gb_cap_by_ndir': True,     # also bound by N_dir // 2 for safety
    })
    



    
    provide.Provider()




def OCTA_max_exp2( projection_batch, guarantee, device, src_dir, nnv_dir ):
    
    params = {
        'projection_batch' : projection_batch,
        'guarantee': guarantee,
    }
    
    Segment = Segmentor(
        device = device,
        src_dir = src_dir,
        nnv_dir = nnv_dir,
        params = params
        )

    Segment.Mask_titles()





if __name__ == '__main__':
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    Ns = 8000
    Nsp = 200
    rank = 7999
    guarantee = 0.999
    Nt = 1500
    N_dir = 100
    threshold_normal = 1e-5
    sim_batch = 50
    trn_batch = 20
    surrogate_mode = 'CLP'
    src_dir = None
    nnv_dir = None
    projection_batch = 304*304
    

    image_names = [
        '10491.bmp',
        '10305.bmp',
        '10395.bmp',
        '10495.bmp',
        '10301.bmp',
        '10401.bmp'
        # '10372.bmp',
        # '10425.bmp',
        # '10439.bmp',
        # '10418.bmp',
        # '10399.bmp',
        # '10469.bmp',
        # '10323.bmp',
        # '10382.bmp',
        # '10486.bmp',
        # '10302.bmp',
        # '10499.bmp',
        # '10444.bmp',
        # '10343.bmp',
        # '10367.bmp'
        ]

    delta_rgb_list = [5, 10, 15] 
    

    for idx, image_name in enumerate(image_names):
        for delta_rgb in delta_rgb_list:
            
            print(f"Running: {image_name} with N_perturbed = ALL")
            time0 = time()
            OCTA_max_exp1( delta_rgb,  Nt, N_dir, image_name,
                           Ns, Nsp, rank, device,  threshold_normal,
                           sim_batch, trn_batch, surrogate_mode)
            
            time_end = time() - time0
            print(f"The runtime is {time_end}.")
            OCTA_max_exp2( projection_batch, guarantee, device, src_dir, nnv_dir )