import torch 
import torch.nn as nn
import itertools

class Points(nn.Module):
    def __init__(self, point_num, point_dim, init_scale = 0.001):
        super().__init__()
        self.point_num = point_num
        self.points = nn.Parameter(torch.randn(point_num, point_dim), requires_grad=True)

# config:
# in_dim xxx
# out_dim xxx 
# hidden_dim xxx
# num_layers xxx
# activation xxx

class GSHEAD(nn.Module):
    def __init__(self, configs):
        super().__init__()
        self.heads = nn.ModuleDict()
        for k, config in configs.items():
            self.heads[k] = self._configure_head(config)
    
    def _configure_head(self, config):
        layers = []
        for i in range(config.num_layers):
            if i == 0:
                layers.append(nn.Linear(config.in_dim, config.hidden_dim))
            elif i == config.num_layers - 2:
                layers.append(nn.Linear(config.hidden_dim, config.hidden_dim))
            else:
                layers.append(nn.Linear(config.hidden_dim, config.out_dim))
        # zero init
        layers[-1].weight.data.zero_()
        layers[-1].bias.data.zero_()

        return nn.Sequential(*layers)
    
    def _get_activation(self, out, head_name):
        if head_name == "xyz":
            return out.clip(-1, 1)
        elif head_name == "rgb":
            return torch.sigmoid(out)
        elif head_name == "scaling":
            return torch.exp(out)
        elif head_name == "rotation":
            return torch.nn.functional.normalize(out, dim=-1)
        elif head_name == "opacity":
            return torch.sigmoid(out)
        else:
            raise ValueError("Unknown head name")

    def forward(self, x, head_name):

        out = self.heads[head_name](x)
        return self._get_activation(out, head_name)



class OSGDecoder(nn.Module):
    """
    Triplane decoder that gives RGB and sigma values from sampled features.
    Using ReLU here instead of Softplus in the original implementation.
    
    Reference:
    EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
    """
    def __init__(self, n_features: int,
                 hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3 * n_features, hidden_dim),
            activation(),
            *itertools.chain(*[[
                nn.Linear(hidden_dim, hidden_dim),
                activation(),
            ] for _ in range(num_layers - 2)]),
            nn.Linear(hidden_dim, 1 + 3),
        )
        # init all bias to zero
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.zeros_(m.bias)

    def forward(self, sampled_features):
        # Aggregate features by mean
        # sampled_features = sampled_features.mean(1)
        # Aggregate features by concatenation
        _N, n_planes, _M, _C = sampled_features.shape
        sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
        x = sampled_features

        N, M, C = x.shape
        x = x.contiguous().view(N*M, C)

        x = self.net(x)
        x = x.view(N, M, -1)
        rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001  # Uses sigmoid clamping from MipNeRF
        sigma = x[..., 0:1]

        return {'rgb': rgb, 'sigma': sigma}