import torch
import numpy as np
import onnxruntime as ort
import cv2
from time import time
import os
import sys
from scipy.stats import beta
import matlab.engine
import scipy.io
import pathlib
import gc
import gurobipy as gp
from gurobipy import GRB
from concurrent.futures import ProcessPoolExecutor, as_completed


root_dir = pathlib.Path(__file__).resolve().parents[1]
training_factory_path = os.path.join(root_dir, 'Training_Factory')
sys.path.append(training_factory_path)

from Direction_training import compute_directions


# ----------------- Parallel LP worker (module scope) -----------------


def _solve_clp_batch_worker(H, V_batch, params):
    """
    Solve the CLP for a single batch in its own process.
    Inputs:
      - H: np.ndarray [t, Ndir]
      - V_batch: np.ndarray [B, Ndir]
      - params: dict with Gurobi settings
    Returns:
      - Yhat_batch: np.ndarray [B, Ndir]
      - status_ok: bool
    """


    t, Ndir = H.shape
    B, Ndir2 = V_batch.shape
    assert Ndir == Ndir2, "H and V_batch must match in directional dimension."

    threads = int(params.get('gb_threads', 1))
    preslv  = int(params.get('gb_presolve', 2))
    method  = int(params.get('gb_method', 1))
    opttol  = float(params.get('gb_opt_tol', 1e-9))
    feastol = float(params.get('gb_feas_tol', 1e-9))

    Yhat_batch = np.empty_like(V_batch)

    # Build a single model for all B points in this batch
    model = gp.Model()
    model.Params.OutputFlag     = 0
    model.Params.Method         = method
    model.Params.OptimalityTol  = opttol
    model.Params.FeasibilityTol = feastol
    model.Params.Threads        = threads
    model.Params.Presolve       = preslv

    alpha_vars, t_vars = [], []
    for i in range(B):
        alpha = model.addMVar(shape=t, lb=0.0, name=f"alpha_{i}")
        tvar  = model.addVar(lb=0.0, name=f"t_{i}")
        model.addConstr(alpha.sum() == 1.0, name=f"sum_alpha_{i}")
        for k in range(Ndir):
            model.addConstr(alpha @ H[:, k] - tvar <= V_batch[i, k], name=f"pos_{i}_{k}")
            model.addConstr(-(alpha @ H[:, k]) - tvar <= -V_batch[i, k], name=f"neg_{i}_{k}")
        alpha_vars.append(alpha)
        t_vars.append(tvar)

    model.setObjective(gp.quicksum(t_vars), GRB.MINIMIZE)
    model.optimize()

    if model.Status == GRB.OPTIMAL:
        for i in range(B):
            alpha_star = alpha_vars[i].X
            Yhat_batch[i, :] = H.T @ alpha_star
        return Yhat_batch, True
    else:
        # Fallback: return input V_batch if LP didn’t finish optimally
        return V_batch, False




