import torch
import numpy as np
import onnxruntime as ort
import cv2
import os
import sys
import pathlib
from torchvision import transforms
import gc
sys.path.append(str(pathlib.Path(__file__).resolve().parents[2]))
root_dir = pathlib.Path(__file__).resolve().parents[2]
sys.path.append(root_dir)
from utils import plot_logits_to_mask


reach_factory_path = os.path.join(root_dir, 'Reach_Factory')
sys.path.append(reach_factory_path)
from Provision_All import Reachability_provider
from Segmentation import Segmentor   ###  This provide Naive method




def CS_exp1( tau, sigma,  Nt, N_dir,
                  Ns, Nsp, rank, device,  threshold_normal,
                  sim_batch, trn_batch, surrogate_mode):
    
    
    model_name = 'hrnet_model_SegCertify_025.onnx'

    current_dir = os.getcwd()
    model_path = os.path.join(current_dir, 'models', model_name)
    image_path = os.path.join(current_dir, 'images', image_name)

    ort_session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider'])
    
    img = cv2.imread(image_path)  # BGR format
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # To RGB
    img = cv2.resize(img, (2048, 1024))  # Resize to [W, H] = (2048, 1024)

    img = img.astype(np.float32) / 255.0  # Normalize to [0, 1]

    # HRNet normalization (ImageNet-based)
    mean_vals = np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3)
    std_vals = np.array([0.229, 0.224, 0.225]).reshape(1, 1, 3)

    img = (img - mean_vals) / std_vals  # Normalize

    img = np.transpose(img, (2, 0, 1))  # [3, H, W]
    img = img.reshape(1, 3, 1024, 2048)  # Add batch dim
    
    x_numpy = img.astype(np.float32)
    
    output = ort_session.run(None, {'input': x_numpy})
    output = torch.tensor(output[0]).to(device)
    
    # plot_logits_to_mask(output)

    output_dim = output.shape

    _, True_class_tensor = torch.max(output, dim=1)  
    True_class_tensor = True_class_tensor.squeeze(0)
    True_class = True_class_tensor.cpu().tolist()

    
    radius = sigma * torch.distributions.Normal(0, 1).icdf(torch.tensor(tau))
    de = torch.tensor(radius/std_vals).to(device)
    de = de.view(1,3,1,1)
    Cent = torch.from_numpy(img).to(device)

    
    params = {
        'sim_batch' : sim_batch,
        'Nt' : Nt,
        'N_dir' : N_dir,
        'trn_batch' : trn_batch,
        'threshold_normal' : threshold_normal,
        'Ns' : Ns,
        'Nsp' : Nsp,
        'rank' : rank,
        'perturbation' : torch.floor(radius*255),
        'True_class' : True_class,
        'class_threshold' : None,
        'image_name' : image_name,
        'input_name' : 'input'
    }
    
    
    provide = Reachability_provider(
        de = de,
        device = device,
        model = ort_session,
        Cent = Cent,
        original_dim = (3, 1024, 2048),
        output_dim = output_dim,
        mode = surrogate_mode,
        radii_mode = "L2",
        params = params
        )
    
    params.update({
    'gb_workers': 112,          # number of CPU workers (processes)
    'gb_threads': 1,            # threads per Gurobi model (keep small when many workers)
    'gb_presolve': 2,
    'gb_method': 1,
    'gb_opt_tol': 1e-9,
    'gb_feas_tol': 1e-9,

    # batch sizing knobs (optional)
    # 'gb_inner_batch': 16,      # hard override if you want a fixed size
    'gb_tasks_per_worker': 4,   # auto: at least 4 LPs queued per worker
    'gb_inner_batch_min': 4,    # don’t go below this
    'gb_inner_batch_max': 64,   # keep LPs small; 32–64 works well
    'gb_cap_by_ndir': True,     # also bound by N_dir // 2 for safety
    })
    
        
    provide.Provider()





def CS_exp2( projection_batch, guarantee, device, src_dir, nnv_dir ):
    
    params = {
        'projection_batch' : projection_batch,
        'guarantee': guarantee,
    }
    
    Segment = Segmentor(
        device = device,
        src_dir = src_dir,
        nnv_dir = nnv_dir,
        params = params
        )

    Segment.Mask_titles()





if __name__ == '__main__':
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    Ns = 920#90
    Nsp =  100
    rank = 919
    guarantee = 0.99
    Nt = 800
    N_dir = 100
    threshold_normal = 1e-5
    sim_batch = 5
    trn_batch = 10
    surrogate_mode = 'Naive'
    src_dir = None
    nnv_dir = None
    projection_batch = 256*512
    


    image_names = [
        'frankfurt_000000_000294_leftImg8bit.png',
        'frankfurt_000000_000576_leftImg8bit.png',
        'frankfurt_000000_001016_leftImg8bit.png',
        'frankfurt_000000_001236_leftImg8bit.png',
        'frankfurt_000000_001751_leftImg8bit.png',
        'frankfurt_000000_002196_leftImg8bit.png'
        # 'frankfurt_000000_002963_leftImg8bit.png',
        # 'frankfurt_000000_003025_leftImg8bit.png',
        # 'frankfurt_000000_003920_leftImg8bit.png',
        # 'frankfurt_000000_003357_leftImg8bit.png',
        # 'frankfurt_000000_007365_leftImg8bit.png',
        # 'frankfurt_000000_006589_leftImg8bit.png',
        # 'frankfurt_000000_005898_leftImg8bit.png',
        # 'frankfurt_000000_005543_leftImg8bit.png',
        # 'frankfurt_000000_004617_leftImg8bit.png',
        # 'frankfurt_000000_004617_leftImg8bit.png',
        # 'frankfurt_000000_004617_leftImg8bit.png',
        # 'frankfurt_000000_009291_leftImg8bit.png',
        # 'frankfurt_000000_008451_leftImg8bit.png',
        # 'frankfurt_000000_008206_leftImg8bit.png',
        # 'frankfurt_000000_011074_leftImg8bit.png',
        # 'frankfurt_000000_011007_leftImg8bit.png',
        # 'frankfurt_000000_010763_leftImg8bit.png',
        # 'frankfurt_000000_010351_leftImg8bit.png',
        # 'frankfurt_000000_009969_leftImg8bit.png',
        # 'frankfurt_000000_013067_leftImg8bit.png',
        # 'frankfurt_000000_013240_leftImg8bit.png',
        # 'frankfurt_000000_013240_leftImg8bit.png',
        # 'frankfurt_000000_013942_leftImg8bit.png',
        # 'frankfurt_000000_014480_leftImg8bit.png'
        ]

    sigmas = [0.25, 0.33, 0.5]
    taus = [0.75, 0.95]
    
    
    ii=0
    for image_name in image_names:
        for tau in taus:
            for sigma in sigmas:
                print(f"Running on: {image_name} for tau= {tau} and sigma= {sigma}")
            
            
                CS_exp1( tau, sigma, Nt, N_dir,
                     Ns, Nsp, rank, device,  threshold_normal,
                     sim_batch, trn_batch, surrogate_mode)
            
            
                CS_exp2( projection_batch, guarantee, device, src_dir, nnv_dir )
            
                gc.collect()
                torch.cuda.empty_cache()

