import torch
import torch.nn as nn
import torch.nn.functional as F
    
class GaussianWeightPred(nn.Module):
    # def __init__(self, input_dim, perturb_range = 0.001, hidden_dims=[32, 64, 128, 64, 32]):
    def __init__(self, input_dim, perturb_range = 0, hidden_dims=[32, 64, 32]):
    # def __init__(self, input_dim, perturb_range = 0.001, hidden_dims=[64, 64, 64]):
        super(GaussianWeightPred, self).__init__()
        layers = []
        prev_dim = input_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, h_dim))
            layers.append(nn.ReLU())
            prev_dim = h_dim
        layers.append(nn.Linear(prev_dim, 1))
        # layers.append(nn.ReLU())
        self.mlp = nn.Sequential(*layers)
        self.sigmoid = nn.Sigmoid()
        self.epsilon = nn.Parameter(torch.zeros(input_dim))  # scalar
        self.perturb_range = perturb_range
    def forward(self, x):
        # import ipdb;ipdb.set_trace()
        # epsilon = torch.clamp(self.epsilon, -self.perturb_range, self.perturb_range)
        # x = self.mlp(x * (epsilon.unsqueeze(0) + 1))       
        x = self.mlp(x)       
        x = self.sigmoid(x)   
        # return (1 - x) * opacity
        return x
   
class PositionalEncoding(nn.Module):

    def __init__(self, num_freqs=6, include_input=True):
        super().__init__()
        self.num_freqs = num_freqs
        self.include_input = include_input
        # [1, 2, 4, ..., 2^{num_freqs-1}] * pi
        self.freq_bands = 2 ** torch.arange(num_freqs).float() * torch.pi

    def forward(self, x):
        # x: (..., 3)
        out = []
        if self.include_input:
            out.append(x)
        for freq in self.freq_bands:
            out.append(torch.sin(x * freq))
            out.append(torch.cos(x * freq))
        return torch.cat(out, dim=-1)
     
class gaussian_weight_pred_pe(nn.Module):
    # def __init__(self, input_dim, perturb_range = 0.001, hidden_dims=[32, 64, 128, 64, 32]):
    def __init__(self, input_dim, perturb_range = 0, pe_feq=6, hidden_dims=[32, 64, 32]):
    # def __init__(self, input_dim, perturb_range = 0.001, hidden_dims=[64, 64, 64]):
        super(gaussian_weight_pred_pe, self).__init__()
        # pe_feq = 6
        layers = []
        # prev_dim = input_dim
        prev_dim = (pe_feq * 2 + 1) * 4
        for h_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, h_dim))
            layers.append(nn.ReLU())
            prev_dim = h_dim
        layers.append(nn.Linear(prev_dim, 1))
        # layers.append(nn.ReLU())
        self.mlp = nn.Sequential(*layers)
        self.sigmoid = nn.Sigmoid()
        self.epsilon = nn.Parameter(torch.zeros(input_dim))  # scalar
        self.perturb_range = perturb_range
        self.pe = PositionalEncoding(pe_feq,include_input=True)
    def forward(self, x):
        # import ipdb;ipdb.set_trace()
        # epsilon = torch.clamp(self.epsilon, -self.perturb_range, self.perturb_range)
        # x = self.mlp(x * (epsilon.unsqueeze(0) + 1))       
        x = self.pe(x)
        # import ipdb;ipdb.set_trace()
        x = self.mlp(x)       
        x = self.sigmoid(x)   
        # return (1 - x) * opacity
        return x
    
class NegativeLinear(nn.Module):

    def __init__(self, in_features, out_features, neg_idx, bias=True):
        super().__init__()
        self.neg_idx = neg_idx

        self.raw_weight = nn.Parameter(torch.randn(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, x):

        W = self.raw_weight.clone()
        W[:, self.neg_idx] = -F.softplus(self.raw_weight[:, self.neg_idx])
        return F.linear(x, W, self.bias)

class PositiveLinear(nn.Module):

    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.raw_weight = nn.Parameter(torch.randn(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        W = F.softplus(self.raw_weight)
        return F.linear(x, W, self.bias)

class DeepReverseMonotonicNorm(nn.Module):

    def __init__(self,
                 input_dim=4,
                 hidden_dims=[64, 64, 32, 16],
                 perturb_range=0.00001):
        super().__init__()
        self.perturb_range = perturb_range
        self.epsilon = nn.Parameter(torch.zeros(input_dim))

        layers = []
        prev = input_dim
 
        layers.append(NegativeLinear(prev, hidden_dims[0], neg_idx=3))
        layers.append(nn.ReLU(inplace=True))
        prev = hidden_dims[0]

        for h in hidden_dims[1:]:
            layers.append(PositiveLinear(prev, h))
            layers.append(nn.ReLU(inplace=True))
            prev = h

        layers.append(PositiveLinear(prev, 1))

        self.net = nn.Sequential(*layers)

    def forward(self, x):

        eps = torch.clamp(self.epsilon, -self.perturb_range, self.perturb_range)
        x_scaled = x * (1 + eps.unsqueeze(0))  # (N,4)

        out = self.net(x_scaled)               # (N,1)

        y = torch.sigmoid(out)
        return y