class Reachability_provider:
    
    
    def __init__(self, model, LB, de, indices, original_dim, output_dim, device, mode, params):
        
        self.de = de
        self.indices = indices
        self.device = device
        self.model = model
        self.LB = LB
        self.original_dim = original_dim
        self.output_dim = output_dim
        self.mode = mode
        self.params = params
        
        
        
        
    def mat_generator_no_third(self, repeat, values):
        
        N_perturbed = len(self.indices)
        
        Matrix = torch.zeros( (repeat, *self.original_dim), device=values.device, dtype=values.dtype)
        
        t = 0
        for c in range(self.original_dim[0]):
            for i in range(N_perturbed):
                row, col = self.indices[i]
                Matrix[:,c,row, col] = values[:,t]
                t += 1
        return Matrix
    
    
    def Func(self, x):
        name = self.params['input_name']
        batch_size = self.params['sim_batch']
        x = x.to(torch.float16)  # Use half precision
        x_numpy = x.cpu().numpy().astype(np.float32)
        results = []
        for i in range(0, x_numpy.shape[0], batch_size):
            batch = x_numpy[i:i+batch_size]
            #with autocast():  # Automatically use mixed precision
            with torch.amp.autocast('cuda'):
                output = self.model.run(None, {name: batch})
            results.append(torch.tensor(output[0]).to(self.device))
        return torch.cat(results, dim=0)
    
    
    def generate_data_chunk(self, repeat, LBs):
        
        N_perturbed = len(self.indices)
        nc = self.original_dim[0]
        
        """ Function to generate the training data for one instance in parallel. """
        Rand = torch.rand(repeat, nc * N_perturbed).to(self.device)
        Rand_matrix = self.mat_generator_no_third(repeat, Rand)
        d_at = self.de * Rand_matrix
        Inp = LBs + d_at
        Inp_tensor = Inp.float()

        with torch.no_grad():
            out = self.Func(Inp_tensor)
    
        return out, Rand
    
    
    def generate_data(self, repeat, SEED):
        
        torch.manual_seed(SEED)

        t0 = time()


        LBs = self.LB.repeat(repeat,1,1,1)
        Y, X = self.generate_data_chunk(repeat, LBs)


        runtime = time() - t0


        Y = Y.view(Y.shape[0], -1)
        
        return Y, X, runtime
  

    def CLP(self, CH: torch.Tensor, dYV: torch.Tensor):
        """
        Parallel CLP:
          - Keeps each (multi-point) batch as a SINGLE LP (same formulation).
          - Splits dYV rows into many small batches (to make each LP fast).
          - Distributes those LPs across multiple CPU workers (processes).
          - Reassembles predictions preserving order and your expected chunking.
        Tunables via self.params:
          gb_workers: int, number of processes (default: os.cpu_count()).
          gb_threads: int, threads per Gurobi model (default: 1).
          gb_presolve: int, default 2.
          gb_method: int, default 1.
          gb_opt_tol: float, default 1e-9.
          gb_feas_tol: float, default 1e-9.
          # Batch sizing knobs:
          gb_inner_batch: int, hard override for rows/LP (default: auto).
          gb_tasks_per_worker: int, default 4.
          gb_inner_batch_min: int, default 4.
          gb_inner_batch_max: int, default 64.
          gb_cap_by_ndir: bool, default True (caps to N_dir//2).
        """
    
        device = dYV.device
        dtype  = dYV.dtype
    
        # Host copies for workers
        H = CH.detach().cpu().double().numpy()   # [t, Ndir]
        V = dYV.detach().cpu().double().numpy()  # [M, Ndir]
        t, Ndir = H.shape
        M, Ndir2 = V.shape
        assert Ndir == Ndir2, "CH and dYV must have same directional dimension."
    
        # ----- expected training chunking (used by Shape_residual) -----
        Nt        = self.params['Nt']
        N_dir     = self.params['N_dir']
        NumChunks = Nt // N_dir
        remainder = Nt % N_dir
        chunk_sizes_expected = [N_dir] * NumChunks + ([remainder] if remainder != 0 else [])
        limit = NumChunks // 2
        chunk_sizes_expected = chunk_sizes_expected[:limit+1]
    
        # ----- decide split plan for final Pred_list reassembly -----
        if sum(chunk_sizes_expected) == M:
            chunk_sizes = chunk_sizes_expected            # training case
        else:
            chunk_sizes = [M]                             # single test batch case
    
        # ----- parallelism across LPs -----
        max_workers = int(self.params.get('gb_workers', 0))
        if max_workers <= 0:
            cpu_cnt = os.cpu_count() or 8
            max_workers = cpu_cnt
    
        # ----- choose inner batch size (rows per LP) -----
        # 1) explicit override
        inner_batch_size = int(self.params.get('gb_inner_batch', 0))
    
        if inner_batch_size <= 0:
            # 2) auto size: aim for >= k tasks per worker to keep pool busy
            k_tasks_per_worker = int(self.params.get('gb_tasks_per_worker', 4))
            k_tasks_per_worker = 1 if k_tasks_per_worker < 1 else k_tasks_per_worker
            target_batches = max_workers * k_tasks_per_worker
            target_batches = 1 if target_batches < 1 else target_batches
            inner_batch_size = int(np.ceil(M / target_batches))
    
            # 3) clamp to avoid too tiny/huge LPs
            min_b = int(self.params.get('gb_inner_batch_min', 4))
            max_b = int(self.params.get('gb_inner_batch_max', 64))
            if inner_batch_size < min_b:
                inner_batch_size = min_b
            if inner_batch_size > max_b:
                inner_batch_size = max_b
    
        # Optional: cap by N_dir (each point contributes N_dir constraints)
        if bool(self.params.get('gb_cap_by_ndir', True)):
            cap = max(1, N_dir // 2)
            if inner_batch_size > cap:
                inner_batch_size = cap
    
        # Build list of LP tasks (each task is V[start:end])
        tasks = []
        for start in range(0, M, inner_batch_size):
            end = min(start + inner_batch_size, M)
            tasks.append((start, end))
    
        # Gurobi config per model
        threads_per_model = int(self.params.get('gb_threads', 1))
        preslv  = int(self.params.get('gb_presolve', 2))
        method  = int(self.params.get('gb_method', 1))
        opttol  = float(self.params.get('gb_opt_tol', 1e-9))
        feastol = float(self.params.get('gb_feas_tol', 1e-9))
    
        worker_params = {
            'gb_threads': threads_per_model,
            'gb_presolve': preslv,
            'gb_method': method,
            'gb_opt_tol': opttol,
            'gb_feas_tol': feastol,
        }
    
        # ---- Solve all LP-batches (parallel or serial depending on workers) ----
        Yhat = np.empty_like(V)
        statuses = []
    
        if len(tasks) == 1 or max_workers == 1:
            # serial fallback
            for start, end in tasks:
                V_batch = V[start:end, :]
                Yhat_batch, ok = _solve_clp_batch_worker(H, V_batch, worker_params)
                Yhat[start:end, :] = Yhat_batch
                statuses.append(ok)
        else:
            # multiprocessing: each task builds/solves one LP
            with ProcessPoolExecutor(max_workers=max_workers) as ex:
                futures = {}
                for start, end in tasks:
                    V_batch = V[start:end, :]
                    fut = ex.submit(_solve_clp_batch_worker, H, V_batch, worker_params)
                    futures[fut] = (start, end)
                for fut in as_completed(futures):
                    start, end = futures[fut]
                    try:
                        Yhat_batch, ok = fut.result()
                        Yhat[start:end, :] = Yhat_batch
                        statuses.append(ok)
                    except Exception:
                        # Fallback: identity for failed slice
                        Yhat[start:end, :] = V[start:end, :]
                        statuses.append(False)
    
        # Optional summary
        n_ok = sum(1 for s in statuses if s)
        print(f"[CLP] Solved {n_ok}/{len(tasks)} LP-batches "
              f"(workers={max_workers}, threads/model={threads_per_model}, "
              f"batch_size={inner_batch_size}).")
    
        # Back to torch + preserve your chunked return
        Yhat_t = torch.tensor(Yhat, dtype=dtype, device=device)
    
        Pred_list = []
        idx = 0
        for clen in chunk_sizes:
            Pred_list.append(Yhat_t[idx:idx+clen, :])
            idx += clen
        return Pred_list    
    
    
    
    
    
    def Shape_residual(self):
        
        Nt = self.params['Nt']
        N_dir = self.params['N_dir']
        Y_dir, _, train_data_run_1_1 = self.generate_data(N_dir, 0)
        
        if self.mode == 'Naive':
            Directions = None
            Direction_Training_time = 0
        else:
            Directions, Direction_Training_time = compute_directions(Y_dir, self.device, self.params['trn_batch'])
            Directions = torch.stack([d.squeeze(-1) for d in Directions])
            

        chunck_size = N_dir
        Num_chunks = Nt // chunck_size
        remainder = Nt % chunck_size

        chunk_sizes = [chunck_size] * Num_chunks
        if remainder != 0:
            chunk_sizes.append(remainder)

        train_data_run_1_l = []
        train_data_run_1_l.append(train_data_run_1_1)
        
        C_l = []
        YV_l = []
        Y_l = []
        
        C =  Y_dir.mean(dim=0) 
        C_l.append(C)
        
        if self.mode == 'Naive':
            YV_l = []
        else:              
            YV_l.append(Y_dir @ Directions.T)
            
        Y_dir = Y_dir.cpu()
        Y_l.append(Y_dir)
        gc.collect()
        torch.cuda.empty_cache()
        
        limit = Num_chunks // 2
        for nc, curr_len in enumerate(chunk_sizes[1:], start=1):
            
            Y_dir, _, tYX = self.generate_data(curr_len, nc)
            train_data_run_1_l.append(tYX)

            C =  Y_dir.mean(dim=0)
            C_l.append(C)
            
            if self.mode != 'Naive':
                YV_l.append(Y_dir @ Directions.T)
                if nc <= limit:
                    Y_dir = Y_dir.cpu()
                    Y_l.append(Y_dir)
                    
            else:
                Y_dir = Y_dir.cpu()
                Y_l.append(Y_dir)

            gc.collect()
            torch.cuda.empty_cache()
            
            
        stackedC = torch.stack(C_l, dim=0)
        C = stackedC.mean(dim=0)
        
        CLP_time = 0
        CH = None
        if self.mode != 'Naive':    
            CV = C @ Directions.T
            dYV = torch.cat(YV_l[:limit+1] , dim=0) - CV.unsqueeze(0)
            CH = torch.cat(YV_l[limit+1:]  , dim=0) - CV.unsqueeze(0)
            start_time = time()
            Pred = self.CLP(CH, dYV)
            CLP_time = time() - start_time
            chunk_sizes = chunk_sizes[:limit+1]
            del dYV, CV, C_l, YV_l
            gc.collect()
            torch.cuda.empty_cache()
            
        train_data_run_1 = sum(train_data_run_1_l)
        
        
        trn_time1_l = []
        res_max_l = []

        for nc, curr_len in enumerate(chunk_sizes):
            Y = Y_l[nc].to(self.device)
            t0 = time()
            
            if self.mode == 'Naive':
                with torch.no_grad():
                    approx_Y = C.unsqueeze(0)
                    pred = None
            else:
                with torch.no_grad():
                    pred = Pred[nc]
                    approx_Y = pred @ Directions  + C.unsqueeze(0)  # shape: same as Y

            residuals = (Y - approx_Y).abs()
            res_max_l.append( residuals.max(dim=0).values)
            trn_time1_l.append( time() - t0)
            del Y, approx_Y, pred, residuals
            gc.collect()
            torch.cuda.empty_cache()
        
        del Y_l
        t0 = time()
        stacked_max = torch.stack(res_max_l, dim=0)
        res_max = stacked_max.max(dim=0).values
        tn = self.params['threshold_normal']
        res_max[res_max < tn ] = tn
        trn_time1_l.append( time()-t0 )
        trn_time1 = sum(trn_time1_l)
        
        return res_max, CH, C, Directions, trn_time1, train_data_run_1, Direction_Training_time, CLP_time

    
    def CI_surrogate(self, CH, C, res_max, Directions):
        
        Ns = self.params['Ns']
        Nsp = self.params['Nsp']
        Nt = self.params['Nt']
        N_dir = self.params['N_dir']
        
        if Nt % N_dir == 0:
            seed_loc = Nt // N_dir
        else:
            seed_loc = 1 +  ( Nt // N_dir )

        ell = self.params['rank']
        
        
        thelen = min(Ns, Nsp)
        if Ns > thelen:
            chunck_size = thelen
            Num_chunks = Ns // chunck_size
            remainder = Ns % chunck_size
        else:
            chunck_size = Ns
            Num_chunks = 1
            remainder = 0

        chunk_sizes = [chunck_size] * Num_chunks
        if remainder != 0:
            chunk_sizes.append(remainder)

        Rs = torch.zeros(Ns, requires_grad=False)
        ind = 0
        test_data_run = []
        res_test_time = []



        for nc, curr_len in enumerate(chunk_sizes):
            
            Y_test, X_test_nc, tst_run = self.generate_data(curr_len, seed_loc+nc+1)
            test_data_run.append( tst_run )
            
            t1 = time()
            if self.mode == 'Naive':
                pred = C.unsqueeze(0)
            else:
                with torch.no_grad():
                    mapped = (Y_test - C.unsqueeze(0)) @ Directions.T
                    proj = self.CLP(CH, mapped)[0]
                    pred = proj @ Directions  + C.unsqueeze(0)  # shape: (dim2, curr_len)
            res_tst = (Y_test - pred).abs()
            vals = torch.max(res_tst / res_max.unsqueeze(0), dim=1).values
            res_test_time.append(time() - t1)
            del Y_test, pred, res_tst
            gc.collect()
            torch.cuda.empty_cache()
            Rs[ind:ind + curr_len] = vals 

            ind += curr_len
        
        t0 = time()

        with torch.no_grad():
            Rs_sorted = torch.sort(Rs).values
            R_star = Rs_sorted[ell-1]  # Assuming `ell` is defined
            Conf = R_star * res_max

        conformal_time = time() - t0
            
        return Conf, R_star, conformal_time, res_test_time, test_data_run
      
    
    def Provider(self):
        
        
        res_max, CH, C, Directions, trn_time1, train_data_run_1, Direction_Training_time, CLP_time = self.Shape_residual()
        
        Conf, R_star, conformal_time, res_test_time, test_data_run = self.CI_surrogate(CH, C, res_max, Directions)
        
        
        save_dict = {
            "Conf": Conf,
            "Directions" : Directions,
            "C" : C,
            "CH" : CH,
            "R_star": R_star,
            "res_max": res_max,
            "train_data_run_1": train_data_run_1,
            "trn_time1": trn_time1,
            "test_data_run": test_data_run,
            "res_test_time": res_test_time,
            "conformal_time": conformal_time,
            "Direction_Training_time": Direction_Training_time,
            "CLP_time": CLP_time,
            "radii_mode" : None,
            "Ns" : self.params['Ns'],
            "Nsp" : self.params['Nsp'],
            "rank" : self.params['rank'],
            "Nt" : self.params['Nt'],
            "N_dir" : self.params['N_dir'],
            "threshold_normal" : self.params['threshold_normal'],
            "perturbation" : self.params['perturbation'],
            "True_class" : self.params['True_class'],
            "class_threshold" : self.params['class_threshold'],
            "image_name" : self.params['image_name'],
            "de" : self.de,
            "indices" : self.indices,
            "original_dim" : self.original_dim,
            "output_dim" : self.output_dim,
            "mode" : self.mode,
            "trn_batch" : self.params['trn_batch'],
            "sim_batch" : self.params['sim_batch']
            }

        
        for key, val in save_dict.items():
            if isinstance(val, torch.Tensor):
                save_dict[key] = val.cpu()
            elif isinstance(val, list):
                save_dict[key] = [v.cpu() if isinstance(v, torch.Tensor) else v for v in val]
            elif isinstance(val, dict):
                save_dict[key] = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in val.items()}
        
        
        save_name = 'CI_provider.pt'
        torch.save(save_dict, save_name)
    
        
    
    
    
    
    
    