import os
from typing import Optional, Tuple, List, Union, Callable
from tqdm import tqdm

import math
import argparse

import time
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from tqdm import trange
import cv2 
import onnx
import itertools
from collections import defaultdict

from cifar10_resnet.resnet import resnet2b, resnet4b

torch.set_default_dtype(torch.float64)
torch.seed()

cifar10_mean = (0.4914, 0.4822, 0.4465)  # np.mean(train_set.train_data, axis=(0,1,2))/255
cifar10_std = (0.2471, 0.2435, 0.2616)  # np.std(train_set.train_data, axis=(0,1,2))/255
mu = torch.tensor(cifar10_mean)
std = torch.tensor(cifar10_std)
def normalize(X):
    return (X - mu)/std

# import warnings

# warnings.filterwarnings("ignore")
# os.environ["PYTHONWARNINGS"] = "ignore"


#from torch.profiler import profile, record_function, ProfilerActivity

from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm, PerturbationLinear

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
script_dir = os.path.dirname(os.path.realpath(__file__))





class ResNetModel(nn.Module):
    def __init__(self, model_name: str, tile_size: int = 32):
        super(ResNetModel, self).__init__()
        if model_name == "resnet2b":
            self.model = resnet2b()
        elif model_name == "resnet4b":
            self.model = resnet4b()
        else:
            raise ValueError("Invalid model name. Choose 'resnet2b' or 'resnet4b'.")
        self.tile_size = tile_size
        self.model.eval()
        mu_full = mu.repeat_interleave(tile_size * tile_size)
        std_full = std.repeat_interleave(tile_size * tile_size)
        self.register_buffer("mu", mu_full)
        self.register_buffer("std", std_full)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = (x - self.mu) / self.std
        x = x.reshape(x.shape[0], 3, self.tile_size, self.tile_size)
        return self.model(x)


class NeRF(nn.Module):
    r"""
    Neural radiance fields module.
    """

    def __init__(
        self,
        d_input: int = 3,
        n_layers: int = 8,
        d_filter: int = 256,
        skip: Tuple[int] = (4,),
        d_viewdirs: Optional[int] = None,
    ):
        super().__init__()
        self.d_input = d_input
        self.skip = skip
        self.act = nn.functional.relu
        self.d_viewdirs = d_viewdirs

        # Create model layers
        self.layers = nn.ModuleList(
            [nn.Linear(self.d_input, d_filter)]
            + [
                (
                    nn.Linear(d_filter + self.d_input, d_filter)
                    if i in skip
                    else nn.Linear(d_filter, d_filter)
                )
                for i in range(n_layers - 1)
            ]
        )

        # Bottleneck layers
        if self.d_viewdirs is not None:
            # If using viewdirs, split alpha and RGB
            self.alpha_out = nn.Linear(d_filter, 1)
            self.rgb_filters = nn.Linear(d_filter, d_filter)
            self.branch = nn.Linear(d_filter + self.d_viewdirs, d_filter // 2)
            self.output = nn.Linear(d_filter // 2, 3)
        else:
            # If no viewdirs, use simpler output
            self.output = nn.Linear(d_filter, 4)

    def forward(
        self, x: torch.Tensor, viewdirs: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        r"""
        Forward pass with optional view direction.
        """

        # Cannot use viewdirs if instantiated with d_viewdirs = None
        if self.d_viewdirs is None and viewdirs is not None:
            raise ValueError("Cannot input x_direction if d_viewdirs was not given.")

        # Apply forward pass up to bottleneck
        x_input = x
        for i, layer in enumerate(self.layers):
            x = self.act(layer(x))
            if i in self.skip:
                x = torch.cat([x, x_input], dim=-1)

        # Apply bottleneck
        if self.d_viewdirs is not None:
            # Split alpha from network output
            alpha = self.alpha_out(x)

            # Pass through bottleneck to get RGB
            x = self.rgb_filters(x)
            x = torch.concat([x, viewdirs], dim=-1)
            x = self.act(self.branch(x))
            x = self.output(x)

            # Concatenate alphas to output
            x = torch.concat([x, alpha], dim=-1)
        else:
            # Simple output
            x = self.output(x)
        return x

class PositionalEncoderEnv(nn.Module):
    r"""
    Sine-cosine positional encoder for environmental parameters (hue and temperature).
    """

    def __init__(self, d_input: int, n_freqs: int, log_space: bool = False):
        super().__init__()
        self.d_input = d_input
        self.n_freqs = n_freqs
        self.log_space = log_space
        self.d_output = d_input * (1 + 2 * self.n_freqs)
        # self.d_output = self.d_input
        self.embed_fns = [lambda x: x]

        # Define frequencies in either linear or log scale
        if self.log_space:
            freq_bands = 2.0 ** torch.linspace(0.0, self.n_freqs - 1, self.n_freqs)
        else:
            freq_bands = torch.linspace(
                2.0**0.0, 2.0 ** (self.n_freqs - 1), self.n_freqs
            )

        self.register_buffer("freq_bands", freq_bands)
        # Alternate sin and cos
        for freq in freq_bands:
            self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
            self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))

    def forward(self, x) -> torch.Tensor:
        r"""
        Apply positional encoding to input.
        """
        x_times_freqs = x[..., None] * self.freq_bands
        sin_values = torch.sin(x_times_freqs)
        cos_values = torch.cos(x_times_freqs)

        # An additional dimension to separate sin and cos
        fn_x = torch.stack([sin_values, cos_values], dim=-1)
        fn_x = fn_x.reshape(*x_times_freqs.shape[:-1], -1)

        # Concatenate in the order of sin(x*f), cos(x*f), ...
        fn_x = fn_x.transpose(-1, -2).reshape(*x.shape[:-1], -1)
        return torch.concat([x, fn_x], dim=-1)


class PositionalEncoder(nn.Module):
    r"""
    Sine-cosine positional encoder for input points.
    """

    def __init__(self, d_input: int, n_freqs: int, log_space: bool = False):
        super().__init__()
        self.d_input = d_input
        self.n_freqs = n_freqs
        self.log_space = log_space
        self.d_output = d_input * (1 + 2 * self.n_freqs)
        self.embed_fns = [lambda x: x]

        # Define frequencies in either linear or log scale
        if self.log_space:
            freq_bands = 2.0 ** torch.linspace(0.0, self.n_freqs - 1, self.n_freqs)
        else:
            freq_bands = torch.linspace(
                2.0**0.0, 2.0 ** (self.n_freqs - 1), self.n_freqs
            )
        self.register_buffer("freq_bands", freq_bands)
        #self.freq_bands=freq_bands

        # Alternate sin and cos
        for freq in freq_bands:
            self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
            self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))

    # def forward(self, x) -> torch.Tensor:
    #     r"""
    #     Apply positional encoding to input.
    #     """
    #     return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)
    def forward(self, x) -> torch.Tensor:
        r"""
        Apply positional encoding to input.
        """
        x_times_freqs = x[..., None] * self.freq_bands
        sin_values = torch.sin(x_times_freqs)
        cos_values = torch.cos(x_times_freqs)

        # An additional dimension to separate sin and cos
        fn_x = torch.stack([sin_values, cos_values], dim=-1)
        fn_x = fn_x.reshape(*x_times_freqs.shape[:-1], -1)

        # Concatenate in the order of sin(x*f), cos(x*f), ...
        fn_x = fn_x.transpose(-1, -2).reshape(*x.shape[:-1], -1)
        return torch.concat([x, fn_x], dim=-1)


