###
### The code was taken from https://github.com/milesial/Pytorch-UNet
###
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import logging as log
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable
import torch.nn.init as init
from src.hair_networks.strand_prior import SirenNet
from src.utils.util import param_to_buffer, positional_encoding

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
    
class DUNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(DUNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        factor = 2 if bilinear else 1
        self.down5 = Down(1024, 2048 // factor)
        self.up1 = Up(2048, 1024 // factor, bilinear)
        self.up2 = Up(1024, 512 // factor, bilinear)
        self.up3 = Up(512, 256 // factor, bilinear)
        self.up4 = Up(256, 128 // factor, bilinear)
        self.up5 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)
        x = self.up1(x6, x5)
        x = self.up2(x, x4)
        x = self.up3(x, x3)
        x = self.up4(x, x2)
        x = self.up5(x, x1)
        logits = self.outc(x)
        return logits
    
class DoubleMLP(nn.Module):
    """(Conv1d -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if mid_channels==None:
            mid_channels=out_channels
        self.double_mlp = nn.Sequential(
            # nn.Conv1d(in_channels, mid_channels, kernel_size=1),
            nn.Linear(in_channels, mid_channels),
            nn.BatchNorm1d(mid_channels),
            # nn.LayerNorm(mid_channels),
            nn.ReLU(inplace=True),
            # nn.Conv1d(mid_channels, out_channels, kernel_size=1),
            # nn.Linear(mid_channels, out_channels),
            nn.BatchNorm1d(out_channels),
            # nn.LayerNorm(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_mlp(x)

class Down1D(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_mlp = DoubleMLP(in_channels, out_channels)

    def forward(self, x):
        return self.double_mlp(x)

class Up1D(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            # self.up = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
            self.up = [nn.Linear(out_channels * 2, out_channels)]
            self.up.append(nn.ReLU())
            self.up = nn.Sequential(*self.up)
        else:
            self.up = nn.ConvTranspose1d(in_channels, in_channels // 2, kernel_size=1, stride=1)

        self.double_mlp = DoubleMLP(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x2.size(1) - x1.size(1)
        # diffX = x2.size(2) - x1.size(2)
        # print("x1:   ",x1.shape)
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2])
        # print("x1:   ",x1.shape)
        # print("x2:   ",x2.shape)
        

        # x = torch.cat([x2, x1], dim=2)
        x = torch.cat([x2, x1], dim=1)
        return self.double_mlp(x)

class OutMLP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutMLP, self).__init__()
        # self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.outmlp = nn.Linear(in_channels, out_channels)

    def forward(self, x):
        return self.outmlp(x)  
    
class DUNet1D(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False, hidden_size=256, num_layers=1):
        super(DUNet1D, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.inc = DoubleMLP(n_channels, 64)
        self.down1 = Down1D(64, 128)
        self.down2 = Down1D(128, 256)
        self.down3 = Down1D(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down1D(512, 1024// factor)
        self.up1 = Up1D(1024, 512 // factor, bilinear)
        self.up2 = Up1D(512, 256 // factor, bilinear)
        self.up3 = Up1D(256, 128 // factor, bilinear)
        self.up4 = Up1D(128, 64 // factor, bilinear)
        self.outc = OutMLP(64 // factor, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

class DoubleMLP_LSTM(nn.Module):
    """(Conv1d -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if mid_channels==None:
            mid_channels=out_channels
        self.double_mlp = nn.Sequential(
            # nn.Conv1d(in_channels, mid_channels, kernel_size=1),
            nn.Linear(in_channels, mid_channels),
            # nn.BatchNorm1d(mid_channels),
            nn.LayerNorm(mid_channels),
            nn.ReLU(inplace=True),
            # nn.Conv1d(mid_channels, out_channels, kernel_size=1),
            nn.Linear(mid_channels, out_channels),
            # nn.BatchNorm1d(out_channels),
            nn.LayerNorm(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_mlp(x)

class LSTMEncoder(nn.Module):
    def __init__(self, in_channels, hidden_size=512, num_layers=4):
        super(LSTMEncoder, self).__init__()
        
        # LSTM layer
        self.lstm = nn.LSTM(input_size=in_channels, 
                            hidden_size=hidden_size, 
                            num_layers=num_layers, 
                            batch_first=True)
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(hidden_size)

    def forward(self, x):
        # Pass through the LSTM layer
        lstm_out, (h_n, c_n) = self.lstm(x)
        
        # Apply layer normalization to the output of the LSTM
        x = self.layer_norm(lstm_out)
        
        return x
    
class Down1DWithLSTM(nn.Module):
    """Downscaling with double conv + lstm"""
    def __init__(self, in_channels, out_channels, hidden_size=8, num_layers=4):
        super().__init__()
        self.double_mlp = DoubleMLP_LSTM(in_channels, out_channels)
        self.num_layers = num_layers
        if self.num_layers != 0:
            self.lstm_encoder = LSTMEncoder(in_channels=out_channels, hidden_size=hidden_size, num_layers=num_layers)
        else:
            self.lstm_encoder = None

    def forward(self, x):
        x = self.double_mlp(x)
        if self.num_layers == 0:
            return x
        x = self.lstm_encoder(x)
        return x

class Up1DWithLSTM(nn.Module):
    """Upscaling with double conv + lstm"""
    def __init__(self, in_channels, out_channels, bilinear=True, hidden_size=8, num_layers=4):
        super().__init__()

        if bilinear:
            self.up = [nn.Linear(out_channels * 2, out_channels)]
            self.up.append(nn.ReLU())
            self.up = nn.Sequential(*self.up)
        else:
            self.up = nn.ConvTranspose1d(in_channels, in_channels // 2, kernel_size=1, stride=1)
        self.num_layers = num_layers
        self.double_mlp = DoubleMLP_LSTM(in_channels, out_channels)
        if self.num_layers != 0:
            self.lstm_encoder = LSTMEncoder(in_channels=out_channels, hidden_size=hidden_size, num_layers=num_layers)
        else:
            self.lstm_encoder = None
    def forward(self, x1, x2):
        # print("x1:  ",x1.shape)
        # print("x2:  ",x2.shape)
        x1 = self.up(x1)
        # print("x1:  ",x1.shape)
        diffX = x2.size(2) - x1.size(2)
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2])
        # print("x1:  ",x1.shape)
        x = torch.cat([x2, x1], dim=2)
        # print("x:   ",x.shape)
        x = self.double_mlp(x)
        # print("x:   ",x.shape)
        if self.num_layers == 0:
            return x
        x = self.lstm_encoder(x)
        # print("x_t: ",x.shape)
        return x
    
class DUNet1D_LSTM(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False, hidden_size=256, num_layers=1):
        super(DUNet1D_LSTM, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        # LSTM
        self.inc = DoubleMLP_LSTM(n_channels, 64)
        self.down1 = Down1DWithLSTM(64, 128, 128, num_layers)
        self.down2 = Down1DWithLSTM(128, 256, 256, num_layers)
        self.down3 = Down1DWithLSTM(256, 512, 512, num_layers)
        factor = 2 if bilinear else 1
        self.down4 = Down1DWithLSTM(512, 1024 // factor, 1024 // factor, num_layers)
        self.up1 = Up1DWithLSTM(1024, 512 // factor, bilinear, 512 // factor, num_layers)
        self.up2 = Up1DWithLSTM(512, 256 // factor, bilinear, 256 // factor, num_layers)
        self.up3 = Up1DWithLSTM(256, 128 // factor, bilinear, 128 // factor, num_layers)
        self.up4 = Up1DWithLSTM(128, 64 // factor, bilinear, 64 // factor, num_layers)
        self.outc = OutMLP(64 // factor, n_classes)
        
        # LSTM BottleNeck
        # self.inc = DoubleMLP_LSTM(n_channels, 64)
        # self.down1 = Down1DWithLSTM(64, 128, 128, 0)
        # self.down2 = Down1DWithLSTM(128, 256, 256, num_layers)
        # self.down3 = Down1DWithLSTM(256, 512, 512, 0)
        # factor = 2 if bilinear else 1
        # # self.down4 = Down1DWithLSTM(512, 1024 // factor, 1024 // factor, num_layers)
        # self.down4 = Down1DWithLSTM(512, 1024 // factor, 1024 // factor, num_layers)
        # self.up1 = Up1DWithLSTM(1024, 512 // factor, bilinear, 512 // factor, num_layers)
        # self.up2 = Up1DWithLSTM(512, 256 // factor, bilinear, 256 // factor, 0)
        # self.up3 = Up1DWithLSTM(256, 128 // factor, bilinear, 128 // factor, num_layers)
        # self.up4 = Up1DWithLSTM(128, 64 // factor, bilinear, 64 // factor, 0)
        # self.outc = OutMLP(64 // factor, n_classes)

        # LSTM BottleNeck
        # self.inc = DoubleMLP_LSTM(n_channels, 64)
        # self.down1 = Down1DWithLSTM(64, 128, 128, 0)
        # self.down2 = Down1DWithLSTM(128, 256, 256, 0)
        # self.down3 = Down1DWithLSTM(256, 512, 512, 0)
        # self.down4 = Down1DWithLSTM(512, 1024, 1024, 0)
        # factor = 2 if bilinear else 1
        # self.down5 = Down1DWithLSTM(1024, 2048 // factor, 2048 // factor, num_layers)
        # self.up1 = Up1DWithLSTM(2048, 1024 // factor, bilinear, 1024 // factor, 0)
        # self.up2 = Up1DWithLSTM(1024, 512 // factor, bilinear, 512 // factor, 0)
        # self.up3 = Up1DWithLSTM(512, 256 // factor, bilinear, 256 // factor, 0)
        # self.up4 = Up1DWithLSTM(256, 128 // factor, bilinear, 128 // factor, 0)
        # self.up5 = Up1DWithLSTM(128, 64 // factor, bilinear, 64 // factor, 0)
        # self.outc = OutMLP(64 // factor, n_classes)


    def forward(self, x):
        # # print("x    ",x.shape)
        x1 = self.inc(x)
        # print("x1   ",x1.shape)
        x2 = self.down1(x1)
        # print("x2   ",x2.shape)
        x3 = self.down2(x2)
        # print("x3   ",x3.shape)
        x4 = self.down3(x3)
        # print("x4   ",x4.shape)
        x5 = self.down4(x4)
        # print("x5   ",x5.shape)
        x = self.up1(x5, x4)
        # print("x    ",x.shape)
        x = self.up2(x, x3)
        # print("x    ",x.shape)
        x = self.up3(x, x2)
        # print("x    ",x.shape)
        x = self.up4(x, x1)
        # print("x    ",x.shape)
        logits = self.outc(x)
        # x1 = self.inc(x)
        # x2 = self.down1(x1)
        # x3 = self.down2(x2)
        # x4 = self.down3(x3)
        # x5 = self.down4(x4)
        # x6 = self.down5(x5)
        # x = self.up1(x6, x5)
        # x = self.up2(x, x4)
        # x = self.up3(x, x3)
        # x = self.up4(x, x2)
        # x = self.up5(x, x1)
        # logits = self.outc(x)

        return logits
class Permute(nn.Module):
    def __init__(self, *order):
        super(Permute, self).__init__()
        self.order = order

    def forward(self, x):
        return x.permute(*self.order)
      
class Conv1DEncoder(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None,  kernel_size=3, stride=1, padding=0):
        super(Conv1DEncoder, self).__init__()
        
        self.conv1d = nn.Conv1d(in_channels=in_channels, 
                                out_channels=out_channels, 
                                kernel_size=kernel_size, 
                                stride=stride, 
                                padding=padding)
        
        self.layer_norm = nn.LayerNorm(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv1d(in_channels, mid_channels, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.double_conv(x)
        x = x.permute(0, 2, 1)
        # print("x:  ",x.shape)
        
        return x

def MLP(channels, activation=nn.ReLU(), bn_momentum=0.1, bias=True, eps=1e-5):
    return nn.Sequential(
        *[
            nn.Sequential(
                nn.Linear(channels[i - 1], channels[i], bias=bias),
                nn.BatchNorm1d(channels[i], momentum=bn_momentum, eps=eps),
                activation,
            )
            for i in range(1, len(channels))
        ]
    )
    
class MLP1D(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False, hidden_size=256, num_layers=1):
        super(MLP1D, self).__init__()
        s = hidden_size
        self.inp = MLP([n_channels, s])
        self.mlp = MLP([s, s, s, s])
        self.out = nn.Linear(s, n_classes)
    def forward(self, x):
        x = self.inp(x)
        for module in self.mlp:
            x = module(x)
        x = self.out(x)
        return x

class Down1DWithConv1d(nn.Module):
    """Downscaling with double conv + lstm"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.maxpool_conv = nn.MaxPool1d(1)
        self.double_mlp = DoubleMLP_LSTM(in_channels, out_channels)
        # self.lstm_encoder = LSTMEncoder(in_channels=out_channels, hidden_size=hidden_size, num_layers=num_layers)
        self.Conv1D = Conv1DEncoder(out_channels, out_channels, kernel_size = kernel_size, stride=stride, padding=padding)
        # self.Conv1D = Conv1DEncoder(out_channels, out_channels, kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        x = self.maxpool_conv(x)
        x = self.double_mlp(x)
        x = self.Conv1D(x)
        return x
    
class Up1DWithConv1d(nn.Module):
    """Upscaling with double conv + lstm"""
    def __init__(self, in_channels, out_channels, bilinear=True, kernel_size=3, stride=1, padding=1):
        super().__init__()

        if bilinear:
            self.up = [nn.Linear(out_channels * 2, out_channels)]
            self.up.append(nn.ReLU())
            self.up = nn.Sequential(*self.up)
        else:
            self.up = nn.ConvTranspose1d(in_channels, in_channels // 2, kernel_size=1, stride=1)

        self.double_mlp = DoubleMLP_LSTM(in_channels, out_channels)
        # self.lstm_encoder = LSTMEncoder(in_channels=out_channels, hidden_size=hidden_size, num_layers=num_layers)
        self.Conv1D = Conv1DEncoder(out_channels, out_channels, kernel_size = kernel_size, stride=stride, padding=padding)
        # self.Conv1D = Conv1DEncoder(out_channels, out_channels, kernel_size, stride=stride, padding=padding)

    def forward(self, x1, x2):
        # print("x1:  ",x1.shape)
        # print("x2:  ",x2.shape)
        x1 = self.up(x1)
        # print("x1:  ",x1.shape)
        diffX = x2.size(2) - x1.size(2)
        # print("diffX:  ",diffX)
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2])
        x = torch.cat([x2, x1], dim=2)
        x = self.double_mlp(x)
        x = self.Conv1D(x)
        return x
      
class DUNet1D_1DConv(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False, hidden_size=256, num_layers=1):
        super(DUNet1D_1DConv, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        self.inc = DoubleMLP_LSTM(n_channels, 64)
        self.down1 = Down1DWithConv1d(64, 128)
        self.down2 = Down1DWithConv1d(128, 256)
        self.down3 = Down1DWithConv1d(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down1DWithConv1d(512, 1024 // factor)
        self.up1 = Up1DWithConv1d(1024, 512 // factor, bilinear)
        self.up2 = Up1DWithConv1d(512, 256 // factor, bilinear)
        self.up3 = Up1DWithConv1d(256, 128 // factor, bilinear)
        self.up4 = Up1DWithConv1d(128, 64 // factor, bilinear)
        self.outc = OutMLP(64 // factor, n_classes)



    def forward(self, x):
        # print("x    ",x.shape)
        x1 = self.inc(x)
        # print("x1   ",x1.shape)
        x2 = self.down1(x1)
        # print("x2   ",x2.shape)
        x3 = self.down2(x2)
        # print("x3   ",x3.shape)
        x4 = self.down3(x3)
        # print("x4   ",x4.shape)
        x5 = self.down4(x4)
        # print("x5   ",x5.shape)
        x = self.up1(x5, x4)
        # print("x    ",x.shape)
        x = self.up2(x, x3)
        # print("x    ",x.shape)
        x = self.up3(x, x2)
        # print("x    ",x.shape)
        x = self.up4(x, x1)
        # print("x    ",x.shape)
        logits = self.outc(x)


        return logits

class Texture_Decoder(nn.Module):
    def __init__(self, dim_in=1, dim_hidden=256, num_layers=8, dim_out=3):
        super().__init__()
        self.dim_out = dim_out

        self.net = SirenNet(
            dim_in = dim_in,
            dim_hidden = dim_hidden,
            dim_out = dim_out,
            num_layers = num_layers,
            w0_initial = 30.
        )
    def forward(self, z):
        v = self.net(z)

        return v
    
def get_normalized_directions(directions):
    """SH encoding must be in the range [0, 1]

    Args:
        directions: batch of directions
    """
    return (directions + 1.0) / 2.0


def normalize_aabb(pts, aabb):
    return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0
def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor:
    grid_dim = coords.shape[-1]

    if grid.dim() == grid_dim + 1:
        # no batch dimension present, need to add it
        grid = grid.unsqueeze(0)
    if coords.dim() == 2:
        coords = coords.unsqueeze(0)

    if grid_dim == 2 or grid_dim == 3:
        grid_sampler = F.grid_sample
    else:
        raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only "
                                  f"implemented for 2 and 3D data.")

    coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:]))
    B, feature_dim = grid.shape[:2]
    n = coords.shape[-2]
    # if torch.any(torch.isnan(grid)):
    #     print("torch.any(torch.isnan(before grid))")
    interp = grid_sampler(
        grid,  # [B, feature_dim, reso, ...]
        coords,  # [B, 1, ..., n, grid_dim]
        align_corners=align_corners,
        mode='bilinear', padding_mode='border')
    interp = interp.view(B, feature_dim, n).transpose(-1, -2)  # [B, n, feature_dim]
    interp = interp.squeeze()  # [B?, n, feature_dim?]
    return interp

def init_grid_param(
        grid_nd: int,
        in_dim: int,
        out_dim: int,
        reso: Sequence[int],
        a: float = 0.3,
        b: float = 0.6):
    assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"
    has_time_planes = in_dim == 3
    assert grid_nd <= in_dim
    coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
    grid_coefs = nn.ParameterList()
    for ci, coo_comb in enumerate(coo_combs):
        new_grid_coef = nn.Parameter(torch.empty(
            [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]
        ))
        grid_size = new_grid_coef.shape[2]
        mgrid = torch.stack(torch.meshgrid([torch.linspace(-1, 1, grid_size)]*2, indexing='ij'))[None].cuda()
        if grid_size == reso[2]:
            grid_size = new_grid_coef.shape[3]
            time_size = new_grid_coef.shape[2]
            mgrid = torch.stack(torch.meshgrid([torch.linspace(0, 1, time_size),torch.linspace(-1, 1, grid_size)], indexing='ij'))[None].cuda()
        mgrid_pe = positional_encoding(mgrid,6)
        with torch.no_grad():
            new_grid_coef.copy_(mgrid_pe)
        grid_coefs.append(new_grid_coef)

    return grid_coefs


def interpolate_ms_features(pts: torch.Tensor,
                            ms_grids: Collection[Iterable[nn.Module]],
                            grid_dimensions: int,
                            concat_features: bool,
                            num_levels: Optional[int],
                            ) -> torch.Tensor:
    coo_combs = list(itertools.combinations(
        range(pts.shape[-1]), grid_dimensions)
    )
    if num_levels is None:
        num_levels = len(ms_grids)
    multi_scale_interp = [] if concat_features else 0.
    grid: nn.ParameterList
    for scale_id,  grid in enumerate(ms_grids[:num_levels]):
        interp_space = 1.
        for ci, coo_comb in enumerate(coo_combs):
            # interpolate in plane
            feature_dim = grid[ci].shape[1]  # shape of grid[ci]: 1, out_dim, *reso
            interp_out_plane = (
                grid_sample_wrapper(grid[ci], pts[..., coo_comb])
                .view(-1, feature_dim)
            )
            interp_space = interp_space * interp_out_plane
        # combine over scales
        if concat_features:
            multi_scale_interp.append(interp_space)
        else:
            multi_scale_interp = multi_scale_interp + interp_space

    if concat_features:
        multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
    # print("multi_scale_interp.shape:  ",multi_scale_interp.shape)
    return multi_scale_interp


class HexPlaneField(nn.Module):
    def __init__(
        self,
        
        bounds,
        planeconfig,
        multires
    ) -> None:
        super().__init__()
        aabb = torch.tensor([[bounds,bounds,bounds],
                             [-bounds,-bounds,-bounds]])
        self.aabb = nn.Parameter(aabb, requires_grad=False)
        self.grid_config =  [planeconfig]
        self.multiscale_res_multipliers = multires
        self.concat_features = True
        # 1. Init planes
        self.grids = nn.ModuleList()
        self.feat_dim = 0
        for res in self.multiscale_res_multipliers:
            # initialize coordinate grid
            config = self.grid_config[0].copy()
            # Resolution fix: multi-res only on spatial planes
            config["resolution"] = [
                r * res for r in config["resolution"][:3]
            ] + config["resolution"][3:]
            
            gp = init_grid_param(
                grid_nd=config["grid_dimensions"],
                in_dim=config["input_coordinate_dim"],
                out_dim=config["output_coordinate_dim"],
                reso=config["resolution"],
            )
            # shape[1] is out-dim - Concatenate over feature len for each scale
            if self.concat_features:
                self.feat_dim += gp[-1].shape[1]
            else:
                self.feat_dim = gp[-1].shape[1]
            self.grids.append(gp)
    @property
    def get_aabb(self):
        return self.aabb[0], self.aabb[1]
    def set_aabb(self,xyz_max, xyz_min):
        aabb = torch.tensor([
            xyz_max,
            xyz_min
        ],dtype=torch.float32)
        self.aabb = nn.Parameter(aabb,requires_grad=False)
        print("Voxel Plane: set aabb=",self.aabb)

    def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
        """Computes and returns the densities."""
        # breakpoint()
        pts = normalize_aabb(pts, self.aabb)
        pts = torch.cat((pts, timestamps), dim=-1)  # [n_rays, n_samples, 4]

        pts = pts.reshape(-1, pts.shape[-1])
        features = interpolate_ms_features(
            pts, ms_grids=self.grids,  # noqa
            grid_dimensions=self.grid_config[0]["grid_dimensions"],
            concat_features=self.concat_features, num_levels=None)
        if len(features) < 1:
            features = torch.zeros((0, 1)).to(features.device)
        return features

    def forward(self,
                pts: torch.Tensor,
                timestamps: Optional[torch.Tensor] = None):
        features = self.get_density(pts, timestamps)

        return features
    
    def initialize_weights(m):
        if isinstance(m, nn.Linear):
            # init.constant_(m.weight, 0)
            init.xavier_uniform_(m.weight,gain=1)
            if m.bias is not None:
                init.xavier_uniform_(m.weight,gain=1)
                init.constant_(m.bias, 0)
                
        
