from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
import numpy as np

class StructureModel(nn.Module):
    def __init__(
        self, 
        hidden_dim: int = 64,
        num_layers: int = 8,
        skip_connections: list = [4],
        geometric_init: bool = True,
        use_positional_encoding: bool = False,
        encoding_dims: int = 6
    ):
        super().__init__()
        self.skip_connections = skip_connections
        self.use_positional_encoding = use_positional_encoding
        self.encoding_dims = encoding_dims

        input_dim = 3
        if use_positional_encoding:
            input_dim = 3 * 2 * encoding_dims

        self.layers = nn.ModuleList()
        
        self.layers.append(nn.Linear(input_dim, hidden_dim))

        for i in range(1, num_layers):
            if i in skip_connections:
                self.layers.append(nn.Linear(input_dim + hidden_dim, hidden_dim))
            else:
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))

        self.fmean_layer = nn.Linear(hidden_dim, 1)
        self.feature_layer = nn.Linear(hidden_dim, 128)

        if geometric_init:
            self._geometric_initialization()
        else:
            self._initialize_weights()

    def _initialize_weights(self):
        for layer in self.layers:
            nn.init.xavier_uniform_(layer.weight, gain=np.sqrt(2))
            nn.init.zeros_(layer.bias)

        nn.init.normal_(self.fmean_layer.weight, 0, 0.001)
        nn.init.constant_(self.fmean_layer.bias, 0.0)

        nn.init.xavier_uniform_(self.feature_layer.weight)
        nn.init.zeros_(self.feature_layer.bias)

    def _geometric_initialization(self):
        for i, layer in enumerate(self.layers):
            if i == 0:
                nn.init.normal_(layer.weight, 0.0, np.sqrt(2) / np.sqrt(layer.out_features))
            else:
                nn.init.normal_(layer.weight, 0.0, np.sqrt(2) / np.sqrt(layer.in_features))
            nn.init.zeros_(layer.bias)

        nn.init.normal_(self.fmean_layer.weight, mean=0.0, std=1e-5)
        nn.init.constant_(self.fmean_layer.bias, 0.0)
        nn.init.normal_(self.feature_layer.weight, 0.0, 1e-5)
        nn.init.constant_(self.feature_layer.bias, 0.5)

    def positional_encoding(self, x: torch.Tensor) -> torch.Tensor:
        encodings = []
        for i in range(self.encoding_dims):
            freq = 2**i * np.pi
            encodings.append(torch.sin(freq * x))
            encodings.append(torch.cos(freq * x))
        return torch.cat(encodings, dim=-1)

    def forward(self, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.use_positional_encoding:
            x = self.positional_encoding(points)
        else:
            x = points
        input_encoding = x

        for i, layer in enumerate(self.layers):
            if i in self.skip_connections:
                x = torch.cat([x, input_encoding], dim=-1)
            
            x = layer(x)
            
            if i < len(self.layers) - 1:
                x = F.relu(x, inplace=True)
            else:
                x = F.softplus(x)
        
        fmean = self.fmean_layer(x)
        local_features = self.feature_layer(x)
        
        return fmean, local_features
    
    def gradient(self, points: torch.Tensor) -> torch.Tensor:
        points.requires_grad_(True)
        fmean, _ = self.forward(points)
        gradients = torch.autograd.grad(
            outputs=fmean,
            inputs=points,
            grad_outputs=torch.ones_like(fmean),
            create_graph=True,
            retain_graph=True
        )[0]
        return gradients

class MaterialModel(nn.Module):
    def __init__(
        self, 
        hidden_dim: int = 128,
        num_layers: int = 8,
        skip_connections: list = [4],
        use_positional_encoding: bool = False,
        encoding_dims: int = 6
    ):
        super().__init__()
        self.skip_connections = skip_connections
        self.use_positional_encoding = use_positional_encoding
        self.encoding_dims = encoding_dims

        spatial_input_dim = 3
        if use_positional_encoding:
            spatial_input_dim = 3 * 2 * encoding_dims
        
        input_dim = spatial_input_dim + 128

        self.layers = nn.ModuleList()
        
        self.layers.append(nn.Linear(input_dim, hidden_dim))

        for i in range(1, num_layers):
            if i in skip_connections:
                self.layers.append(nn.Linear(input_dim + hidden_dim, hidden_dim))
            else:
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))

        self.material_layer = nn.Linear(hidden_dim, 1)
        self.anisotropic_layer = nn.Linear(hidden_dim, 1)
        self.phase_layer = nn.Linear(hidden_dim, 1)

        self._initialize_weights()

    def _initialize_weights(self):
        for layer in self.layers:
            nn.init.xavier_uniform_(layer.weight, gain=np.sqrt(2))
            nn.init.zeros_(layer.bias)

        nn.init.xavier_uniform_(self.material_layer.weight)
        nn.init.zeros_(self.material_layer.bias)
        nn.init.xavier_uniform_(self.anisotropic_layer.weight)
        nn.init.zeros_(self.anisotropic_layer.bias)
        nn.init.xavier_uniform_(self.phase_layer.weight)
        nn.init.zeros_(self.phase_layer.bias)

    def positional_encoding(self, x: torch.Tensor) -> torch.Tensor:
        encodings = []
        for i in range(self.encoding_dims):
            freq = 2**i * np.pi
            encodings.append(torch.sin(freq * x))
            encodings.append(torch.cos(freq * x))
        return torch.cat(encodings, dim=-1)

    def forward(self, spatial_inputs: torch.Tensor, local_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if self.use_positional_encoding:
            x = self.positional_encoding(spatial_inputs)
        else:
            x = spatial_inputs
        
        x = torch.cat([x, local_features], dim=-1)
        input_encoding = x

        for i, layer in enumerate(self.layers):
            if i in self.skip_connections:
                x = torch.cat([x, input_encoding], dim=-1)
            
            x = layer(x)
            
            if i < len(self.layers) - 1:
                x = F.relu(x, inplace=True)
            else:
                x = F.softplus(x)
        
        material = F.relu(self.material_layer(x))  
        anisotropic = torch.sigmoid(self.anisotropic_layer(x))
        phase = np.pi * (torch.tanh(self.phase_layer(x)) + 1)
        
        return material, anisotropic, phase

class WeightsModel(nn.Module):
    def __init__(self, hidden_dim: int = 128, num_layers: int = 8, encoding_dims: int = 6, skip_connections: list = [4]):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.encoding_dims = encoding_dims
        self.skip_connections = skip_connections
        
        input_dim = 6 * 2 * encoding_dims
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        for i in range(1, num_layers):
            if i in skip_connections:
                self.layers.append(nn.Linear(input_dim + hidden_dim, hidden_dim))
            else:
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))
        self.output_layer = nn.Linear(hidden_dim, 1)
        self._initialize_weights()
    
    def _initialize_weights(self):
        for layer in self.layers:
            nn.init.xavier_uniform_(layer.weight, gain=np.sqrt(2))
            nn.init.zeros_(layer.bias)
        nn.init.xavier_uniform_(self.output_layer.weight)
        nn.init.zeros_(self.output_layer.bias)
    
    def positional_encoding(self, x: torch.Tensor) -> torch.Tensor:
        encodings = []
        for i in range(self.encoding_dims):
            freq = 2**i * np.pi
            encodings.append(torch.sin(freq * x))
            encodings.append(torch.cos(freq * x))
        return torch.cat(encodings, dim=-1)
    
    def forward(self, ray_direction: torch.Tensor, tx_position: torch.Tensor) -> torch.Tensor:
        ray_encoded = self.positional_encoding(ray_direction)
        tx_encoded = self.positional_encoding(tx_position)
        x = torch.cat([ray_encoded, tx_encoded], dim=-1)
        input_encoding = x
        
        for i, layer in enumerate(self.layers):
            if i in self.skip_connections:
                x = torch.cat([x, input_encoding], dim=-1)
            
            x = layer(x)
            x = F.relu(x, inplace=True)
        return self.output_layer(x)

