from operator import ge
from statistics import fmean
from turtle import forward
from sympy.logic.inference import valid
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
import numpy as np
from Model import NeuralRF, WeightsModel
import torch.distributions as dist
import math

class DifferentiableRFrenderer(nn.Module):
    def __init__(
        self,
        fmean_network: Optional[NeuralRF] = None,
        max_steps: int = 128,
        max_distance: float = 20,
        epsilon: float = 0.05,
        structure_hidden_dim: int = 64,
        material_hidden_dim: int = 128,
        num_layers: int = 8,
        use_sphere_init: bool = False,
        use_box_init: bool = True,
        sphere_center: torch.Tensor = None,
        box_size: torch.Tensor = torch.tensor([17.0, 13.0, 3.0], dtype=torch.float32),
        box_center: torch.Tensor = torch.tensor([8.5, 8.0, 1.5], dtype=torch.float32),
        num_spheres: int = 3,
        sphere_radius_range: tuple = (0.8, 1.0),
        space_boundaries: torch.Tensor = torch.tensor([20.0, 20.0, 3.0], dtype=torch.float32).reshape(1, 3)
    ):
        super().__init__()

        if fmean_network is None:
            self.fmean_network = NeuralRF(
                structure_hidden_dim=structure_hidden_dim,
                material_hidden_dim=material_hidden_dim,
                num_layers=num_layers,
                geometric_init=True
            )
        else:
            self.fmean_network = fmean_network

        self.max_steps = max_steps
        self.max_distance = max_distance
        self.epsilon = epsilon
        self.space_boundaries = space_boundaries
        
        self.weights_model = WeightsModel(
            hidden_dim=128,
            num_layers=8,
            encoding_dims=6,
            skip_connections=[4]
        )
        
        if use_sphere_init and fmean_network is None:
            self._initialize_as_sphere(sphere_center=sphere_center)
        elif use_box_init and fmean_network is None:
            self._initialize_as_box(size=box_size, box_center=box_center, 
                                   num_spheres=num_spheres, sphere_radius_range=sphere_radius_range)
    
    def _initialize_as_sphere(self, radius: float = 3.0, sphere_center: torch.Tensor = None):
        num_samples = 100000
        device = next(self.fmean_network.parameters()).device
        
        if sphere_center is None:
            sphere_center = torch.tensor([10, 10, 0.5], device=device, dtype=torch.float32)
        else:
            sphere_center = sphere_center.to(device)
        
        space_boundaries = self.space_boundaries.to(device)
        points = torch.rand(num_samples, 3, device=device, dtype=torch.float32) * space_boundaries
        target_fmean = torch.norm(points - sphere_center, dim=-1, keepdim=True) - radius
        optimizer = torch.optim.Adam(self.fmean_network.structure_model.parameters(), lr=1e-3)
        for _ in range(1000):
            pred_fmean, _ = self.fmean_network.structure_model(points/space_boundaries*2-1)
            loss = F.mse_loss(pred_fmean, target_fmean)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if loss.item() < 1e-5:
                break

    def _initialize_as_box(self, size: torch.Tensor = torch.tensor([17.0, 13.0, 3.0], dtype=torch.float32), box_center: torch.Tensor = torch.tensor([8.5, 8.0, 1.5], dtype=torch.float32), 
                          num_spheres: int = 3, sphere_radius_range: tuple = (0.8, 1.0)):
        device = next(self.fmean_network.parameters()).device
        num_samples = 100000
        
        size = size.to(device)
        box_center = box_center.to(device)
        
        space_boundaries = self.space_boundaries.to(device)
        points = torch.rand(num_samples, 3, device=device, dtype=torch.float32) * space_boundaries
        
        box_min = box_center - size / 2
        box_max = box_center + size / 2
        
        spheres = []
        max_attempts = 1000
        
        for sphere_idx in range(num_spheres):
            placed = False
            attempts = 0
            
            while not placed and attempts < max_attempts:
                margin = torch.tensor(1.0, dtype=torch.float32, device=device)
                xy_center = torch.rand(2, device=device, dtype=torch.float32) * 0.6 + 0.2
                xy_center = box_min[:2] + margin + xy_center * (size[:2] - 2 * margin)
                z_center = torch.rand(1, device=device, dtype=torch.float32) * (2.0 - 1.0) + 1.0
                sphere_center = torch.cat([xy_center, z_center])
                
                sphere_radius = torch.rand(1, device=device, dtype=torch.float32) * (sphere_radius_range[1] - sphere_radius_range[0]) + sphere_radius_range[0]
                
                overlap = False
                for existing_center, existing_radius in spheres:
                    distance = torch.norm(sphere_center - existing_center)
                    min_distance = sphere_radius.item() + existing_radius + 0.5
                    
                    if distance < min_distance:
                        overlap = True
                        break
                
                if not overlap:
                    spheres.append((sphere_center, sphere_radius.item()))
                    placed = True
                
                attempts += 1
            
            if not placed:
                print(f"Warning: Could not place sphere {sphere_idx + 1} without overlap after {max_attempts} attempts")
                sphere_radius = torch.tensor(sphere_radius_range[0], device=device, dtype=torch.float32)
                spheres.append((sphere_center, sphere_radius.item()))
        
        print(f"Generated {num_spheres} random spheres inside box:")
        
        distances_to_box = torch.zeros(num_samples, device=device, dtype=torch.float32)
        
        for i in range(num_samples):
            point = points[i]
            
            dist_to_faces = torch.zeros(6, device=device, dtype=torch.float32)
            
            dist_to_faces[0] = torch.abs(point[0] - box_min[0])
            dist_to_faces[1] = torch.abs(point[0] - box_max[0])
            dist_to_faces[2] = torch.abs(point[1] - box_min[1])
            dist_to_faces[3] = torch.abs(point[1] - box_max[1])
            dist_to_faces[4] = torch.abs(point[2] - box_min[2])
            dist_to_faces[5] = torch.abs(point[2] - box_max[2])
           
            inside_box = ((point[0] >= box_min[0]) & (point[0] <= box_max[0]) &
                         (point[1] >= box_min[1]) & (point[1] <= box_max[1]) &
                         (point[2] >= box_min[2]) & (point[2] <= box_max[2]))
            
            if inside_box:
                inside_any_sphere = False
                min_sphere_distance = float('inf')
                
                for sphere_center, sphere_radius in spheres:
                    distance_to_sphere = torch.norm(point - sphere_center)
                    if distance_to_sphere <= sphere_radius:
                        inside_any_sphere = True
                        break
                    min_sphere_distance = min(min_sphere_distance, distance_to_sphere - sphere_radius)
                
                if inside_any_sphere:
                    distances_to_box[i] = -torch.min(dist_to_faces)
                else:
                    distances_to_box[i] = torch.min(dist_to_faces)
            else:
                distances_to_box[i] = -torch.min(dist_to_faces)
        
        optimizer = torch.optim.Adam(self.fmean_network.structure_model.parameters(), lr=1e-3)
        
        print("Initializing Structure Model as box with spheres...")
        for epoch in range(20000):
            pred_sdf, _ = self.fmean_network.structure_model(points/space_boundaries*2-1)
            pred_sdf = pred_sdf.squeeze(-1)
            sdf_loss = F.mse_loss(pred_sdf, distances_to_box.to(pred_sdf.dtype))
            
            eikonal_points = torch.rand(1024, 3, device=device, dtype=torch.float32) * space_boundaries
            normalized_eikonal_points = eikonal_points/space_boundaries*2-1
            normalized_eikonal_points.requires_grad_(True)
            
            eikonal_pred_sdf, _ = self.fmean_network.structure_model(normalized_eikonal_points)
            eikonal_pred_sdf = eikonal_pred_sdf.squeeze(-1)
            
            gradients = torch.autograd.grad(
                outputs=eikonal_pred_sdf,
                inputs=normalized_eikonal_points,
                grad_outputs=torch.ones_like(eikonal_pred_sdf, dtype=torch.float32),
                create_graph=False,
                retain_graph=False
            )[0]
            gradients = torch.clamp(gradients, min=-5.0, max=5.0)
            gradient_norm = gradients.norm(dim=-1)
            eikonal_loss = F.huber_loss(gradient_norm, torch.ones_like(gradient_norm, dtype=torch.float32), delta=0.1)
            
            eikonal_weight = 1.0
            total_loss = sdf_loss+ eikonal_weight * eikonal_loss
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            if sdf_loss.item() < 1e-4:
                break
        
        print("Box with spheres initialization completed!")

    def sphere_trace_differentiable(
        self,
        ray_origins: torch.Tensor,
        ray_directions: torch.Tensor,
        Rx_center: torch.Tensor,
        Rx_radius: float,
        beta: torch.Tensor,
        use_alpha_blending: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        B = ray_origins.shape[0]
        device = ray_origins.device
        space_boundaries = self.space_boundaries.to(device)

        t = torch.zeros(B, device=device, dtype=torch.float32)
        hit_mask = torch.zeros(B, dtype=torch.bool, device=device)
        Rx_mask = torch.zeros(B, dtype=torch.bool, device=device)

        if use_alpha_blending:
            weights = torch.zeros(B, device=device, dtype=torch.float32)
            accumulated_alpha = torch.zeros(B, device=device, dtype=torch.float32)

        min_step = self.epsilon * 0.5

        for step in range(self.max_steps):
            points = ray_origins + t.unsqueeze(-1) * ray_directions

            fmean_values, _ = self.fmean_network.structure_model(points/space_boundaries*2-1)
            fmean_values = fmean_values.squeeze(-1)

            if use_alpha_blending:
                alpha = torch.sigmoid(-beta * fmean_values)

                weights = weights + alpha * (1 - accumulated_alpha)
                accumulated_alpha = accumulated_alpha + alpha * (1 - accumulated_alpha)

                hit_mask = accumulated_alpha > 0.5
            else:
                new_hits = (torch.abs(fmean_values) < self.epsilon) & (~hit_mask)
                hit_mask = hit_mask | new_hits

            too_far = t > self.max_distance
            if (hit_mask | too_far).all():
                break

            step_size = torch.abs(fmean_values) * 0.9
            step_size = torch.clamp(step_size, min=min_step, max=0.5)

            active = ~hit_mask & ~too_far
            t = torch.where(active, t + step_size, t)

        depths = torch.where(hit_mask, t, torch.full_like(t, self.max_distance, dtype=torch.float32))
        Rx_mask, Rx_dist = self.compute_Rxray_intersection_mask(
            ray_origins,
            ray_directions,
            Rx_center,
            Rx_radius,
            depths
        )
        hit_mask = hit_mask & ~Rx_mask

        depths = torch.where(Rx_mask, Rx_dist, depths)


        return depths, hit_mask, Rx_mask


    def compute_normal(self, points: torch.Tensor) -> torch.Tensor:
        space_boundaries = self.space_boundaries.to(points.device)
        normals = self.fmean_network.structure_model.gradient(points/space_boundaries*2-1)
        normals_clamp = torch.clamp(normals, -5.0, 5.0)
        return normals_clamp
        eps = 1e-3
        device = points.device

        offsets = torch.tensor([
            [eps, 0, 0], [-eps, 0, 0],
            [0, eps, 0], [0, -eps, 0],
            [0, 0, eps], [0, 0, -eps]
        ], device=device)

        all_points = points.unsqueeze(1) + offsets.unsqueeze(0)
        all_points = all_points.reshape(-1, 3)

        all_fmean, _, _, _ = self.fmean_network(all_points)
        all_fmean = all_fmean.reshape(points.shape[0], 6, 1)

        grad_x = (all_fmean[:, 0] - all_fmean[:, 1]) / (2 * eps)
        grad_y = (all_fmean[:, 2] - all_fmean[:, 3]) / (2 * eps)
        grad_z = (all_fmean[:, 4] - all_fmean[:, 5]) / (2 * eps)

        normals = torch.stack([grad_x, grad_y, grad_z], dim=-1).squeeze(-2)
        normals = F.normalize(normals, dim=-1)

        return normals

    
    def indentify_valid_paths(
        self,
        num_rays_azimuth: int,
        num_rays_elevation: int,
        Rx_center: torch.Tensor,
        Rx_radius: float,
        Tx_position: torch.Tensor,
        radius: float = 0.3,
        max_bounces: int = 3,
        single_beta: float = 50.0,
        use_soft_rendering: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        device = next(self.fmean_network.parameters()).device
        space_boundaries = self.space_boundaries.to(device)
        
        ray_origins, ray_directions, ray_weights = self.generate_sphere_rays(
            num_rays_azimuth, num_rays_elevation, Tx_position, radius
        )
        B = num_rays_azimuth * num_rays_elevation * 8
        ray_origins = ray_origins.reshape(B, 3)
        ray_directions = ray_directions.reshape(B, 3)
        beta = torch.ones(B, dtype=torch.float32, device=device) * single_beta

        all_origins = torch.zeros(B, 3*max_bounces, device=device, dtype=torch.float32)
        all_directions = torch.zeros(B, 3*max_bounces, device=device, dtype=torch.float32)
        all_depths = torch.zeros(B, max_bounces, device=device, dtype=torch.float32)
        all_diffuse_origins = torch.zeros(B, 3*2, device=device, dtype=torch.float32)
        all_diffuse_directions = torch.zeros(B, 3*2, device=device, dtype=torch.float32)
        all_diffuse_depths = torch.zeros(B, 2, device=device, dtype=torch.float32)
        active_rays = torch.ones(B, dtype=torch.bool, device=device)
        final_Rx_mask = torch.zeros(B, dtype=torch.bool, device=device)
        final_diffuse_mask = torch.zeros(B, dtype=torch.bool, device=device)
        valid_penetration_mask = torch.zeros(B, dtype=torch.bool, device=device)
        final_penetration_mask = torch.ones(B, dtype=torch.bool, device=device)

        all_origins[:, 0:3] = ray_origins
        all_directions[:, 0:3] = ray_directions
        all_diffuse_origins[:, 0:3] = ray_origins
        all_diffuse_directions[:, 0:3] = ray_directions

        direct_mask, direct_direction, direct_distance = self.firsthit_RX_diffuse(Tx_position.unsqueeze(0), Rx_center)

        
        for bounce in range(max_bounces):
            if not active_rays.any():
                break

            current_origins = all_origins[:, bounce*3:(bounce+1)*3]
            current_directions = all_directions[:, bounce*3:(bounce+1)*3]

            active_indices = torch.where(active_rays)[0]

            if bounce == 1:
                diffuse_mask, diffuse_directions, distances = self.firsthit_RX_diffuse(current_origins[active_rays], Rx_center)
                diffuse_indices = active_indices[diffuse_mask]
                final_diffuse_mask[diffuse_indices] = True
                if diffuse_indices.numel() > 0:
                    all_diffuse_origins[diffuse_indices, 3:3*2] = current_origins[active_rays][diffuse_mask]
                    all_diffuse_directions[diffuse_indices, 3:3*2] = diffuse_directions[diffuse_mask]
                    all_diffuse_depths[diffuse_indices, 1] = distances[diffuse_mask]
                

            depths, hit_mask, Rx_mask = self.sphere_trace_differentiable(
                current_origins[active_rays],
                current_directions[active_rays],
                Rx_center,
                Rx_radius,
                beta[active_rays],
                use_alpha_blending=use_soft_rendering
            )

            hit_indices = active_indices[hit_mask]

            if hit_indices.numel() > 0 and bounce < max_bounces - 1:

                hit_points = current_origins[hit_indices] + \
                            depths[hit_mask].unsqueeze(-1) * current_directions[hit_indices]
                hit_points.requires_grad_(True)
                normals = self.compute_normal(hit_points)

                d_dot_n = (current_directions[hit_indices] * normals).sum(dim=-1, keepdim=True)
                reflected_dirs = current_directions[hit_indices] - 2 * d_dot_n * normals
                reflected_dirs = F.normalize(reflected_dirs, dim=-1)

                _, _, anisotropic_hit, _ = self.fmean_network(hit_points/space_boundaries*2-1)
                sigma_hit_reflection = anisotropic_hit * (torch.abs(torch.sum(reflected_dirs * normals, dim=-1, keepdim=True))/(torch.norm(normals, dim=-1, keepdim=True) + 1e-6)) + (1-anisotropic_hit)/2
                sigma_hit_penetration = anisotropic_hit * (torch.abs(torch.sum(current_directions[hit_indices] * normals, dim=-1, keepdim=True))/(torch.norm(normals, dim=-1, keepdim=True) + 1e-6)) + (1-anisotropic_hit)/2

                prob_penetration = sigma_hit_reflection / (sigma_hit_reflection + sigma_hit_penetration + 1e-6)
                rand_vals = torch.rand_like(prob_penetration, device=device, dtype=torch.float32)
                reflection_mask = (rand_vals >= prob_penetration).squeeze(-1)
                penetration_mask = ~reflection_mask
                reflection_indices = hit_indices[reflection_mask]
                penetration_indices = hit_indices[penetration_mask]
                if penetration_indices.numel() > 0:
                    valid_penetration_mask[penetration_indices] = True
                final_penetration_mask = final_penetration_mask & valid_penetration_mask

                if reflection_indices.numel() > 0:
                    all_origins[reflection_indices, 3*(bounce+1):3*(bounce+2)] = (hit_points[reflection_mask] + normals[reflection_mask] * self.epsilon * 2)
                    all_directions[reflection_indices, 3*(bounce+1):3*(bounce+2)] = reflected_dirs[reflection_mask]
                if penetration_indices.numel() > 0:
                    all_origins[penetration_indices, 3*(bounce+1):3*(bounce+2)] = (hit_points[penetration_mask] - normals[penetration_mask] * self.epsilon * 2)
                    all_directions[penetration_indices, 3*(bounce+1):3*(bounce+2)] = current_directions[hit_indices][penetration_mask]

                flag = torch.zeros_like(beta, dtype=torch.float32)
                if reflection_indices.numel() > 0:
                    flag[reflection_indices] = 1.0
                if penetration_indices.numel() > 0:
                    flag[penetration_indices] = -1.0
                beta = beta * flag

            if hit_indices.numel() > 0:
                all_depths[hit_indices, bounce] = depths[hit_mask]
            if Rx_mask.sum() > 0:
                all_depths[active_indices, bounce][Rx_mask] = depths[Rx_mask]
            if bounce == 0:
                all_diffuse_depths[hit_indices, 0] = depths[hit_mask]

            if active_indices.numel() > 0:
                final_Rx_mask[active_indices] = final_Rx_mask[active_indices] | Rx_mask
                
            miss_indices = active_indices[~hit_mask]
            if miss_indices.numel() > 0:
                new_active_rays = active_rays.clone()
                new_active_rays[miss_indices] = False
                active_rays = new_active_rays
        
        final_penetration_mask = final_penetration_mask & final_Rx_mask

        return all_origins, all_directions, all_depths, final_Rx_mask, final_diffuse_mask, all_diffuse_origins, all_diffuse_directions, all_diffuse_depths, direct_mask, direct_direction, direct_distance, final_penetration_mask, ray_weights
       

    def generate_sphere_rays(
        self,
        num_rays_azimuth: int,
        num_rays_elevation: int,
        Tx_position: torch.Tensor,
        radius: float = 0.3,
        device: Optional[torch.device] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if device is None:
            device = next(self.fmean_network.parameters()).device
            
        rays_per_area = 8

        azimuth_bins = torch.linspace(0, 360, num_rays_azimuth + 1, device=device, dtype=torch.float32)
        elevation_bins = torch.linspace(0, 180, num_rays_elevation + 1, device=device, dtype=torch.float32)
        
        total_areas = num_rays_azimuth * num_rays_elevation
        total_rays = total_areas * rays_per_area
        
        theta_mins = azimuth_bins[:-1].repeat_interleave(num_rays_elevation)
        theta_maxs = azimuth_bins[1:].repeat_interleave(num_rays_elevation)
        phi_mins = elevation_bins[:-1].repeat(num_rays_azimuth)
        phi_maxs = elevation_bins[1:].repeat(num_rays_azimuth)
        
        theta_ranges = theta_maxs - theta_mins
        phi_ranges = phi_maxs - phi_mins
        
        theta_rand = torch.rand(total_areas, rays_per_area, device=device, dtype=torch.float32)
        phi_rand = torch.rand(total_areas, rays_per_area, device=device, dtype=torch.float32)
        
        theta_samples = theta_rand * theta_ranges.unsqueeze(-1) + theta_mins.unsqueeze(-1)
        phi_samples = phi_rand * phi_ranges.unsqueeze(-1) + phi_mins.unsqueeze(-1)
        
        theta_flat = theta_samples.reshape(-1) * np.pi / 180.0
        phi_flat = phi_samples.reshape(-1) * np.pi / 180.0

        x = radius * torch.cos(phi_flat) * torch.cos(theta_flat)
        y = radius * torch.cos(phi_flat) * torch.sin(theta_flat)
        z = radius * torch.sin(phi_flat)

        sphere_points = torch.stack([x, y, z], dim=-1)
        ray_origins_world = sphere_points + Tx_position.unsqueeze(0)

        ray_directions_world = F.normalize(ray_origins_world - Tx_position.unsqueeze(0), dim=-1)

        space_boundaries = self.space_boundaries.to(device) 
        tx_position_expanded = Tx_position.unsqueeze(0).expand(ray_directions_world.shape[0], -1)
        tx_position_expanded = tx_position_expanded / space_boundaries * 2 - 1
        ray_weights_world = torch.abs(F.leaky_relu(self.weights_model(ray_directions_world, tx_position_expanded), negative_slope=0.01))
        nor_ray_weights = ray_weights_world / ray_weights_world.sum()

        return ray_origins_world, ray_directions_world, nor_ray_weights

    def firsthit_RX_diffuse(
        self,
        first_hit_points: torch.Tensor,
        Rx_center: torch.Tensor,
    ) -> torch.Tensor:
        B = first_hit_points.shape[0]
        device = first_hit_points.device
        space_boundaries = self.space_boundaries.to(device)
        Rx_center = Rx_center.expand(B, 3)
        t = torch.zeros(B, dtype=torch.float32, device=first_hit_points.device)
        diffuse_mask = torch.zeros(B, dtype=torch.bool, device=first_hit_points.device)
        diffuse_directions = F.normalize(Rx_center - first_hit_points, dim=-1)
        depth = torch.norm(Rx_center - first_hit_points, dim=-1)
        for step in range(self.max_steps):
            points = first_hit_points + diffuse_directions * t.unsqueeze(-1)
            distances = torch.norm(Rx_center - points, dim=-1)
            fmean_values, _ = self.fmean_network.structure_model(points/space_boundaries*2-1)
            fmean_values = fmean_values.squeeze(-1)
            diffuse_mask = diffuse_mask | (fmean_values > distances)
            step_size = fmean_values * 0.9
            step_size = torch.clamp(step_size, min=self.epsilon*0.02, max=0.5)
            active = ~diffuse_mask
            t = torch.where(active, t + step_size, t)

        return diffuse_mask, diffuse_directions, depth


    def compute_Rxray_intersection_mask(
        self,
        ray_origins: torch.Tensor,
        ray_directions: torch.Tensor,
        rx_center: torch.Tensor,
        rx_radius: float,
        step_size: float
    ) -> torch.Tensor:
        device = ray_origins.device
        B = ray_origins.shape[0]
        
        rx_center = rx_center.expand(B, 3)
        
        to_rx = rx_center - ray_origins
        
        projection = torch.sum(to_rx * ray_directions, dim=-1)
        
        if isinstance(step_size, (int, float)):
            step_size_tensor = torch.full_like(projection, step_size, dtype=torch.float32)
        else:
            step_size_tensor = step_size
        min_tensor = torch.zeros_like(projection, dtype=torch.float32)
        projection_clamped = torch.clamp(projection, min=min_tensor, max=step_size_tensor)
        closest_point = ray_origins + projection_clamped.unsqueeze(-1) * ray_directions
        
        distance_to_rx = torch.norm(closest_point - rx_center, dim=-1)
        
        rx_mask = distance_to_rx <= rx_radius
        rx_dist = torch.where(rx_mask, torch.norm(rx_center - ray_origins, dim=-1), torch.full_like(distance_to_rx, float('inf'), dtype=torch.float32))
        
        return rx_mask, rx_dist

    
    def render_RSSI(
        self,
        Tx_positions: torch.Tensor,
        Rx_positions: torch.Tensor,
        Tx_signals: torch.Tensor,
        valid_paths_origins: torch.Tensor,
        valid_paths_directions: torch.Tensor,
        valid_paths_depths: torch.Tensor,
        diffuse_paths_origins: torch.Tensor,
        diffuse_paths_directions: torch.Tensor,
        diffuse_paths_depths: torch.Tensor,
        direct_mask: torch.Tensor,
        direct_direction: torch.Tensor,
        direct_distance: torch.Tensor,
        penetration_paths_origins: torch.Tensor,
        penetration_paths_directions: torch.Tensor,
        penetration_paths_depths: torch.Tensor,
        sample_spacing: float,
        num_rays_azimuth: int,
        num_rays_elevation: int,
        ray_weights_valid: torch.Tensor,
        ray_weights_diffuse: torch.Tensor,
        ray_weights_penetration: torch.Tensor,
        max_bounces: int = 3,
    ) -> torch.Tensor:
        
        L = valid_paths_origins.shape[0]
        D = diffuse_paths_origins.shape[0]
        T = direct_mask.shape[0]
        P = penetration_paths_origins.shape[0]

        total_attenuation_direct = self.process_direct_paths(
            Tx_positions.unsqueeze(0),
            direct_direction,
            direct_distance,
        )
        total_attenuation_penetration = self.process_penetration_paths(
            penetration_paths_origins,
            penetration_paths_directions,
            penetration_paths_depths,
            sample_spacing,
            ray_weights_penetration,
        )
        total_attenuation_valid = self.process_valid_paths(
            valid_paths_origins,
            valid_paths_directions,
            valid_paths_depths,
            ray_weights_valid,
            max_bounces,
        )
        total_attenuation_diffuse = self.process_diffuse_paths(
            diffuse_paths_origins,
            diffuse_paths_directions,
            diffuse_paths_depths,
            ray_weights_diffuse,
            2,
        )

        total_rays = num_rays_azimuth * num_rays_elevation * 8
        received_signal = (total_attenuation_valid + total_attenuation_diffuse + total_attenuation_direct/total_rays + total_attenuation_penetration) * 1.0
        RSSI = -100 * (1 - torch.abs(received_signal))

        return RSSI
    
    def process_direct_paths(
        self,
        paths_origins: torch.Tensor,
        paths_directions: torch.Tensor,
        paths_depths: torch.Tensor,
    ) -> torch.Tensor:
        if paths_origins.shape[0] > 0:
            return torch.tensor(1.0, device=paths_origins.device, dtype=torch.float32)
        else:
            return torch.tensor(0.0, device=paths_origins.device, dtype=torch.float32)  
    
    def process_penetration_paths(
        self,
        paths_origins: torch.Tensor,
        paths_directions: torch.Tensor,
        paths_depths: torch.Tensor,
        sample_spacing: float,
        ray_weights: torch.Tensor,
    ) -> torch.Tensor:
        B = paths_origins.shape[0]
        device = paths_origins.device
        
        if B > 0:
            penetration_log_attenuation = self.process_paths(
                paths_origins[:, 3:6],
                paths_directions[:, 3:6],
                paths_depths[:, 1],
                1,
                sample_spacing,
            )
            penetration_total_attenuation = torch.sum(torch.exp(-penetration_log_attenuation).unsqueeze(-1) * ray_weights)
        else:
            penetration_total_attenuation = torch.tensor(0.0, device=device, dtype=torch.float32)

        return penetration_total_attenuation
    
    def process_diffuse_paths(
        self,
        paths_origins: torch.Tensor,
        paths_directions: torch.Tensor,
        paths_depths: torch.Tensor,
        ray_weights: torch.Tensor,
        diffuse_bounces: int = 2,
    ) -> torch.Tensor:
        B = paths_origins.shape[0]
        device = paths_origins.device
        epsilon = self.epsilon
        space_boundaries = self.space_boundaries.to(device)
        if B > 0:
            points = paths_origins[:, 3:6]
            directions = paths_directions[:, 3:6]
            fmean_values, materials, anistropics, _ = self.fmean_network(points/space_boundaries*2-1)
            normals = self.compute_normal(points)
            sigmas = self.calculate_attenuation(
                fmean_values, materials, anistropics, normals, directions, scaling=1.0
            )
            sigmas = sigmas.squeeze(-1)
            diffuse_attenuation = torch.sum(torch.exp(-sigmas * 2 * epsilon).unsqueeze(-1) * ray_weights)
        else:
            diffuse_attenuation = torch.tensor(0.0, device=device, dtype=torch.float32)


        return diffuse_attenuation
        

    def process_valid_paths(
        self,
        paths_origins: torch.Tensor,
        paths_directions: torch.Tensor,
        paths_depths: torch.Tensor,
        ray_weights: torch.Tensor,
        max_bounces: int = 3,
    ) -> torch.Tensor:

        B = paths_origins.shape[0]
        device = paths_origins.device
        epsilon = self.epsilon
        space_boundaries = self.space_boundaries.to(device)
        if B > 0:
            points = torch.cat([paths_origins[:, 3:6], paths_origins[:, 6:9]], dim=0)
            directions = torch.cat([paths_directions[:, 3:6], paths_directions[:, 6:9]], dim=0)
            fmean_values, materials, anistropics, _ = self.fmean_network(points/space_boundaries*2-1)
            normals = self.compute_normal(points)
            sigmas = self.calculate_attenuation(
                fmean_values, materials, anistropics, normals, directions, scaling=1.0
            )
            sigmas = sigmas.squeeze(-1)
            ray_weights_expanded = torch.cat([ray_weights, ray_weights], dim=0)
            valid_attenuation = torch.sum(torch.exp(-sigmas * 2 * epsilon).unsqueeze(-1) * ray_weights_expanded)/2
        else:
            valid_attenuation = torch.tensor(0.0, device=paths_origins.device, dtype=torch.float32)

        return valid_attenuation


    def process_paths(
        self,
        paths_origins: torch.Tensor,
        paths_directions: torch.Tensor,
        paths_depths: torch.Tensor,
        max_bounces: int = 3,
        sample_spacing: float = 0.02,
    ) -> torch.Tensor:
        N = paths_origins.shape[0]
        device = paths_origins.device
        space_boundaries = self.space_boundaries.to(device)
        
        if paths_depths.numel() == 0:
            return torch.tensor(50.0, device=paths_origins.device, dtype=torch.float32)

        paths_origins = paths_origins.reshape(N, max_bounces, 3)
        paths_directions = paths_directions.reshape(N, max_bounces, 3)
        paths_depths = paths_depths.reshape(N, max_bounces)

        max_depth = paths_depths.max().item()
        max_samples = int(torch.ceil(torch.tensor(max_depth / sample_spacing, dtype=torch.float32)).item())

        t_values = torch.arange(0, max_samples, device=paths_origins.device, dtype=torch.float32) * sample_spacing
        t_values = t_values.view(1, 1, -1)

        t_values = t_values.expand(N, max_bounces, -1)

        depth_mask = t_values < paths_depths.unsqueeze(-1)

        sample_points = paths_origins.unsqueeze(2) + paths_directions.unsqueeze(2) * t_values.unsqueeze(-1)

        paths_directions_expanded = paths_directions.unsqueeze(2).expand_as(sample_points)

        sample_points = sample_points[depth_mask]
        paths_directions_valid = paths_directions_expanded[depth_mask]

        sample_points.requires_grad_(True)

        fmean_values, materials, anistropics, _ = self.fmean_network(sample_points/space_boundaries*2-1)
        normals = self.compute_normal(sample_points)
        sigmas = self.calculate_attenuation(
            fmean_values, materials, anistropics, normals,
            paths_directions_valid,
            scaling=1.0
        )

        sigmas = sigmas.squeeze(-1)

        padded_sigmas = torch.zeros(N, max_bounces, max_samples, device=paths_origins.device, dtype=torch.float32)
        padded_sigmas[depth_mask] = sigmas

        log_attenuation = padded_sigmas.sum(dim=(1,2)) * sample_spacing

        return log_attenuation


    def compute_hemisphere_mask(
        self,
        paths_directions: torch.Tensor,
    ) -> torch.Tensor:
        N = paths_directions.shape[0]
        hemisphere_mask = torch.zeros(N, device=paths_directions.device, dtype=torch.float32)

        hemisphere_mask = (paths_directions[:, 2] >= 0).float()

        return hemisphere_mask
                



    def calculate_attenuation(
        self,
        fmean_values: torch.Tensor,
        materials: torch.Tensor,
        anisotropic: torch.Tensor,
        normals: torch.Tensor,
        dirs: torch.Tensor,
        scaling: float = 1.0
    ) -> torch.Tensor:
        
        scaling_fmean = torch.abs(scaling * fmean_values)
        INV_SQRT_PI = 1.0 / math.sqrt(math.pi)
        pdf_values = scaling * INV_SQRT_PI * torch.exp(-torch.square(scaling_fmean))
        cdf_values = 0.5 + 0.5 * torch.erf(scaling_fmean)
        
        coeff = ((materials * cdf_values) + 1) * scaling * pdf_values * torch.norm(normals, dim=-1, keepdim=True) / (cdf_values + 1e-6)
        sigma_D = (torch.abs(torch.sum(dirs * normals, dim=-1, keepdim=True))/(torch.norm(normals, dim=-1, keepdim=True) + 1e-6)) + (1-anisotropic)/2
        sigma = coeff * sigma_D
        
        return sigma


    def forward(
        self,
        Tx_positions: torch.Tensor,
        Rx_positions: torch.Tensor,
        Rx_radius: float,
        Tx_signals: torch.Tensor,
        num_rays_azimuth: int,
        num_rays_elevation: int,
        sample_spacing: float,
        use_soft_rendering: bool = True,
    ) -> torch.Tensor:
        B = Tx_positions.shape[0]
        pred_RSSI = torch.zeros(B, device=Tx_positions.device, dtype=torch.float32)

        for i in range(B):
            Tx_position = Tx_positions[i]
            Rx_position = Rx_positions[i]
            Tx_signal = Tx_signals[i].squeeze()

            all_origins, all_directions, all_depths, final_Rx_mask, final_diffuse_mask, all_diffuse_origins, all_diffuse_directions, all_diffuse_depths, direct_mask, direct_direction, direct_distance, final_penetration_mask, ray_weights = self.indentify_valid_paths(
                num_rays_azimuth,
                num_rays_elevation,
                Rx_position,
                Rx_radius,
                Tx_position,
                max_bounces=3,
                single_beta=50.0,
                use_soft_rendering=use_soft_rendering
            )
            pred_RSSI[i] = self.render_RSSI(
                Tx_position,
                Rx_position,
                Tx_signal,
                all_origins[final_Rx_mask & ~final_penetration_mask],
                all_directions[final_Rx_mask & ~final_penetration_mask],
                all_depths[final_Rx_mask & ~final_penetration_mask],
                all_diffuse_origins[final_diffuse_mask],
                all_diffuse_directions[final_diffuse_mask],
                all_diffuse_depths[final_diffuse_mask],
                direct_mask,
                direct_direction,
                direct_distance,
                all_origins[final_penetration_mask],
                all_directions[final_penetration_mask],
                all_depths[final_penetration_mask],
                sample_spacing,
                num_rays_azimuth,
                num_rays_elevation,
                ray_weights[final_Rx_mask & ~final_penetration_mask],
                ray_weights[final_diffuse_mask],
                ray_weights[final_penetration_mask],
            )
        

        return pred_RSSI, all_diffuse_origins[final_diffuse_mask][:,3:6]



class RSSILoss(nn.Module):
    def __init__(self, loss_type: str = 'huber'):
        super().__init__()
        self.loss_type = loss_type

    def forward(
        self,
        pred_RSSI: torch.Tensor,
        measured_RSSI: torch.Tensor,
    ) -> torch.Tensor:
        if self.loss_type == 'l1':
            loss = F.l1_loss(pred_RSSI, measured_RSSI)
        elif self.loss_type == 'l2':
            loss = F.mse_loss(pred_RSSI, measured_RSSI)
        elif self.loss_type == 'huber':
            loss = F.smooth_l1_loss(pred_RSSI, measured_RSSI)
        elif self.loss_type == 'relative':
            eps = 1e-6
            relative_error = torch.abs(pred_RSSI - measured_RSSI) / (measured_RSSI + eps)
            loss = (relative_error).mean()
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")

        return loss



class NeuralRFOptimizer(nn.Module):
    def __init__(
        self,
        structure_hidden_dim: int = 64,
        material_hidden_dim: int = 128,
        num_layers: int = 8,
        max_steps: int = 128,
        loss_type: str = 'huber',
        use_sphere_init: bool = False,
        use_box_init: bool = True,
        sphere_center: torch.Tensor = None,
        box_size: torch.Tensor = torch.tensor([17.0, 13.0, 3.0], dtype=torch.float32),
        box_center: torch.Tensor = torch.tensor([8.5, 8.0, 1.5], dtype=torch.float32),
        num_spheres: int = 3,
        sphere_radius_range: tuple = (0.8, 1.0),
        space_boundaries: torch.Tensor = torch.tensor([20.0, 20.0, 3.0], dtype=torch.float32).reshape(1, 3)
    ):
        super().__init__()
        self.renderer = DifferentiableRFrenderer(
            structure_hidden_dim=structure_hidden_dim,
            material_hidden_dim=material_hidden_dim,
            num_layers=num_layers,
            max_steps=max_steps,
            use_sphere_init=use_sphere_init,
            use_box_init=use_box_init,
            sphere_center=sphere_center,
            box_size=box_size,
            box_center=box_center,
            num_spheres=num_spheres,
            sphere_radius_range=sphere_radius_range,
            space_boundaries=space_boundaries
        )
        self.RSSI_loss = RSSILoss(loss_type=loss_type)
        
    def to_half(self):
        self.renderer.fmean_network.structure_model = self.renderer.fmean_network.structure_model
        self.renderer.fmean_network.material_model = self.renderer.fmean_network.material_model
        return self

    def forward(
        self,
        Tx_positions: torch.Tensor,
        Rx_positions: torch.Tensor,
        Tx_signals: torch.Tensor,
        Rx_radius: float,
        num_rays_azimuth: int,
        num_rays_elevation: int,
        sample_spacing: float,
        use_soft_rendering: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        return self.renderer(
            Tx_positions,
            Rx_positions,
            Rx_radius,
            Tx_signals,
            num_rays_azimuth,
            num_rays_elevation,
            sample_spacing,
            use_soft_rendering
        )

    def compute_total_loss(
        self,
        Tx_positions: torch.Tensor,
        Rx_positions: torch.Tensor,
        Tx_signals: torch.Tensor,
        Rx_radius: float,
        num_rays_azimuth: int,
        num_rays_elevation: int,
        sample_spacing: float,
        measured_RSSI: torch.Tensor,
        space_boundaries: torch.Tensor,
        use_soft_rendering: bool = True,
        eikonal_weight: float = 1.0,
        laplacian_weight_fmean: float = 0.001,
        laplacian_weight_material: float = 0.001,
        laplacian_weight_anistropic: float = 0.001,
        laplacian_weight_phase: float = 0.001,
        free_space_weight: float = 1.0,
        all_tx_positions: torch.Tensor = None,
        all_rx_positions: torch.Tensor = None,
    ) -> dict:
        
        pred_RSSI, diffuse_surface_origins = self.forward(
            Tx_positions,
            Rx_positions,
            Tx_signals,
            Rx_radius,
            num_rays_azimuth,
            num_rays_elevation,
            sample_spacing,
            use_soft_rendering
        )
        RSSI_loss = self.RSSI_loss(pred_RSSI, measured_RSSI)

        num_eikonal_points = 256
        device = Tx_positions.device
        space_boundaries = self.renderer.space_boundaries.to(device)
        space_boundaries = space_boundaries.to(device)
        
        box_size = torch.tensor([17.0, 13.0, 3.0], device=device, dtype=torch.float32)
        box_center = torch.tensor([8.5, 8.0, 1.5], device=device, dtype=torch.float32)
        box_min = box_center - box_size / 2
        box_max = box_center + box_size / 2
        
        eikonal_points_1 = torch.rand(num_eikonal_points, 3, device=device, dtype=torch.float32)
        eikonal_points_1 = box_min + eikonal_points_1 * (box_max - box_min)
        if diffuse_surface_origins.shape[0] > 64:
            eikonal_points_2 = diffuse_surface_origins[torch.randperm(diffuse_surface_origins.shape[0])[:64]]
        else:
            eikonal_points_2 = diffuse_surface_origins
        eikonal_points = torch.cat([eikonal_points_1, eikonal_points_2], dim=0)
        eikonal_loss = self.compute_eikonal_loss(eikonal_points)

        if all_tx_positions is not None and all_rx_positions is not None:
            free_space_loss = self.compute_free_space_loss(all_tx_positions, all_rx_positions)
        else:
            free_space_loss = self.compute_free_space_loss(Tx_positions, Rx_positions)

        total_loss = (
            RSSI_loss +
            eikonal_weight * eikonal_loss +
            free_space_weight * free_space_loss
        )

        return {
            'total': total_loss,
            'RSSI': RSSI_loss,
            'eikonal': eikonal_loss,
            'free_space': free_space_loss,
            'pred_RSSI': pred_RSSI,
        }

    def compute_eikonal_loss(self, eikonal_points):
        device = eikonal_points.device
        space_boundaries = self.renderer.space_boundaries.to(device)
        
        normalized_points = eikonal_points/space_boundaries*2-1
        normalized_points.requires_grad_(True)
        
        fmean_values, _ = self.renderer.fmean_network.structure_model(normalized_points)
        
        gradients = torch.autograd.grad(
            outputs=fmean_values,
            inputs=normalized_points,
            grad_outputs=torch.ones_like(fmean_values, dtype=torch.float32),
            create_graph=False,
            retain_graph=False
        )[0]
        gradients = torch.clamp(gradients, min=-5.0, max=5.0)
        gradient_norm = gradients.norm(dim=-1)
        eikonal_loss = F.huber_loss(gradient_norm, torch.ones_like(gradient_norm, dtype=torch.float32), delta=0.1)
        
        return eikonal_loss

    def compute_free_space_loss(self, Tx_positions: torch.Tensor, Rx_positions: torch.Tensor, 
                               radius: float = 0.25, num_samples: int = 8):
        device = Tx_positions.device
        B = Tx_positions.shape[0]
        space_boundaries = self.renderer.space_boundaries.to(device)
        
        tx_sample_points = []
        rx_sample_points = []
        
        for i in range(B):
            tx_center = Tx_positions[i]
            rx_center = Rx_positions[i]
            
            tx_points = self._generate_cube_samples(tx_center, radius, num_samples, device)
            rx_points = self._generate_cube_samples(rx_center, radius, num_samples, device)
            
            tx_sample_points.append(tx_points)
            rx_sample_points.append(rx_points)
        
        tx_sample_points = torch.cat(tx_sample_points, dim=0)
        rx_sample_points = torch.cat(rx_sample_points, dim=0)
        
        tx_sdf_values, _ = self.renderer.fmean_network.structure_model(tx_sample_points/space_boundaries*2-1)
        tx_sdf_values = tx_sdf_values.squeeze(-1)
        
        rx_sdf_values, _ = self.renderer.fmean_network.structure_model(rx_sample_points/space_boundaries*2-1)
        rx_sdf_values = rx_sdf_values.squeeze(-1)
        
        tx_negative_penalty = F.relu(-tx_sdf_values)
        rx_negative_penalty = F.relu(-rx_sdf_values)
        
        tx_free_space_loss = torch.mean(tx_negative_penalty)
        rx_free_space_loss = torch.mean(rx_negative_penalty)
        free_space_loss = tx_free_space_loss + rx_free_space_loss
        
        return free_space_loss
    
    def _generate_cube_samples(self, center: torch.Tensor, radius: float, 
                              num_samples: int, device: torch.device) -> torch.Tensor:
        sample_points = torch.rand(num_samples, 3, device=device, dtype=torch.float32) * (2 * radius) - radius
        
        sample_points = sample_points + center.unsqueeze(0)
        
        return sample_points