import torch
import numpy as np
from time import time
import os
import sys
from scipy.stats import beta
import matlab.engine
import scipy.io
import pathlib



class Segmentor:
    
    
    def __init__(self, device, src_dir, nnv_dir, params):
        
        self.device = device
        self.src_dir = src_dir
        self.nnv_dir = nnv_dir
        self.params = params
        
    
    
    @torch.no_grad()
    def Verify_with_surrogate(self, CH, Conf, Directions, C, output_dim):
        """
        Map-back approach, but memory-safe:
          For row chunks of CH:
            mapped = CH_chunk @ Directions   # [rows, D]
            update CHmax/CHmin over rows
          Then:
            Lb = CHmin + C - Conf
            Ub = CHmax + C + Conf
        No [M, D] tensor is ever materialized.
        """
        proj_start = time()
    
        # Shapes
        M, N_dir = CH.shape
        N_dir_d, D = Directions.shape
        assert N_dir == N_dir_d, "CH is [M, N_dir] and Directions is [N_dir, D]"
        assert Conf.shape[0] == D and C.shape[0] == D, "C and Conf must be length D"
    
        # Tuning knob (rows per chunk)
        row_bs = int(self.params.get('mapback_row_batch', 4096))
    
        # Running per-logit extrema on device
        CHmax = torch.full((D,), -torch.inf, device=CH.device)
        CHmin = torch.full((D,),  torch.inf, device=CH.device)
    
        # Process CH in row chunks
        for r0 in range(0, M, row_bs):
            r1 = min(r0 + row_bs, M)
            CH_chunk = CH[r0:r1, :]                   # [rows, N_dir]
            mapped   = CH_chunk @ Directions          # [rows, D]
            # Reduce over rows to get per-column extrema for this chunk
            rmax = mapped.max(dim=0).values           # [D]
            rmin = mapped.min(dim=0).values           # [D]
            CHmax = torch.maximum(CHmax, rmax)
            CHmin = torch.minimum(CHmin, rmin)
            del CH_chunk, mapped, rmax, rmin
    
        # Final bounds (stay as 1-D vectors)
        Lb = (CHmin + C - Conf).cpu()
        Ub = (CHmax + C + Conf).cpu()
    
        projection_time = time() - proj_start
    
        # Reshape to [H*W, n_class]
        n_batch, n_class, outHeight, outWidth = output_dim
        logits = Lb.reshape(n_batch, n_class, outHeight, outWidth).permute(0, 2, 3, 1)
        Lb_pixels = logits.reshape(1, -1, n_class).squeeze(0)
        logits = Ub.reshape(n_batch, n_class, outHeight, outWidth).permute(0, 2, 3, 1)
        Ub_pixels = logits.reshape(1, -1, n_class).squeeze(0)
    
        return Lb_pixels, Ub_pixels, projection_time


    def Verify_with_Naive(self, Conf, C, output_dim):
        
        proj_start = time()
        Lb = C.unsqueeze(1) - Conf.unsqueeze(1)
        Ub = C.unsqueeze(1) + Conf.unsqueeze(1)
        n_batch = output_dim[0]
        n_class = output_dim[1]
        outHeight = output_dim[2]
        outWidth = output_dim[3]
        logits = Lb.reshape(n_batch, n_class, outHeight, outWidth)   
        logits = logits.permute(0, 2, 3, 1)            
        Lb_pixels = logits.reshape(1, -1, n_class).squeeze(0)
        logits = Ub.reshape(n_batch, n_class, outHeight, outWidth)            
        logits = logits.permute(0, 2, 3, 1)           
        Ub_pixels = logits.reshape(1, -1, n_class).squeeze(0)
        
        projection_time = time() - proj_start
        
        
        return Lb_pixels, Ub_pixels, projection_time
        
    def Mask_titles(self):
        
        current_dir = os.getcwd()
        file_path = os.path.join(current_dir, 'CI_provider.pt')
        Data = torch.load(file_path, weights_only=False)
        
        
        if Data["radii_mode"]:
            N_perturbed = 'ALL'
        else:
            N_perturbed = len(Data['indices'])
        
        Ns = Data['Ns']
        guarantee = self.params['guarantee']
        ell = Data['rank']
        Failure_chance_of_guarantee = beta.cdf(guarantee, ell, Ns + 1 - ell)
        dimp = Data['original_dim'][1]*Data['original_dim'][2]
        mode = Data['mode']
        output_dim = Data['output_dim']
           

            
        Conf = Data['Conf'].to(self.device)  
        C = Data['C'].to(self.device)
        
        
        
        if mode == 'Naive':
            Lb_pixels, Ub_pixels, projection_time = self.Verify_with_Naive(Conf,  C,  output_dim)
        else:
            CH = Data['CH'].to(self.device)
            Directions = Data['Directions'].to(self.device)
            print('Reachability started...')
            Lb_pixels, Ub_pixels, projection_time = self.Verify_with_surrogate(CH , Conf, Directions, C, output_dim)
            print('Reachability is finished and projection is done!!')
    
    
        start_time = time()
    
        outHeight = Data['output_dim'][2]
        outWidth= Data['output_dim'][3]
        mask_dim = Data['output_dim'][1]
    
        if mask_dim == 1:
        
        # Some SSNs with two classes have one dimensional logits
        # where each class is found using a threshold on this logit
        
            classes = [[None for _ in range(outHeight)] for _ in range(outWidth)]
            class_threshold = Data['class_threshold']
            for i in range(outHeight):
                for j in range(outWidth):
                    t = i * outWidth + j  
                    lb = Lb_pixels[t].item()
                    ub = Ub_pixels[t].item()
                
                    if lb > class_threshold:
                        class_members = [1]
                    elif ub <= class_threshold:
                        class_members = [0]
                    else:
                        class_members = [0, 1]
                    classes[i][j] = class_members
                
        else:
            
        
            Lb_max, _ = torch.max(Lb_pixels, dim=1, keepdim=True)  # Shape: [720*960, 1]

            mask = Lb_max <= Ub_pixels 

            mask_np = mask.cpu().numpy()

            classes = [[[] for _ in range(outWidth)] for _ in range(outHeight)]

            for t in range(outWidth * outHeight):
                j = t % outWidth
                i = t // outWidth
                class_members = [k for k in range(mask_dim) if mask_np[t, k]]
                classes[i][j] = class_members

        
        
        # Initialize counters
        robust = 0
        nonrobust = 0
        unknown = 0
        if N_perturbed == 'ALL':
            attacked = outHeight*outWidth  # Assuming this is defined
        else:
            attacked = N_perturbed
            
        True_class = Data['True_class']

        for i in range(outHeight):
            for j in range(outWidth):
                if len(classes[i][j]) == 1:
                    if classes[i][j] == [True_class[i][j]]:
                        robust += 1
                    else:
                        nonrobust += 1
                else:
                    if True_class[i][j] in classes[i][j]:
                        unknown += 1
                    else:
                        nonrobust += 1
        
        
        Pixel_status_time = time() - start_time
        
        
        # Compute the robustness percentage
        dim_pic = outHeight*outWidth
        RV = 100 * robust / dim_pic
    
        print(f"Number of Robust pixels: {robust}")
        print(f"Number of non-Robust pixels: {nonrobust}")
        print(f"Number of unknown pixels: {unknown}")
        print(f"RV value: {RV}")
    

    
        print(f"Pr[ Pr[ RV value = {RV} ] > {guarantee}%] > {1 - Failure_chance_of_guarantee}")


        verification_runtime = Data['train_data_run_1'] + Data['trn_time1'] + sum(Data['test_data_run']) + \
                               sum(Data['res_test_time']) + Data['conformal_time'] +  \
                               projection_time + Data['Direction_Training_time'] + Data['CLP_time'] + \
                               Pixel_status_time

        print(f"The verification runtime is: {verification_runtime / 60:.2f} minutes.")


        save_dict = {
        "robust": robust,
        "nonrobust": nonrobust,
        "attacked": attacked,
        "unknown": unknown,
        "True_class": True_class,
        "classes": classes,
        "class_threshold" : Data['class_threshold'],
        "image_name" : Data['image_name'],
        "Conf": Conf,
        "Nt": Data['Nt'],
        "N_dir": Data['N_dir'],
        "de": Data['de'],
        "ell": ell,
        "Lb_pixels": Lb_pixels,
        "Ub_pixels": Ub_pixels,
        "Ns": Ns,
        "Nsp": Data['Nsp'],
        "R_star": Data['R_star'],
        "res_max": Data['res_max'],
        "RV": RV,
        "guarantee" : guarantee,
        "verification_runtime": verification_runtime,
        "threshold_normal": Data['threshold_normal'],
        "train_data_run_1": Data['train_data_run_1'],
        "trn_time1": Data['trn_time1'],
        "test_data_run": Data['test_data_run'],
        "res_test_time": Data['res_test_time'],
        "conformal_time": Data['conformal_time'],
        "projection_time": projection_time,
        "Direction_Training_time": Data['Direction_Training_time'],
        "CLP_time": Data['CLP_time'],
        "Pixel_status_time" : Pixel_status_time,
        "projection_batch" : self.params['projection_batch'],
        "trn_batch" : Data['trn_batch'],
        "sim_batch" : Data['sim_batch'],
        "perturbation" : Data['perturbation'],
        "device" : self.device,
        "mode" : mode,
        "original_dim" : Data['original_dim'],
        "output_dim" : output_dim,
        "indices" : Data['indices']
        }


        for key, val in save_dict.items():
            if isinstance(val, torch.Tensor):
                save_dict[key] = val.cpu()
            elif isinstance(val, list):
                save_dict[key] = [v.cpu() if isinstance(v, torch.Tensor) else v for v in val]
            elif isinstance(val, dict):
                save_dict[key] = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in val.items()}
    
        # save_name = f"CI_result_middle_guarantee_ReLU_relaxed_eps_{delta_rgb}_Npertubed_{N_perturbed}"+image_name+".pt"
        base_name = os.path.splitext(Data['image_name'])[0]
        
        if mode == 'Naive':
            save_name = f"CI_result_Naive_eps_{Data['perturbation']}_Npertubed_{N_perturbed}_{base_name}.pt"
        else:
            save_name = f"CI_result_CLP_eps_{Data['perturbation']}_Npertubed_{N_perturbed}_{base_name}.pt"
            
        torch.save(save_dict, save_name)
        os.remove(file_path)
        
        print('All the details are saved')