class NeuralRF(nn.Module):
    def __init__(
        self, 
        structure_hidden_dim: int = 64,
        material_hidden_dim: int = 128,
        num_layers: int = 8,
        skip_connections: list = [4],
        geometric_init: bool = True,
        use_positional_encoding: bool = True,
        encoding_dims: int = 6
    ):
        super().__init__()
        
        self.structure_model = StructureModel(
            hidden_dim=structure_hidden_dim,
            num_layers=num_layers,
            skip_connections=skip_connections,
            geometric_init=geometric_init,
            use_positional_encoding=use_positional_encoding,
            encoding_dims=encoding_dims
        )
        
        self.material_model = MaterialModel(
            hidden_dim=material_hidden_dim,
            num_layers=num_layers,
            skip_connections=skip_connections,
            use_positional_encoding=use_positional_encoding,
            encoding_dims=encoding_dims
        )

    def forward(self, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        fmean, local_features = self.structure_model(points)
        material, anisotropic, phase = self.material_model(points, local_features)
        
        return fmean, material, anisotropic, phase

    def fmean_to_density(self, fmean: torch.Tensor, beta: float = 100.0) -> torch.Tensor:
        return torch.sigmoid(-beta * fmean)

    def gradient(self, points: torch.Tensor) -> torch.Tensor:
        points.requires_grad_(True)
        fmean, _, _, _ = self.forward(points)
        gradients = torch.autograd.grad(
            outputs=fmean,
            inputs=points,
            grad_outputs=torch.ones_like(fmean),
            create_graph=True,
            retain_graph=True
        )[0]
        return gradients

        
class CoarseToFineStructureModel(nn.Module):
    def __init__(
        self, 
        hidden_dim: int = 64,
        num_layers: int = 8,
        skip_connections: list = [4],
        geometric_init: bool = True,
        max_freq_levels: int = 4,
        start_freq_level: int = 0
    ):
        super().__init__()
        self.skip_connections = skip_connections
        self.max_freq_levels = max_freq_levels
        self.current_freq_level = start_freq_level
        
        self.base_structure = StructureModel(
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            skip_connections=skip_connections,
            geometric_init=geometric_init,
            use_positional_encoding=False,
            encoding_dims=0
        )
        
        self.freq_networks = nn.ModuleList()
        for freq_level in range(max_freq_levels):
            freq_net = StructureModel(
                hidden_dim=hidden_dim,
                num_layers=num_layers,
                skip_connections=skip_connections,
                geometric_init=False,
                use_positional_encoding=True,
                encoding_dims=freq_level + 1
            )
            self.freq_networks.append(freq_net)
            
            self._initialize_freq_network(freq_net)

    def _initialize_freq_network(self, network):
        for layer in network.layers:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, 0, 0.01)
                nn.init.zeros_(layer.bias)
        
        nn.init.normal_(network.fmean_layer.weight, 0, 0.001)
        nn.init.zeros_(network.fmean_layer.bias)
        nn.init.normal_(network.feature_layer.weight, 0, 0.001)
        nn.init.zeros_(network.feature_layer.bias)

    def increase_frequency(self):
        if self.current_freq_level < self.max_freq_levels:
            self.current_freq_level += 1
            print(f"Increased frequency level to {self.current_freq_level}/{self.max_freq_levels}")

    def forward(self, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        fmean, local_features = self.base_structure(points)
        
        for freq_level in range(min(self.current_freq_level, len(self.freq_networks))):
            freq_fmean, freq_features = self.freq_networks[freq_level](points)
            
            weight = 1.0 / (2 ** (freq_level + 1))
            fmean = fmean + weight * freq_fmean
            local_features = local_features + weight * freq_features
        
        return fmean, local_features


class CoarseToFineNeuralRF(nn.Module):
    def __init__(
        self, 
        structure_hidden_dim: int = 64,
        material_hidden_dim: int = 128,
        num_layers: int = 8,
        skip_connections: list = [4],
        geometric_init: bool = True,
        max_freq_levels: int = 4,
        start_freq_level: int = 0
    ):
        super().__init__()
        
        self.structure_model = CoarseToFineStructureModel(
            hidden_dim=structure_hidden_dim,
            num_layers=num_layers,
            skip_connections=skip_connections,
            geometric_init=geometric_init,
            max_freq_levels=max_freq_levels,
            start_freq_level=start_freq_level
        )
        
        self.material_model = MaterialModel(
            hidden_dim=material_hidden_dim,
            num_layers=num_layers,
            skip_connections=skip_connections,
            use_positional_encoding=True,
            encoding_dims=6
        )

    def increase_frequency(self):
        self.structure_model.increase_frequency()

    def forward(self, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        fmean, local_features = self.structure_model(points)
        material, anisotropic, phase = self.material_model(points, local_features)
        
        return fmean, material, anisotropic, phase

    def fmean_to_density(self, fmean: torch.Tensor, beta: float = 100.0) -> torch.Tensor:
        return torch.sigmoid(-beta * fmean)

    def gradient(self, points: torch.Tensor) -> torch.Tensor:
        points.requires_grad_(True)
        fmean, _, _, _ = self.forward(points)
        gradients = torch.autograd.grad(
            outputs=fmean,
            inputs=points,
            grad_outputs=torch.ones_like(fmean),
            create_graph=True,
            retain_graph=True
        )[0]
        return gradients