import os.path as osp
import torch
import torch.nn as nn

import math


class GaussianField(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=64, pos_dim=3, num_gaussians=128, weight=False):
        super(GaussianField, self).__init__()
        self.__file__ = osp.abspath(__file__)
        self.pos_dim = pos_dim
        self.num_gaussians = num_gaussians
        self.weight = weight
        
        self.proj_in = nn.Linear(in_dim, hidden_dim)
        self.to_mu = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, pos_dim * num_gaussians)
        )
        self.to_sigma = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, pos_dim * num_gaussians),
            nn.Softplus()  # Ensure positive output for sigma
        )
        self.to_weight = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_gaussians),
            nn.Softmax(dim=-1)  # Ensure weights sum to 1
        )
        
        # Decoder MLP
        self.proj_out = nn.Sequential(
            nn.Linear(num_gaussians, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
        
    def encode(self, x):
        B, N, d = x.shape
        
        x_in = self.proj_in(x)  # (B, N, hidden_dim)
        
        mu = self.to_mu(x_in)  # (B, N, num_gaussians * pos_dim)
        sigma = self.to_sigma(x_in)  # (B, N, num_gaussians * pos_dim)
        weight = self.to_weight(x_in)  # (B, N, num_gaussians)
        
        mu = mu.view(B, N, self.num_gaussians, self.pos_dim)
        sigma = sigma.view(B, N, self.num_gaussians, self.pos_dim)
        weight = weight.view(B, N, self.num_gaussians)

        return mu, sigma, weight
    
    def encode_z(self, x):
        mu, sigma, weight = self.encode(x)
        x_pos = x[..., :self.pos_dim]
        z = self.compute_gaussian(x_pos, mu, sigma, weight=weight)
        return z
    
    def decode_z(self, z):
        out = self.proj_out(z)
        return out
    
    def compute_gaussian(self, x_pos, mu, sigma, weight):
        x_exp = x_pos.unsqueeze(2).expand(-1, -1, self.num_gaussians, -1)  # (B, N, G, pos_dim)

        diff = (x_exp - mu) / (sigma + 1e-6) # (B, N, G, pos_dim)
        dist_sq = (diff ** 2).sum(dim=-1)  # (B, N, G)

        gaussian_response = torch.exp(-0.5 * dist_sq)  # (B, N, G)
        z = gaussian_response * weight if self.weight else gaussian_response

        return z
            
    def decode(self, mu, sigma, weight, x_pos):
        z = self.compute_gaussian(x_pos, mu, sigma, weight)
        out = self.proj_out(z)
        return out

    def forward(self, x, x_pos=None):
        mu, sigma, weight = self.encode(x)
        x_pos = x[..., :self.pos_dim] if x_pos is None else x_pos
        out = self.decode(mu, sigma, weight, x_pos)
        return out
    
    def compute_loss(self, x, sigma_min=0.01, sigma_max=0.5):
        x_pos = x[..., :self.pos_dim]
        mu, sigma, weight = self.encode(x)
        
        mu_loss = self.compute_mu_loss(x_pos, mu, weight)
        sigma_range_loss = self.compute_sigma_loss(sigma, sigma_min, sigma_max)

        return mu_loss, sigma_range_loss
    
    def compute_mu_loss(self, x_pos, mu, weight):
        if self.weight:
            mu = (mu * weight.unsqueeze(-1)).sum(dim=2)
            mu_loss = ((mu - x_pos)**2).mean()
        else:
            mu_loss = ((mu - x_pos.unsqueeze(2))**2).mean()
        return mu_loss
    
    def compute_sigma_loss(self, sigma, sigma_min=0.01, sigma_max=0.5):
        sigma_penalty = torch.relu(sigma - sigma_max) + torch.relu(sigma_min - sigma)
        sigma_range_loss = sigma_penalty.mean()
        return sigma_range_loss