class RenderModel(nn.Module):
    def __init__(self,input_type, total_height,total_width,start_height,start_width,end_height,end_width,focal_x,focal_y,\
                        xyzrpy,env,near,far,distance_to_infinity,n_samples,perturb,inverse_depth,\
                        kwargs_sample_stratified,n_samples_hierarchical,kwargs_sample_hierarchical,chunksize,\
                        encode,encode_env,encode_viewdirs,coarse_model,fine_model,\
                        raw_noise_std=0.0, print_flag=False
                        ):
        super(RenderModel, self).__init__()
        self.input_type=input_type
        self.total_height=total_height
        self.total_width=total_width
        self.start_height=start_height
        self.end_width=end_width
        self.end_height=end_height
        self.start_width=start_width
        self.focal_x=focal_x
        self.focal_y=focal_y

        self.init_xyzrpy(xyzrpy)
        self.init_env(env)
        self.near,self.far=near,far
        self.distance_to_infinity=distance_to_infinity
        self.n_samples=n_samples
        self.perturb=perturb
        self.inverse_depth=inverse_depth
        self.kwargs_sample_stratified={} if kwargs_sample_stratified is None else kwargs_sample_stratified
        self.n_samples_hierarchical=n_samples_hierarchical
        self.kwargs_sample_hierarchical={}  if kwargs_sample_hierarchical is None else kwargs_sample_hierarchical
        self.chunksize=chunksize
        self.t_rand=torch.rand([n_samples])

        self.encode=encode
        self.encode_env=encode_env
        self.encode_viewdirs=encode_viewdirs
        self.model=coarse_model
        self.fine_model=fine_model 

        self.raw_noise_std=raw_noise_std
        if (start_height is None) or (end_height is None) or (start_width is None) or (end_width is None):
            self.noise_rand=None
        else:
            self.noise_rand=torch.randn((end_height-start_height)*(end_width-start_width),n_samples) * raw_noise_std


        self.print_flag=print_flag

    def update_height_and_width(self,start_height,end_height,start_width,end_width):
        self.start_height=start_height
        self.end_height=end_height
        self.start_width=start_width
        self.end_width=end_width
        if (start_height is None) or (end_height is None) or (start_width is None) or (end_width is None):
            self.noise_rand=None
        else:
            self.noise_rand=torch.randn((end_height-start_height)*(end_width-start_width),n_samples) * raw_noise_std

    def get_extrinsic_matrix(self,xyzrpy):
        x=xyzrpy[:,0:1]
        y=xyzrpy[:,1:2]
        z=xyzrpy[:,2:3]
        gamma = xyzrpy[:,3:4]
        beta = xyzrpy[:,4:5]
        alpha = xyzrpy[:,5:6]

        R00 = torch.cos(alpha)*torch.cos(beta)
        R01 = torch.cos(alpha)*torch.sin(beta)*torch.sin(gamma)-torch.sin(alpha)*torch.cos(gamma)
        R02 = torch.cos(alpha)*torch.sin(beta)*torch.cos(gamma)+torch.sin(alpha)*torch.sin(gamma)
        R03 = x

        R10 = torch.sin(alpha)*torch.cos(beta)
        R11 = torch.sin(alpha)*torch.sin(beta)*torch.sin(gamma)+torch.cos(alpha)*torch.cos(gamma)
        R12 = torch.sin(alpha)*torch.sin(beta)*torch.cos(gamma)-torch.cos(alpha)*torch.sin(gamma)
        R13 = y

        R20 = -torch.sin(beta)
        R21 = torch.cos(beta)*torch.sin(gamma)
        R22 = torch.cos(beta)*torch.cos(gamma)
        R23 = z

        # Concatenate the rotation matrix components and translation
        R_row0 = torch.cat([R00, R01, R02, x], dim=1).unsqueeze(1)  # First row (unsqueeze to add extra dimension)
        R_row1 = torch.cat([R10, R11, R12, y], dim=1).unsqueeze(1)  # Second row
        R_row2 = torch.cat([R20, R21, R22, z], dim=1).unsqueeze(1)  # Third row
        R_row3 = torch.cat([torch.zeros_like(x), torch.zeros_like(x), torch.zeros_like(x), torch.ones_like(x)], dim=1).unsqueeze(1)  # Fourth row

        # Use torch.cat to concatenate along the second dimension (row-wise)
        extrinsic_matrices = torch.cat([R_row0, R_row1, R_row2, R_row3], dim=1)

        return extrinsic_matrices

    def init_xyzrpy(self,xyzrpy):
        self.x=float(xyzrpy[0])
        self.y=float(xyzrpy[1])
        self.z=float(xyzrpy[2])
        self.roll=float(xyzrpy[3])
        self.pitch=float(xyzrpy[4])
        self.yaw=float(xyzrpy[5])

        self.dist_to_object=float(np.linalg.norm(xyzrpy[:2],ord=2))
        self.initial_angle=float(np.arctan2(xyzrpy[1], xyzrpy[0]))
        self.offset_yaw=float(xyzrpy[5])
        self.current_angle=self.initial_angle

    def init_env(self,env):
        self.hue=float(env[0])
        self.satur=float(env[1])

    def update_angle(self,angle):
        self.current_angle=float(angle + self.initial_angle)

    def generate_camera_positions_around_object_torch(self,angle):
        z=self.z*torch.ones_like(angle).to(angle.device)
        current_angle=self.current_angle

        # Calculate the new x and y coordinates based on the updated angle
        x = self.dist_to_object * torch.cos(angle + self.initial_angle).to(angle.device)
        y = self.dist_to_object * torch.sin(angle + self.initial_angle).to(angle.device)

        yaw=current_angle- self.initial_angle + self.offset_yaw
        yaw=yaw*torch.ones_like(angle).to(angle.device)
        pitch = torch.zeros_like(angle).to(angle.device)  # Keeping pitch at 0
        roll = self.roll*torch.ones_like(angle).to(angle.device)  # Roll is constant

        positions_xyzrpy = torch.cat([x, y, z, roll, pitch, yaw], dim=-1)
        return positions_xyzrpy
    
    def generate_xyzrpy_torch(self,input):
        if self.input_type=="x":
            x=input
            y=self.y*torch.ones_like(input).to(input.device)
            z=self.z*torch.ones_like(input).to(input.device)

            roll=self.roll*torch.ones_like(input).to(input.device)
            pitch=self.pitch*torch.ones_like(input).to(input.device)
            yaw=self.yaw*torch.ones_like(input).to(input.device)

            hue=self.hue*torch.ones_like(input).to(input.device)
            satur=self.satur*torch.ones_like(input).to(input.device)
        
        elif self.input_type=="y":
            x=self.x*torch.ones_like(input).to(input.device)
            y=input
            z=self.z*torch.ones_like(input).to(input.device)

            roll=self.roll*torch.ones_like(input).to(input.device)
            pitch=self.pitch*torch.ones_like(input).to(input.device)
            yaw=self.yaw*torch.ones_like(input).to(input.device)

            hue=self.hue*torch.ones_like(input).to(input.device)
            satur=self.satur*torch.ones_like(input).to(input.device)
        elif self.input_type=="z":
            x=self.x*torch.ones_like(input).to(input.device)
            y=self.y*torch.ones_like(input).to(input.device)
            z=input

            roll=self.roll*torch.ones_like(input).to(input.device)
            pitch=self.pitch*torch.ones_like(input).to(input.device)
            yaw=self.yaw*torch.ones_like(input).to(input.device)

            hue=self.hue*torch.ones_like(input).to(input.device)
            satur=self.satur*torch.ones_like(input).to(input.device)
        elif self.input_type=="roll":
            x=self.x*torch.ones_like(input).to(input.device)
            y=self.y*torch.ones_like(input).to(input.device)
            z=self.z*torch.ones_like(input).to(input.device)

            roll=input
            pitch=self.pitch*torch.ones_like(input).to(input.device)
            yaw=self.yaw*torch.ones_like(input).to(input.device)

            hue=self.hue*torch.ones_like(input).to(input.device)
            satur=self.satur*torch.ones_like(input).to(input.device)
        elif self.input_type=="yaw":
            x=self.x*torch.ones_like(input).to(input.device)
            y=self.y*torch.ones_like(input).to(input.device)
            z=self.z*torch.ones_like(input).to(input.device)

            roll=self.roll*torch.ones_like(input).to(input.device)
            pitch=self.pitch*torch.ones_like(input).to(input.device)
            yaw=input

            hue=self.hue*torch.ones_like(input).to(input.device)
            satur=self.satur*torch.ones_like(input).to(input.device)

        elif self.input_type=="hue":
            x=self.x*torch.ones_like(input).to(input.device)
            y=self.y*torch.ones_like(input).to(input.device)
            z=self.z*torch.ones_like(input).to(input.device)

            roll=self.roll*torch.ones_like(input).to(input.device)
            pitch=self.pitch*torch.ones_like(input).to(input.device)
            yaw=self.yaw*torch.ones_like(input).to(input.device)

            hue=input
            satur=self.satur*torch.ones_like(input).to(input.device)

        elif self.input_type=="satur":
            x=self.x*torch.ones_like(input).to(input.device)
            y=self.y*torch.ones_like(input).to(input.device)
            z=self.z*torch.ones_like(input).to(input.device)

            roll=self.roll*torch.ones_like(input).to(input.device)
            pitch=self.pitch*torch.ones_like(input).to(input.device)
            yaw=self.yaw*torch.ones_like(input).to(input.device)

            hue=self.hue*torch.ones_like(input).to(input.device)
            satur=input
            
        elif self.input_type=="ry":
            x=self.x*torch.ones_like(input[..., 0:1]).to(input.device)
            y=self.y*torch.ones_like(input[..., 0:1]).to(input.device)
            z=self.z*torch.ones_like(input[..., 0:1]).to(input.device)

            roll=input[..., 0:1]
            pitch=self.pitch*torch.ones_like(input[..., 0:1]).to(input.device)
            yaw=input[..., 1:2]

            hue=self.hue*torch.ones_like(input[..., 0:1]).to(input.device)
            satur=self.satur*torch.ones_like(input[..., 0:1]).to(input.device)

        elif self.input_type=="xyz":
            x=input[..., 0:1]
            y=input[..., 1:2]
            z=input[..., 2:3]

            roll=self.roll*torch.ones_like(input[..., 0:1]).to(input.device)
            pitch=self.pitch*torch.ones_like(input[..., 0:1]).to(input.device)
            yaw=self.yaw*torch.ones_like(input[..., 0:1]).to(input.device)

            hue=self.hue*torch.ones_like(input[..., 0:1]).to(input.device)
            satur=self.satur*torch.ones_like(input[..., 0:1]).to(input.device)

        elif self.input_type=="rot":
            x=input[..., 0:1]
            z=input[..., 1:2]
            pitch=input[..., 2:3]

            roll=self.roll*torch.ones_like(input[..., 0:1]).to(input.device)
            yaw=self.yaw*torch.ones_like(input[..., 0:1]).to(input.device)
            y=self.y*torch.ones_like(input[..., 0:1]).to(input.device)

            hue=self.hue*torch.ones_like(input[..., 0:1]).to(input.device)
            satur=self.satur*torch.ones_like(input[..., 0:1]).to(input.device)


        elif self.input_type=="xyzry":
            x=input[..., 0:1]
            y=input[..., 1:2]
            z=input[..., 2:3]

            roll=input[..., 3:4]
            pitch=self.pitch*torch.ones_like(input[..., 0:1]).to(input.device)
            yaw=input[..., 4:5]

            hue=self.hue*torch.ones_like(input[..., 0:1]).to(input.device)
            satur=self.satur*torch.ones_like(input[..., 0:1]).to(input.device)

        elif self.input_type=="env":
            x=self.x*torch.ones_like(input[..., 0:1]).to(input.device)
            y=self.y*torch.ones_like(input[..., 0:1]).to(input.device)
            z=self.z*torch.ones_like(input[..., 0:1]).to(input.device)

            roll=self.roll*torch.ones_like(input[..., 0:1]).to(input.device)
            pitch=self.pitch*torch.ones_like(input[..., 0:1]).to(input.device)
            yaw=self.yaw*torch.ones_like(input[..., 0:1]).to(input.device)

            hue=input[..., 0:1]
            satur=input[..., 1:2]

        positions_xyzrpy = torch.cat([x, y, z, roll, pitch, yaw], dim=-1)
        env=torch.cat([hue,satur], dim=-1)
        return positions_xyzrpy,env
        

    def get_rays(
        self, c2w: torch.Tensor, directions: torch.Tensor
        ):

        rays_d = torch.sum(directions * c2w[...,:3, :3], dim=-1)
        rays_o=c2w[...,:3,-1]

        return rays_o,rays_d
    
    def get_directions(self):
        r"""
        Find origin and direction of rays through every pixel and camera origin.
        """
        # Apply pinhole camera model to gather directions at each pixel
        total_height=self.total_height
        total_width=self.total_width
        start_height=self.start_height
        start_width=self.start_width
        end_height=self.end_height
        end_width=self.end_width
        focal_x_length=self.focal_x.to(device)
        focal_y_length=self.focal_y.to(device)
        i, j = torch.meshgrid(
            torch.arange(start=start_width, end=end_width, dtype=torch.float32).to(device),
            torch.arange(start=start_height,end=end_height, dtype=torch.float32).to(device),
            indexing="ij",
        )
        
        i, j = i.transpose(-1, -2), j.transpose(-1, -2)
        directions = torch.stack(
            [
                (i - total_width * 0.5) / focal_x_length,
                -(j - total_height * 0.5) / focal_y_length,
                -torch.ones_like(i),
            ],
            dim=-1,
        )

        # Apply camera pose to directions
        #rays_d = torch.sum(directions[..., None, :] * c2w[...,:3, :3], dim=-1)
        directions=directions.reshape([(end_height-start_height)*(end_width-start_width),3])
        return directions[..., None, :]
    
    def sample_stratified(
        self,
        rays_o: torch.Tensor,
        rays_d: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Sample along ray from regularly-spaced bins.
        """

        near,far=self.near,self.far
        distance_to_infinity=self.distance_to_infinity
        n_samples=self.n_samples
        perturb=self.perturb
        inverse_depth=self.inverse_depth

        # Grab samples for space integration along ray
        t_vals = torch.linspace(0.0, 1.0, n_samples, device=rays_o.device)
        if not inverse_depth:
            # Sample linearly between `near` and `far`
            z_vals = near * (1.0 - t_vals) + far * (t_vals)
        else:
            # Sample linearly in inverse depth (disparity)
            z_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * (t_vals))

        # Draw uniform samples from bins along ray
        if perturb:
            mids = 0.5 * (z_vals[1:] + z_vals[:-1])
            upper = torch.concat([mids, z_vals[-1:]], dim=-1)
            lower = torch.concat([z_vals[:1], mids], dim=-1)
            #t_rand = torch.rand([n_samples], device=z_vals.device)
            t_rand=self.t_rand.to(z_vals.device)
            z_vals = lower + (upper - lower) * t_rand

        dists_vals = z_vals[..., 1:]-z_vals[..., :-1]
        dists_vals = torch.cat([dists_vals, distance_to_infinity * torch.ones_like(dists_vals[..., :1])], dim=-1)
        dists_vals= dists_vals.repeat(*rays_o.shape[:-1],1)

        #z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples])
        z_vals = z_vals.repeat(*rays_o.shape[:-1],1)
        

        # Apply scale from `rays_d` and offset from `rays_o` to samples
        # pts: (width, height, n_samples, 3)
        #print('shapes:',rays_o[..., None, :].shape,rays_d[..., None, :].shape,z_vals[..., :, None].shape)
        pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]

        return pts, z_vals,dists_vals
    
    def two_norm(self,inputs: torch.Tensor, dim: int,keepdim: bool =False) -> torch.Tensor:
        squared = torch.square(inputs)  # Square the elements
        summed = torch.sum(squared, dim=dim, keepdim=keepdim)  # Sum along the specified dimension
        norm_manual = torch.sqrt(summed)  # Take the square root
        return norm_manual

    def get_chunks(self,inputs: torch.Tensor) -> List[torch.Tensor]:
        r"""
        Divide an input into chunks.
        """

        chunksize=self.chunksize
        n_samples=self.n_samples
        
        return [inputs[:,i : i + chunksize] for i in range(0, n_samples, chunksize)]
    
    def prepare_chunks(
        self,
        points: torch.Tensor,
        encoding_function: Callable[[torch.Tensor], torch.Tensor],
        env: torch.Tensor, 
        encoding_function_env:  Callable[[torch.Tensor], torch.Tensor],
    ) -> List[torch.Tensor]:
        r"""
        Encode and chunkify points to prepare for NeRF model.
        """
        
        chunksize=self.chunksize

        env = env[...,None,:].repeat([1,self.n_samples,1])
        env=encoding_function_env(env)
        points = encoding_function(points)
        points = torch.cat((points, env), dim=-1)
        points = self.get_chunks(points)
        return points

    def prepare_viewdirs_chunks(
        self,
        rays_d: torch.Tensor,
        encoding_function: Callable[[torch.Tensor], torch.Tensor]
    ) -> List[torch.Tensor]:
        r"""
        Encode and chunkify viewdirs to prepare for NeRF model.
        """
        # Prepare the viewdirs
        chunksize=self.chunksize
        #print('norm.shape:',torch.norm(rays_d, dim=-1, keepdim=True).shape)
        #print('norm2.shape:',torch.norm(rays_d, dim=-1).unsqueeze(-1).shape)
        
        #print(rays_d.shape)
        norm_manual=self.two_norm(rays_d,dim=-1, keepdim=True)
        tmp = 1/norm_manual
        viewdirs = rays_d*tmp
        viewdirs = viewdirs[:, None, ...].repeat([1,self.n_samples,1])
        viewdirs = encoding_function(viewdirs)
        viewdirs = self.get_chunks(viewdirs)
        return viewdirs
    
    def get_rgb_map(self,alpha:torch.Tensor, rgb:torch.Tensor)-> torch.Tensor:
        # tmp =torch.zeros_like(alpha[...,None, 0]).to(alpha.device)
        tmp = 0.0
        alpha_rgb = alpha[..., None] * rgb
        one_minus_alpha = 1 - alpha
        for i in reversed(range(self.n_samples)):
            tmp = alpha_rgb[:, i, :] + one_minus_alpha[:, i:i+1] * tmp
        return tmp

    
    def get_depth_map(self,alpha:torch.Tensor, z_vals:torch.Tensor)-> torch.Tensor:
        depth_map=torch.zeros_like(alpha[..., 0]).to(alpha.device)
        one_minus_alpha = 1 - alpha
        for i in reversed(range(self.n_samples)):
            depth_map=alpha[..., i]*z_vals[...,i]+one_minus_alpha[..., i]*depth_map
        
        return depth_map  
    
    def get_acc_map(self,alpha:torch.Tensor)-> torch.Tensor:
        acc_map=torch.zeros_like(alpha[..., 0]).to(alpha.device)
        one_minus_alpha = 1 - alpha
        for i in reversed(range(self.n_samples)):
            acc_map=alpha[..., i]+one_minus_alpha[..., i]*acc_map
        
        return acc_map
    
    def raw2outputs(
        self,
        raw: torch.Tensor,
        z_vals: torch.Tensor,
        dists_vals: torch.Tensor,
        rays_d: torch.Tensor,
        white_bkgd: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""
        Convert the raw NeRF output into RGB and other maps.
        """

        dists_vals = dists_vals * self.two_norm(rays_d[..., None, :],dim=-1)
        noise = 0.0
        alpha = 1.0 - torch.exp(-nn.functional.relu((raw[..., 3] + noise )* dists_vals))

        rgb = torch.sigmoid(raw[..., :3])  # [n_rays, n_samples, 3]
        rgb_map=self.get_rgb_map(alpha,rgb)
        
        return rgb_map
        depth_map = self.get_depth_map(alpha,z_vals)
        acc_map = self.get_acc_map(alpha)

        disp_map = 1.0 / torch.max(
            1e-10 * torch.ones_like(depth_map), depth_map / acc_map
        )
        if white_bkgd:
            rgb_map = rgb_map + (1.0 - acc_map[..., None])


        return rgb_map,depth_map,acc_map,alpha

    def nerf_forward(
        self,
        rays_o: torch.Tensor,
        rays_d: torch.Tensor,
        rays_env: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
        r"""
        Compute forward pass through model(s).
        """
        near,far=self.near,self.far
        encoding_fn=self.encode
        coarse_model=self.model
        kwargs_sample_stratified=self.kwargs_sample_stratified
        n_samples_hierarchical=self.n_samples_hierarchical
        kwargs_sample_hierarchical=self.kwargs_sample_hierarchical
        fine_model=self.fine_model
        viewdirs_encoding_fn=self.encode_viewdirs
        env_encoding_fn=self.encode_env

        # Sample query points along each ray.
        query_points, z_vals, dists_vals = self.sample_stratified(rays_o,rays_d)
        #outputs = {"z_vals_stratified": z_vals}

        # Prepare batches.
        batches = self.prepare_chunks(query_points, encoding_fn,rays_env, env_encoding_fn)
        #print('batches_legnth:',len(batches))
        #print('batch.shape:',batches[0].shape)
        if viewdirs_encoding_fn is not None:
            batches_viewdirs = self.prepare_viewdirs_chunks(
                rays_d, viewdirs_encoding_fn
            )
            #print('batches_viewdirs_legnth:',len(batches_viewdirs))
            #print('batch_viewdirs.shape:',batches_viewdirs[0].shape)
        else:
            batches_viewdirs = [None] * len(batches)
            print('batches_viewdirs is in composition of None.')
        

        # Coarse model pass.
        # Split the encoded points into "chunks", run the model on all chunks, and
        # concatenate the results (to avoid out-of-memory issues).

        predictions = []
        for batch, batch_viewdirs in zip(batches, batches_viewdirs):
            predictions.append(coarse_model(batch, viewdirs=batch_viewdirs))

        raw = torch.cat(predictions, dim=1)
        if self.print_flag:
            print('raw.shape:',raw.shape)

        # Perform differentiable volume rendering to re-synthesize the RGB image.
        rgb_map=self.raw2outputs(raw, z_vals, dists_vals, rays_d)
        return rgb_map
    
        rgb_map,depth_map,acc_map,alpha= self.raw2outputs(raw, z_vals, dists_vals, rays_d)
        
        print('rgb_map.shape:',rgb_map.shape)
        print('depth_map.shape:',depth_map.shape)
        print('acc_map.shape:',acc_map.shape)
        print('alpha.shape:',alpha.shape)

        # Store outputs.
        outputs["rgb_map"] = rgb_map
        outputs["depth_map"] = depth_map
        outputs["acc_map"] = acc_map
        outputs["alpha"] = alpha
        return outputs
    
    def forward(self,x,directions):

        x,rays_env=self.generate_xyzrpy_torch(x)
        x=self.get_extrinsic_matrix(x)
        #print(rays_env.shape)

        rays_o, rays_d=self.get_rays(x, directions)
        # return rays_o
        if self.print_flag:
            print('rays_o.shape:',rays_o.shape)
            print('rays_d.shape:',rays_d.shape)

        #return self.nerf_forward(rays_o,rays_d)

        rgb_map=self.nerf_forward(rays_o,rays_d,rays_env)
        return  rgb_map
    
        outputs=self.nerf_forward(rays_o,rays_d)
        res=outputs["rgb_map"]
        
        return res
    
class RGBModel(nn.Module):
    def __init__(self, n_samples):
        super(RGBModel, self).__init__()
        self.n_samples = n_samples


    def get_rgb_map(self,alpha:torch.Tensor, rgb:torch.Tensor)-> torch.Tensor:
        # tmp =torch.zeros_like(alpha[...,None, 0]).to(alpha.device)
        tmp = 0.0
        # Compute alpha * rgb outside the loop
        alpha_rgb = alpha[..., None] * rgb
        # Compute 1 - alpha outside the loop
        one_minus_alpha = 1 - alpha
        for i in reversed(range(self.n_samples)):
            tmp = alpha_rgb[:, i, :] + one_minus_alpha[:, i:i+1] * tmp
            # Use ReLU to clamp the value
            # tmp = 1 - torch.relu(1 - tmp)
            #tmp=torch.relu(tmp)

        return tmp


    def forward(self, alpha_rgb):
        alpha, rgb = alpha_rgb[..., 0], alpha_rgb[..., 1:]
        return self.get_rgb_map(alpha, rgb)
    
def extrinsic_matrix_to_xyzrpy(T):
    x, y, z = T[0, 3], T[1, 3], T[2, 3]
    R = T[:3, :3]

    def rotation_matrix_to_rpy(R):
        pitch = -np.arcsin(R[2, 0])
        if np.abs(np.cos(pitch)) > np.finfo(float).eps:
            roll = np.arctan2(R[2, 1], R[2, 2])
            yaw = np.arctan2(R[1, 0], R[0, 0])
        else:
            roll = 0
            yaw = np.arctan2(-R[0, 1], R[1, 1])
        return roll, pitch, yaw

    roll, pitch, yaw = rotation_matrix_to_rpy(R)
    return np.array([x, y, z, roll, pitch, yaw])

def compute_image_sampling(ray_model,directions,input,input_type,num_sampling,xyz_step, ry_step,hue_step,satur_step,\
                           start_height,end_height,start_width,end_width):
    
    inputpos= input.repeat((end_height-start_height)*(end_width-start_width),1)
    predicted_image=ray_model(inputpos, directions)

    image_ls,image_us=predicted_image,predicted_image

    for i in range(num_sampling):
        if input_type=="x" or input_type=="y" or input_type=="z" :
            random_tensor = xyz_step* (torch.rand(1, 1).to(input.device) * 2 - 1 )
        elif input_type=="roll" or input_type=="yaw":
            random_tensor = ry_step* (torch.rand(1, 1).to(input.device) * 2 - 1 )
        elif input_type=="hue":
            random_tensor = hue_step* (torch.rand(1, 1).to(input.device) * 2 - 1 )
        elif input_type=="satur":
            random_tensor = satur_step* (torch.rand(1, 1).to(input.device) * 2 - 1 )
        elif input_type=="ry":
            random_tensor = ry_step* (torch.rand(1, 2).to(input.device) * 2 - 1 )
        elif input_type=="xyz":
            random_tensor = xyz_step* (torch.rand(1, 3).to(input.device) * 2 - 1 )
        elif input_type=="xyzry":
            random_tensor_1 = xyz_step* (torch.rand(1, 3).to(input.device) * 2 - 1 )
            random_tensor_2 = ry_step* (torch.rand(1, 2).to(input.device) * 2 - 1 )
            random_tensor=torch.cat((random_tensor_1, random_tensor_2), dim=1)
        elif input_type=="env":
            random_tensor_1 = hue_step* (torch.rand(1, 1).to(input.device) * 2 - 1 )
            random_tensor_2 = satur_step* (torch.rand(1, 1).to(input.device) * 2 - 1 )
            random_tensor=torch.cat((random_tensor_1, random_tensor_2), dim=1)

        inputvec=input+random_tensor
        inputpos= inputvec.repeat((end_height-start_height)*(end_width-start_width),1)
        predicted_image=ray_model(inputpos, directions)

        image_ls=torch.min(image_ls,predicted_image)
        image_us=torch.max(image_us,predicted_image)

    image_ls=image_ls.reshape([end_height-start_height, end_width-start_width, 3]).detach().cpu().numpy()
    image_us=image_us.reshape([end_height-start_height, end_width-start_width, 3]).detach().cpu().numpy()

    if num_sampling!=0:
        del inputvec,random_tensor
    del inputpos,predicted_image
    
    return image_ls,image_us

def compute_image_bound(ray_model, bounded_ray_model, input,input_type,num_sampling,xyz_step,ry_step,hue_step,satur_step,\
                        start_vis_height,end_vis_height,tile_height,start_vis_width,end_vis_width,tile_width,\
                        print_flag,visual_flag,device,dist_to_object,theta_step):

    start_height = start_vis_height
    start_width = start_vis_width
    end_height = end_vis_height
    end_width = end_vis_width

    if input_type=="x" or input_type=="y" or input_type=="z" or input_type=="roll" or input_type=="yaw" or input_type=="hue" or input_type=="satur":
        dummy_inputpos = BoundedTensor(torch.rand(((end_height-start_height)*(end_width-start_width), 1), device=device))
    elif input_type=="ry":
        dummy_inputpos = BoundedTensor(torch.rand(((end_height-start_height)*(end_width-start_width), 2), device=device))
    elif input_type=="xyz":
        dummy_inputpos = BoundedTensor(torch.rand(((end_height-start_height)*(end_width-start_width), 3), device=device))
    elif input_type=="xyzry":
        dummy_inputpos = BoundedTensor(torch.rand(((end_height-start_height)*(end_width-start_width), 5), device=device))
    elif input_type=="env":
        dummy_inputpos = BoundedTensor(torch.rand(((end_height-start_height)*(end_width-start_width), 2), device=device))
    elif input_type=="rot":
        dummy_inputpos = BoundedTensor(torch.rand(((end_height-start_height)*(end_width-start_width), 3), device=device))

    if print_flag:
        print('\n cur_height,cur_width:',start_height,start_width)
    
    ray_model.update_height_and_width(start_height,end_height,start_width,end_width)
    directions = ray_model.get_directions()
    # h_w = torch.tensor([start_height,end_height,start_width,end_width], device=device).unsqueeze(0)

    if input_type=="x" or input_type=="y" or input_type=="z" or input_type=="xyz":
        ptb = PerturbationLpNorm(norm=np.inf, eps=xyz_step)
    elif input_type=="hue":
        ptb = PerturbationLpNorm(norm=np.inf, eps=hue_step)
    elif input_type=="satur":
        ptb = PerturbationLpNorm(norm=np.inf, eps=satur_step)
    elif input_type=="roll" or input_type=="yaw" or input_type=="ry":
        ptb = PerturbationLpNorm(norm=np.inf, eps=ry_step)
    elif input_type=="rot":
        cur_val=input[2]
        x_per = dist_to_object*torch.abs(torch.sin(cur_val))*theta_step/2
        z_per = dist_to_object*torch.abs(torch.cos(cur_val))*theta_step/2
        pitch_per = theta_step/2

        eps_tensor=torch.tensor([[x_per,z_per,pitch_per]]).to(device)
        xl=(input-eps_tensor).repeat((end_height-start_height)*(end_width-start_width),1)
        xu=(input+eps_tensor).repeat((end_height-start_height)*(end_width-start_width),1)
        ptb = PerturbationLpNorm(x_L=xl,x_U=xu)
        #print('xl,xu:',xl.shape,xu.shape)
    elif input_type=="xyzry":
        eps_tensor=torch.tensor([[xyz_step,xyz_step,xyz_step,ry_step,ry_step]]).to(device)
        xl=(input-eps_tensor).repeat((end_height-start_height)*(end_width-start_width),1)
        xu=(input+eps_tensor).repeat((end_height-start_height)*(end_width-start_width),1)
        ptb = PerturbationLpNorm(x_L=xl,x_U=xu)
    elif input_type=="env":
        eps_tensor=torch.tensor([[hue_step,satur_step]]).to(device)
        xl=(input-eps_tensor).repeat((end_height-start_height)*(end_width-start_width),1)
        xu=(input+eps_tensor).repeat((end_height-start_height)*(end_width-start_width),1)
        ptb = PerturbationLpNorm(x_L=xl,x_U=xu)
    
    
    
    inputpose= input.repeat((end_height-start_height)*(end_width-start_width),1)
    
    inputpose_ptb = BoundedTensor(inputpose, ptb)
    model = bounded_ray_model
    # model.visualize('model_viewer')

    

    # print("computing ibp and crown")
    if print_flag:
        print("Start IBP")
    lb_ibp, ub_ibp = model.compute_bounds(x=(inputpose_ptb, directions), method="ibp")
    if print_flag:
        print("IBP finished")
    reference_interm_bounds = {}
    for node in model.nodes():
        if (node.perturbed
            and isinstance(node.lower, torch.Tensor)
            and isinstance(node.upper, torch.Tensor)):
            reference_interm_bounds[node.name] = (node.lower, node.upper)
    if print_flag:
        print("Start forward")
    # backward_start_time = time.time()
    required_A = defaultdict(set)
    required_A[model.output_name[0]].add(model.input_name[0])
    required_A[model.output_name[0]].add(model.input_name[1])
    lb, ub, A_dict = model.compute_bounds(
        x=(inputpose_ptb, directions),
        method="forward+backward",
        reference_bounds=reference_interm_bounds,
        return_A=True, needed_A_dict=required_A)
    lower_A, lower_bias = A_dict[model.output_name[0]][model.input_name[0]]['lA'], A_dict[model.output_name[0]][model.input_name[0]]['lbias']
    upper_A, upper_bias = A_dict[model.output_name[0]][model.input_name[0]]['uA'], A_dict[model.output_name[0]][model.input_name[0]]['ubias']
        
    # print("lb.shape:",lb.shape)
    if print_flag:
        print("Lower bounds: ", lb)
        print("Upper bounds: ", ub)
    
    # lb = lb.reshape([end_height-start_height, end_width-start_width, 3])
    # ub = ub.reshape([end_height-start_height, end_width-start_width, 3])

    input_lb = inputpose - ptb.eps
    input_ub = inputpose + ptb.eps

    del ptb, inputpose_ptb,lb_ibp, ub_ibp,reference_interm_bounds
    torch.cuda.empty_cache()
    
    return input_lb, input_ub, lb, ub, lower_A, lower_bias, upper_A, upper_bias

    # lb=torch.clamp(lb,min=0,max=1)
    # ub=torch.clamp(ub,min=0,max=1)

    # Establish the whole image by composing every tile
    # if visual_flag:

    #     image_lb[start_height:end_height,start_width:end_width,:]=lb.reshape([end_height-start_height, end_width-start_width, 3]).detach().cpu().numpy()
    #     image_ub[start_height:end_height,start_width:end_width,:]=ub.reshape([end_height-start_height, end_width-start_width, 3]).detach().cpu().numpy()
    #     image_ls[start_height:end_height,start_width:end_width,:]=ls
    #     image_us[start_height:end_height,start_width:end_width,:]=us


    # return image_lb,image_ub,image_ls,image_us,sampling_time

def save_data(save_path,images_lb,images_ub,image_ls=None,image_us=None,pose=None,env=None):
    # np.savez(save_path, images_lb=images_lb,images_ub=images_ub,image_ls=image_ls,image_us=image_us,pose=pose)
    variables_to_save = {}
    if images_lb is not None:
        variables_to_save['images_lb'] = images_lb
    if images_ub is not None:
        variables_to_save['images_ub'] = images_ub
    if image_ls is not None:
        variables_to_save['image_ls'] = image_ls
    if image_us is not None:
        variables_to_save['image_us'] = image_us
    if pose is not None:
        variables_to_save['pose'] = pose
    if env is not None:
        variables_to_save['env'] = env

    # Save the variables that are not None
    np.savez(save_path, **variables_to_save)

    print(f"Data saved to {save_path}")

def process_nerf_config():
    # Create an ArgumentParser object
    parser = argparse.ArgumentParser(description="NeRF configuration")

    # Add arguments with default values
    parser.add_argument('--dataname', type=str, default='airplane_grey', help='Dataset name')
    parser.add_argument('--n_samples', type=int, default=32, help='Number of samples')
    parser.add_argument('--n_layers', type=int, default=2, help='Number of layers')
    parser.add_argument('--d_filter', type=int, default=128, help='Filter size')

    parser.add_argument('--n_iters', type=int, default=300000, help='Number of iterations')
    parser.add_argument('--chunksize', type=int, default=2**5, help='Chunk size')

    parser.add_argument('--xyz_step', type=float, default=0.0002, help='XYZ step size')
    parser.add_argument('--ry_step', type=float, default=0.00010, help='RY step size')
    parser.add_argument('--xyz_eps', type=float, default=0.0006, help='XYZ epsilon')
    parser.add_argument('--ry_eps', type=float, default=0.0003, help='RY epsilon')

    parser.add_argument('--hue_step', type=float, default=0.00007, help='Hue step size')
    parser.add_argument('--satur_step', type=float, default=0.00012, help='Saturation step size')
    parser.add_argument('--hue_eps', type=float, default=0.00007, help='Hue epsilon')
    parser.add_argument('--satur_eps', type=float, default=0.00012, help='Saturation epsilon')

    parser.add_argument('--hue_offset', type=float, default=0.0, help='Hue offset')
    parser.add_argument('--satur_offset', type=float, default=0.0, help='Saturation offset')
    parser.add_argument('--input_type', type=str, default='rot', help='Input type')
    parser.add_argument('--num_sampling', type=int, default=0, help='Number of samples for testing')

    parser.add_argument('--testimgidx', type=int, default=236, help='Test image index')
    parser.add_argument('--visual_flag', type=bool, default=True, help='Visual flag')
    parser.add_argument('--bound_whole_flag', type=bool, default=True, help='Bounding flag for whole scene')
    parser.add_argument('--xdown_factor', type=float, default=5, help='X downscale factor')
    parser.add_argument('--ydown_factor', type=float, default=3.75, help='Y downscale factor')

    parser.add_argument('--tile_height', type=int, default=32, help='Tile height')
    parser.add_argument('--tile_width', type=int, default=32, help='Tile width')

    parser.add_argument('--print_flag', type=bool, default=False, help='Print flag')
    parser.add_argument('--save_npz_flag', type=bool, default=True, help='Save NPZ flag')
    parser.add_argument('--save_img_flag', type=bool, default=True, help='Save image flag')
    parser.add_argument('--save_img_sep_flag', type=bool, default=False, help='Save image separately flag')

    parser.add_argument('--hue_min', type=int, default=-30, help='Minimum hue')
    parser.add_argument('--hue_max', type=int, default=30, help='Maximum hue')
    parser.add_argument('--sat_min', type=float, default=-0.5, help='Minimum saturation')
    parser.add_argument('--sat_max', type=float, default=0.5, help='Maximum saturation')

    parser.add_argument('--near', type=float, default=2.0, help='Near plane')
    parser.add_argument('--far', type=float, default=6.0, help='Far plane')
    parser.add_argument('--distance_to_infinity', type=float, default=1e2, help='Distance to infinity')

    parser.add_argument('--perturb', type=bool, default=False, help='Perturbation flag')
    parser.add_argument('--inverse_depth', type=bool, default=False, help='Inverse depth flag')
    parser.add_argument('--n_samples_hierarchical', type=int, default=0, help='Number of hierarchical samples')

    parser.add_argument('--d_input', type=int, default=3, help='Input dimension')
    parser.add_argument('--env_input', type=int, default=2, help='Environment input dimension')
    parser.add_argument('--n_freqs', type=int, default=10, help='Number of frequencies')
    parser.add_argument('--log_space', type=bool, default=True, help='Logarithmic space flag')
    parser.add_argument('--n_freqs_views', type=int, default=4, help='Number of view frequencies')
    parser.add_argument('--skip', type=list, default=[], help='Skip list')
    parser.add_argument('--raw_noise_std', type=float, default=0.0, help='Raw noise standard deviation')

    parser.add_argument('--classification_batch_size', type=int, default=1, help='Classification batch size')

    # Parse the arguments
    args = parser.parse_args()

    # Access the arguments within the function
    print("Configurations:")
    print(f"Dataset name: {args.dataname}")
    print(f"Number of samples: {args.n_samples}")
    print(f"Number of layers: {args.n_layers}")
    print(f"Filter size: {args.d_filter}")
    print(f"Number of iterations: {args.n_iters}")
    print(f"Chunk size: {args.chunksize}")
    print(f"input_type: {args.input_type}")
    input_type=args.input_type
    if input_type=="xyz" or input_type=="x" or input_type=="y" or input_type=="z":
        print(f"Perturbation: {args.xyz_eps}")
    elif input_type=="ry" or input_type=="roll" or  input_type=="yaw":
        print(f"Perturbation: {args.ry_eps}")
    elif input_type=="xyzry":
        print(f"Perturbation: {args.xyz_eps,args.ry_eps}")
    elif input_type=="hue":
        print(f"Perturbation: {args.hue_eps}")
    elif input_type=="satur":
        print(f"Perturbation: {args.satur_eps}")
    elif input_type=="env":
        print(f"Perturbation: {args.hue_eps,args.satur_eps}")

    # print(f"XYZ Step: {args.xyz_step}, RY Step: {args.ry_step}")
    # print(f"Hue Step: {args.hue_step}, Saturation Step: {args.satur_step}")
    print(f"Visual flag: {args.visual_flag}")
    print(f"Save image flag: {args.save_img_flag}")
    # Add any further processing needed

    return args


def classification(classification_model, 
                   input_lb, input_ub,
                   images_lb, images_ub,
                   lower_A, lower_bias,
                   upper_A, upper_bias,
                   linear_perturbation=True):
    if linear_perturbation:
        # Reshape the images to match the input shape of the classification model
        input_lb = torch.concat(input_lb, dim=0).squeeze(-1)
        input_ub = torch.concat(input_ub, dim=0).squeeze(-1)
        # (B, W*H, 3, 1) -> (B, 3, W*H)
        lower_A = torch.concat(lower_A, dim=0).squeeze(-1).permute(0, 2, 1)
        # (B, 3, W*H) -> (B, 3*W*H, W*H)
        lower_A = torch.diag_embed(lower_A).reshape(lower_A.shape[0], -1, lower_A.shape[2])
        # (B, W*Hm 3) -> (B, 3*W*H)
        lower_bias = torch.concat(lower_bias, dim=0).permute(0, 2, 1).reshape(input_lb.shape[0], -1)

        # (B, W*H, 3, 1) -> (B, 3, W*H)
        upper_A = torch.concat(upper_A, dim=0).squeeze(-1).permute(0, 2, 1)
        # (B, 3, W*H) -> (B, 3*W*H, W*H)
        upper_A = torch.diag_embed(upper_A).reshape(upper_A.shape[0], -1, upper_A.shape[2])
        # (B, W*H, 3) -> (B, 3*W*H)
        upper_bias = torch.concat(upper_bias, dim=0).permute(0, 2, 1).reshape(input_ub.shape[0], -1)

    # (B, W*H, 3) -> (B, 3*W*H)
    images_lb = torch.concat(images_lb, dim=0).permute(0, 2, 1).reshape(len(images_lb), -1)
    images_ub = torch.concat(images_ub, dim=0).permute(0, 2, 1).reshape(len(images_ub), -1)

    # mid = ((input_lb + input_ub) / 2).unsqueeze(-1)
    # diff = ((input_ub - input_lb) / 2).unsqueeze(-1)
    # concretized_upper = (upper_A @ mid + torch.abs(upper_A) @ diff).squeeze(-1) + upper_bias
    # concretized_lower = (lower_A @ mid - torch.abs(lower_A) @ diff).squeeze(-1) + lower_bias

    # assert torch.allclose(concretized_upper, images_ub)
    # assert torch.allclose(concretized_lower, images_lb)
    if linear_perturbation:
        ptb = PerturbationLinear(lower_A, upper_A, lower_bias, upper_bias,
                                input_lb=input_lb, input_ub=input_ub,
                                x_L=images_lb, x_U=images_ub)
    else:
        ptb = PerturbationLpNorm(x_L=images_lb, x_U=images_ub)
    bounded_images = BoundedTensor((images_lb + images_ub) / 2, ptb)

    # test_output = classification_model(bounded_images)
    ret = classification_model.compute_bounds(bounded_images, method="backward")

    return ret



if __name__ == "__main__":

    # Call process_nerf_config to get the arguments
    args = process_nerf_config()

    # Assign the argument values to corresponding variables
    dataname = args.dataname
    n_samples = args.n_samples
    n_layers = args.n_layers
    d_filter = args.d_filter
    n_iters = args.n_iters
    chunksize = args.chunksize
    xyz_step = args.xyz_step
    ry_step = args.ry_step
    xyz_eps = args.xyz_eps
    ry_eps = args.ry_eps

    hue_step = args.hue_step
    satur_step = args.satur_step
    hue_eps = args.hue_eps
    satur_eps = args.satur_eps

    hue_offset = args.hue_offset
    satur_offset = args.satur_offset
    input_type = args.input_type
    num_sampling = args.num_sampling

    testimgidx = args.testimgidx
    visual_flag = args.visual_flag
    bound_whole_flag = args.bound_whole_flag
    xdown_factor = args.xdown_factor
    ydown_factor = args.ydown_factor
    tile_height = args.tile_height
    tile_width = args.tile_width

    print_flag = args.print_flag
    save_npz_flag = args.save_npz_flag
    save_img_flag = args.save_img_flag
    save_img_sep_flag = args.save_img_sep_flag

    hue_min = args.hue_min
    hue_max = args.hue_max
    sat_min = args.sat_min
    sat_max = args.sat_max

    near = args.near
    far = args.far
    distance_to_infinity = args.distance_to_infinity
    perturb = args.perturb
    inverse_depth = args.inverse_depth
    n_samples_hierarchical = args.n_samples_hierarchical

    d_input = args.d_input
    env_input = args.env_input
    n_freqs = args.n_freqs
    log_space = args.log_space
    n_freqs_views = args.n_freqs_views

    skip = args.skip
    raw_noise_std = args.raw_noise_std

    kwargs_sample_stratified = {
        "n_samples": n_samples,
        "perturb": perturb,
        "inverse_depth": inverse_depth,
    }
    kwargs_sample_hierarchical = {"perturb": perturb}

    images_lb=[]
    images_ub=[]
    images_ls=[]
    images_us=[]
    
    datapath='data/'+dataname+'_env_data.npz'

    data = np.load(os.path.join(script_dir,datapath))
    images = data["images"]
    poses = data["poses"]
    focal = data["focal"]
    envs = data["env"]

    # envs[:,0] = (envs[:,0]-hue_min)/(hue_max-hue_min)
    # envs[:,1] = (envs[:,1]-sat_min)/(sat_max-sat_min)

    envs=np.zeros_like(envs)


    
    
    testimg = images[testimgidx]
    testpose = poses[testimgidx]
    #print(testpose.shape)

    # cv2img = cv2.cvtColor(testimg, cv2.COLOR_RGB2BGR)
    # testimg = cv2.cvtColor(cv2img, cv2.COLOR_RGB2BGR)
    testimg = torch.Tensor(testimg).to(device)

    testenv = envs[testimgidx]
    testenv[0]-=hue_offset
    testenv[1]-=satur_offset
    if print_flag:
        print("env:",testenv)
    total_height, total_width = testimg.shape[:2]
    print('total_height, total_width:',total_height, total_width)
   
    start_vis_height, end_vis_height = 0, tile_height #0,0+tile_height*1#20,20+tile_height*1 #
    start_vis_width, end_vis_width = 0, tile_width #20,20+tile_width*3 # 0,0+tile_width*3 #

    xdown_factor = total_width / tile_width
    ydown_factor = total_height / tile_height

    total_height, total_width = tile_height, tile_width

    start_vis_height_org,end_vis_height_org=start_vis_height*ydown_factor,end_vis_height*ydown_factor
    start_vis_width_org,end_vis_width_org=start_vis_width*xdown_factor,end_vis_width*xdown_factor


    xyzrpy_np = extrinsic_matrix_to_xyzrpy(testpose)
    dist_to_object=float(np.linalg.norm(xyzrpy_np[:2],ord=2))
    xyzrpy=torch.Tensor(xyzrpy_np).to(device)
    extrinsic_matrix = torch.Tensor(testpose).to(device)

    env_np=testenv
    testenv=torch.Tensor(testenv).to(device)
    
    focal_x = torch.Tensor([focal/xdown_factor]).to(device)
    focal_y = torch.Tensor([focal/ydown_factor]).to(device)
    
    
    feature="env_"+str(dataname)+"_"+str(n_freqs)+"_"+str(n_freqs_views)+"_"+str(d_filter)+"_"+str(n_layers)+"_"+str(n_iters)

    encode = PositionalEncoder(d_input, n_freqs, log_space=log_space).to(device)
    encode_env = PositionalEncoderEnv(env_input , n_freqs, log_space=log_space).to(device)
    encode_viewdirs = PositionalEncoder(d_input, n_freqs_views, log_space=log_space).to(device)
    d_viewdirs = encode_viewdirs.d_output

    coarse_model = NeRF(
        encode.d_output+encode_env.d_output,
        n_layers=n_layers,
        d_filter=d_filter,
        skip=skip,
        d_viewdirs=d_viewdirs,
    )

    coarse_model.load_state_dict(torch.load(os.path.join(script_dir, 'pts/nerf-fine_'+feature+'.pt')))
    coarse_model.to(device)

    fine_model = NeRF(
        encode.d_output+encode_env.d_output,
        n_layers=n_layers,
        d_filter=d_filter,
        skip=skip,
        d_viewdirs=d_viewdirs,
    )
    fine_model.load_state_dict(torch.load(os.path.join(script_dir,'pts/nerf-fine_'+feature+'.pt')))
    fine_model.to(device)


    ray_model=RenderModel(input_type,total_height,total_width,None,None,None,None,focal_x,focal_y,\
                        xyzrpy_np,env_np,near,far,distance_to_infinity,n_samples,perturb,inverse_depth,\
                        kwargs_sample_stratified,n_samples_hierarchical,kwargs_sample_hierarchical,chunksize,\
                        encode,encode_env,encode_viewdirs,coarse_model,fine_model,\
                        raw_noise_std,print_flag
                        ).to(device)
    
    # torch.onnx.export(ray_model,(dummy_inputpos,h_w),'onnx_net.onnx')
    
    start_time=time.time()
    sampling_times=0

    

    x_start, x_end, x_step = xyzrpy_np[0], xyzrpy_np[0]+xyz_eps, xyz_step
    y_start, y_end, y_step = xyzrpy_np[1], xyzrpy_np[1]+xyz_eps, xyz_step
    z_start, z_end, z_step = xyzrpy_np[2], xyzrpy_np[2]+xyz_eps, xyz_step

    roll_start, roll_end, roll_step = xyzrpy_np[3], xyzrpy_np[3]+ry_eps, ry_step
    pitch_start, pitch_end, pitch_step = xyzrpy_np[4], xyzrpy_np[4]+ry_eps, ry_step
    yaw_start, yaw_end, yaw_step = xyzrpy_np[5], xyzrpy_np[5]+ry_eps, ry_step

    hue_start, hue_end, hue_step = env_np[0], env_np[0]+hue_eps, hue_step
    satur_start,  satur_end, satur_step = env_np[1], env_np[1]+satur_eps, satur_step


    x_vals = np.arange(x_start, x_end, x_step*2)
    y_vals = np.arange(y_start, y_end, y_step*2)
    z_vals = np.arange(z_start, z_end, z_step*2)

    roll_vals = np.arange(roll_start, roll_end, roll_step*2)
    pitch_vals = np.arange(pitch_start, pitch_end, pitch_step*2)
    yaw_vals = np.arange(yaw_start, yaw_end, yaw_step*2)

    hue_vals = np.arange(hue_start, hue_end, hue_step*2)
    satur_vals = np.arange(satur_start, satur_end, satur_step*2)


    # classification_model = ImageClassificationModel()
    # classification_model.load_state_dict(torch.load('model_cnn_weights_advtrain_epochs50_eps0.03_alpha0.007_iter10_l1reg0.pth'))
    # classification_model.to(device)
    # classification_model.eval()

    classification_model_name = 'resnet2b'
    classification_model = ResNetModel(classification_model_name, tile_size=tile_height)
    classification_model.model.load_state_dict(torch.load(f'cifar10_resnet/{classification_model_name}.pth')['state_dict'])
    classification_model.to(device)
    classification_model.eval()

    bounded_classification_model = BoundedModule(classification_model,
                                                 torch.zeros((1, 3 * total_height * total_width), device=device),
                                                 bound_opts={'sparse_intermediate_bounds': False,
                                                             'conv_mode': 'matrix'})
    
    dummy_inputpos = BoundedTensor(torch.zeros((total_height * total_width, 3), device=device))
    directions = torch.zeros(total_height * total_width, 1, 3, device=device)

    bounded_ray_model = BoundedModule(ray_model, (dummy_inputpos, directions),
                                      bound_opts={'mul': {'middle': False}})

    input_lb_all = []
    input_ub_all = []
    images_lb_all = []
    images_ub_all = []
    lower_A_all = []
    lower_bias_all = []
    upper_A_all = []
    upper_bias_all = []
    ret_all_lower = []
    ret_all_upper = []

    if input_type in ["x", "y", "z","roll", "yaw", "hue", "satur"]:
        vals = {"x": x_vals, "y": y_vals, "z": z_vals, "roll":roll_vals, "yaw":yaw_vals, "hue":hue_vals, "satur": satur_vals}[input_type]  # Select the corresponding value list
        iteration = 0
        for cur_val in tqdm(vals):
            cur_val = float(cur_val)
            input_tensor = torch.tensor([cur_val]).to(device)

            with torch.no_grad():
                input_lb, input_ub, lb, ub, lower_A, lower_bias, upper_A, upper_bias = compute_image_bound(ray_model, bounded_ray_model, input_tensor, input_type, num_sampling, xyz_step, ry_step,hue_step,satur_step,
                                                                    start_vis_height, end_vis_height, tile_height,
                                                                    start_vis_width, end_vis_width, tile_width, print_flag,
                                                                    visual_flag, device)

                input_lb_all.append(input_lb.detach().unsqueeze(0))
                input_ub_all.append(input_ub.detach().unsqueeze(0))
                images_lb_all.append(lb.detach().unsqueeze(0))
                images_ub_all.append(ub.detach().unsqueeze(0))
                lower_A_all.append(lower_A.detach().unsqueeze(0))
                lower_bias_all.append(lower_bias.detach().unsqueeze(0))
                upper_A_all.append(upper_A.detach().unsqueeze(0))
                upper_bias_all.append(upper_bias.detach().unsqueeze(0))

            iteration += 1
            if iteration % args.classification_batch_size == 0 or iteration == len(vals):
                with torch.no_grad():
                    ret = classification(bounded_classification_model,
                                        input_lb_all, input_ub_all,
                                        images_lb_all, images_ub_all,
                                        lower_A_all, lower_bias_all,
                                        upper_A_all, upper_bias_all,
                                        linear_perturbation=False)
                    ret_all_lower.append(ret[0].detach())
                    ret_all_upper.append(ret[1].detach())

                input_lb_all.clear()
                input_ub_all.clear()
                images_lb_all.clear()
                images_ub_all.clear()
                lower_A_all.clear()
                lower_bias_all.clear()
                upper_A_all.clear()
                upper_bias_all.clear()
                bounded_classification_model._clear_and_set_new(None)
                torch.cuda.empty_cache()

    elif input_type=="rot":
        if dataname =='airplane_grey':
            theta_start , theta_end, theta_step =  0.0 ,1.0, 0.003
            # theta_start , theta_end, theta_step =  0 ,0.01, 0.001
        elif dataname =='truck_america':
             theta_start , theta_end, theta_step = 0 ,6.2832, 0.000005
        elif dataname =='car_porsche_small' or dataname =='car_blue':
             theta_start , theta_end, theta_step = 0 ,6.2832, 0.001
        theta_vals = np.arange(theta_start , theta_end, theta_step*2)
        iteration = 0

        for cur_val in tqdm(theta_vals):
            cur_val = float(cur_val)
            dist_to_object=4
            cur_x = xyzrpy[0]+dist_to_object*(math.cos(cur_val)-math.cos(0))
            cur_z = xyzrpy[2]+dist_to_object*(math.sin(cur_val)-math.sin(0))
            cur_pitch = xyzrpy[4]+cur_val
            input_tensor = torch.tensor([cur_x,cur_z,cur_pitch]).to(device)
            #print('check:',xyzrpy, input_tensor)

            with torch.no_grad():
                input_lb, input_ub, lb, ub, lower_A, lower_bias, upper_A, upper_bias = compute_image_bound(ray_model, bounded_ray_model, input_tensor, input_type, num_sampling, xyz_step, ry_step,hue_step,satur_step,
                                                                    start_vis_height, end_vis_height, tile_height,
                                                                    start_vis_width, end_vis_width, tile_width, print_flag,
                                                                    visual_flag, device,dist_to_object,theta_step)

                input_lb_all.append(input_lb.detach().unsqueeze(0))
                input_ub_all.append(input_ub.detach().unsqueeze(0))
                images_lb_all.append(lb.detach().unsqueeze(0))
                images_ub_all.append(ub.detach().unsqueeze(0))
                lower_A_all.append(lower_A.detach().unsqueeze(0))
                lower_bias_all.append(lower_bias.detach().unsqueeze(0))
                upper_A_all.append(upper_A.detach().unsqueeze(0))
                upper_bias_all.append(upper_bias.detach().unsqueeze(0))

            iteration += 1
            if iteration % args.classification_batch_size == 0 or iteration == len(vals):
                # print('stop:',testimg.shape, testimg.dtype)
                # testimg=cv2.resize(testimg.detach().cpu().numpy(),(32,32))
                # testimg = torch.from_numpy(testimg).to(device)
                # testimg = testimg.permute(2,0,1)
                # testimg = testimg.reshape(1,-1).to(device)
                # print(bounded_classification_model(testimg))
                with torch.no_grad():
                    ret = classification(bounded_classification_model,
                                        input_lb_all, input_ub_all,
                                        images_lb_all, images_ub_all,
                                        lower_A_all, lower_bias_all,
                                        upper_A_all, upper_bias_all,
                                        linear_perturbation=False)
                    ret_all_lower.append(ret[0].detach())
                    ret_all_upper.append(ret[1].detach())

                input_lb_all.clear()
                input_ub_all.clear()
                images_lb_all.clear()
                images_ub_all.clear()
                lower_A_all.clear()
                lower_bias_all.clear()
                upper_A_all.clear()
                upper_bias_all.clear()
                bounded_classification_model._clear_and_set_new(None)
                torch.cuda.empty_cache()


    elif input_type=="ry":
        for cur_roll, cur_yaw in tqdm(itertools.product(roll_vals, yaw_vals)):
            cur_roll,cur_yaw=float(cur_roll),float(cur_yaw)
            input=torch.tensor([cur_roll,cur_yaw]).to(device)

            # lb, ub, lower_A, lower_bias, upper_A, upper_bias = compute_image_bound(ray_model,input,input_type,num_sampling, xyz_step,ry_step,hue_step,satur_step,\
            #                     start_vis_height,end_vis_height,tile_height,start_vis_width,end_vis_width,tile_width,\
            #                     print_flag,visual_flag,device)

            # # images_lb.append(image_lb)
            # # images_ub.append(image_ub)
            # # images_ls.append(image_ls)
            # # images_us.append(image_us)
            # # sampling_times+=sampling_time

            input_lb, input_ub, lb, ub, lower_A, lower_bias, upper_A, upper_bias = compute_image_bound(ray_model, input_tensor, input_type, num_sampling, xyz_step, ry_step,hue_step,satur_step,
                                                                start_vis_height, end_vis_height, tile_height,
                                                                start_vis_width, end_vis_width, tile_width, print_flag,
                                                                visual_flag, device)    

    elif input_type=="xyz": 
        for cur_x, cur_y, cur_z in tqdm(itertools.product(x_vals, y_vals, z_vals)):
            cur_x, cur_y, cur_z=float(cur_x), float(cur_y), float(cur_z)
            input=torch.tensor([cur_x, cur_y, cur_z]).to(device)

            lb, ub, lower_A, lower_bias, upper_A, upper_bias = compute_image_bound(ray_model,input,input_type,num_sampling, xyz_step,ry_step,hue_step,satur_step,\
                                start_vis_height,end_vis_height,tile_height,start_vis_width,end_vis_width,tile_width,\
                                print_flag,visual_flag,device)

            # images_lb.append(image_lb)
            # images_ub.append(image_ub)
            # images_ls.append(image_ls)
            # images_us.append(image_us)
            # sampling_times+=sampling_time

    elif input_type=="xyzry":
        for cur_x, cur_y, cur_z,cur_roll, cur_yaw in tqdm(itertools.product(x_vals, y_vals, z_vals,roll_vals, yaw_vals)):
            cur_x, cur_y, cur_z=float(cur_x), float(cur_y), float(cur_z)
            cur_roll,cur_yaw=float(cur_roll),float(cur_yaw)
            input=torch.tensor([cur_x, cur_y, cur_z, cur_roll,cur_yaw]).to(device)
            
            lb, ub, lower_A, lower_bias, upper_A, upper_bias = compute_image_bound(ray_model,input,input_type,num_sampling, xyz_step,ry_step,hue_step,satur_step,\
                                start_vis_height,end_vis_height,tile_height,start_vis_width,end_vis_width,tile_width,\
                                print_flag,visual_flag,device)

            # images_lb.append(image_lb)
            # images_ub.append(image_ub)
            # images_ls.append(image_ls)
            # images_us.append(image_us)
            # sampling_times+=sampling_time

    elif input_type=="env":
        for cur_hue, cur_satur in tqdm(itertools.product(hue_vals, satur_vals)):
            cur_hue, cur_satur=float(cur_hue), float(cur_satur)
            input=torch.tensor([cur_hue, cur_satur]).to(device)
            
            lb, ub, lower_A, lower_bias, upper_A, upper_bias = compute_image_bound(ray_model,input,input_type,num_sampling, xyz_step,ry_step,hue_step,satur_step,\
                                start_vis_height,end_vis_height,tile_height,start_vis_width,end_vis_width,tile_width,\
                                print_flag,visual_flag,device)

            # images_lb.append(image_lb)
            # images_ub.append(image_ub)
            # images_ls.append(image_ls)
            # images_us.append(image_us)
            # sampling_times+=sampling_time

    ret_all_lower = torch.concat(ret_all_lower, dim=0)
    ret_all_upper = torch.concat(ret_all_upper, dim=0)

    if dataname == "airplane" or dataname =="airplane_grey":
        # print("ret_all_lower:",ret_all_lower)
        # print("ret_all_upper:", ret_all_upper)
        target_lower = ret_all_lower[:, 0]
        other_upper = ret_all_upper[:, 4:].max(dim=1)[0]
        diff=target_lower-other_upper
        diff=diff.detach().cpu().numpy()
        diff_bool=[1 if val>0 else 0 for val in diff]
        # print('diff:',diff)
        # print('diff_bool:',diff_bool)
        print('ratio:',np.sum(diff_bool)/len(diff_bool))

        with open("classifier_airplane_result.txt", "w") as f:
            f.write(" ".join(str(val) for val in diff_bool))
    
    elif dataname == "truck" or dataname == "truck_america":
        # print("ret_all_lower:",ret_all_lower)
        # print("ret_all_upper:", ret_all_upper)
        target_lower = ret_all_lower[:, 9]
        other_upper = ret_all_upper[:, :9].max(dim=1)[0]
        diff=target_lower-other_upper
        diff=diff.detach().cpu().numpy()
        diff_bool=[1 if val>0 else 0 for val in diff]
        # print('diff:',diff)
        # print('diff_bool:',diff_bool)

        with open("classifier_truck_result.txt", "w") as f:
            f.write(" ".join(str(val) for val in diff_bool))

    elif dataname == "car" or dataname == "car_porsche_small" or dataname =='car_blue':
        # print("ret_all_lower:",ret_all_lower)
        # print("ret_all_upper:", ret_all_upper)
        target_lower = ret_all_lower[:, 0:2].max(dim=1)[0]
        other_upper = ret_all_upper[:, 4:9].max(dim=1)[0]
        diff=target_lower-other_upper
        diff=diff.detach().cpu().numpy()
        diff_bool=[1 if val>0 else 0 for val in diff]
        # print('diff:',diff)
        # print('diff_bool:',diff_bool)
        print('ratio:',np.sum(diff_bool)/len(diff_bool))

        with open("classifier_car_result.txt", "w") as f:
            f.write(" ".join(str(val) for val in diff_bool))
    
    
    plt.figure(figsize=(10, 5))
    plt.plot(range(len(target_lower)), (target_lower - other_upper).detach().cpu().numpy(), label='Verified Gap')
    plt.axhline(0, color='red', linestyle='--', label='y = 0')
    plt.savefig(f'verified_data_{dataname}_gap_testimgidx{testimgidx}_input_type_{input_type}_xyzeps{xyz_eps}_xyz_step{xyz_step}.png')
