import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F

import numpy as np
import math

# from module import *
# from data_utils import *

from einops import rearrange, repeat
import models
from models import register

from models.mlp import MLP, get_activation_function

"""
A Set of position encoder
"""


def _cal_freq_list(freq_init, frequency_num, max_radius, min_radius):
    if freq_init == "random":
        # the frequence we use for each block, alpha in ICLR paper
        # freq_list shape: (frequency_num)
        freq_list = np.random.random(size=[frequency_num]) * max_radius
    elif freq_init == "geometric":
        # Compute the 1/f_s 
        log_timescale_increment = (math.log(float(max_radius) / float(min_radius)) /
                                   (frequency_num * 1.0 - 1))

        timescales = min_radius * np.exp(
            np.arange(frequency_num).astype(float) * log_timescale_increment)

        freq_list = 1.0 / timescales

    return freq_list


def generate_random_pt_within_interval(band_coord, K=1, band_int_sample_type="uniform"):
    '''
    Args:
        band_coord: shape (B = batch_size, num_b = C or num_band_sample, 2), band interval coordinates, [-1, 1]
                        or just 0 when we do not send band_path, shape (B, 1) (avoid this)
                        B: batch size
                        num_b: number of bands/channels we consider to super-res
                        2: band interval start and end endpoints -> (S, E)
        K: the number of random band coordinate we would like to sample
        band_int_sample_type: the method to sample these points
    Return:
        rand_coord: shape (B = batch_size, num_b = C or num_band_sample, K)
                    random pt within each (start, end) interval
    '''
    B, num_b, _ = band_coord.shape

    if band_int_sample_type == "uniform":
        rand_coord = np.random.uniform(low=0.0, high=1.0, size=(B, num_b, K))
        rand_coord = torch.from_numpy(rand_coord).float()
        if band_coord.is_cuda:
            rand_coord = rand_coord.cuda()
        start = band_coord[:, :, 0].unsqueeze(-1)
        end = band_coord[:, :, 1].unsqueeze(-1)
        rand_coord = rand_coord * (end - start) + start
        # rand_coord = rand_coord_.unsqueeze(-1)
        return rand_coord
    elif band_int_sample_type == "gaussian":
        # start: shape (B, num_b, 1)
        start = band_coord[:, :, 0].unsqueeze(-1).cpu().numpy()
        # end:   shape (B, num_b, 1)
        end = band_coord[:, :, 1].unsqueeze(-1).cpu().numpy()

        # mean:  shape (B, num_b, 1), mean of each interval
        mean = torch.mean(band_coord, dim=-1, keepdim=True).cpu().numpy()
        # std:   shape (B, num_b, 1), standard deviation of each interval
        std = (end - start) / 6

        # rand_coord:  shape (B, num_b, K), sampled wavelength in each interval
        rand_coord = np.random.normal(loc=mean, scale=std, size=(B, num_b, K))
        rand_coord = torch.from_numpy(rand_coord).float()
        if band_coord.is_cuda:
            rand_coord = rand_coord.cuda()
        return rand_coord
    elif band_int_sample_type == "fix":
        # start: shape (B, num_b, 1)
        start = band_coord[:, :, 0].unsqueeze(-1).cpu().numpy()
        # end:   shape (B, num_b, 1)
        end = band_coord[:, :, 1].unsqueeze(-1).cpu().numpy()

        # rand_coord: shape (K, B, num_b, 1)
        rand_coord = np.linspace(start, end, num=K, endpoint=True)
        # rand_coord: shape (B, num_b, K), sampled wavelength in each interval
        rand_coord = np.transpose(rand_coord, axes=(1, 2, 3, 0)).squeeze(-2)
        rand_coord = torch.from_numpy(rand_coord).float()
        if band_coord.is_cuda:
            rand_coord = rand_coord.cuda()
        return rand_coord
    else:
        raise NotImplementedError


def get_band_sample_prob_with_response_func(band_coord, band_coord_samples,
                                            resp_func_type='gaussian',
                                            resp_func_norm_type='softmax'):
    '''
    Args:
        band_coord: shape (B = batch_size, num_b = C or num_band_sample, 2)
            band intervals (start, end)
        band_coord_samples: shape (B = batch_size, num_b = C or num_band_sample, K)
            sampled K wavelengths within each band interval
        resp_func_type: the response function type
        resp_func_norm_type: the normalized method of response function weights
    Return:
        probs_norm: shape (B, num_b, K)
            the normalized probability of each wavelength samples according to the predefined response functions
    '''
    import scipy

    # start: shape (B, num_b, 1)
    start = band_coord[:, :, 0].unsqueeze(-1).cpu().numpy()
    # end:   shape (B, num_b, 1)
    end = band_coord[:, :, 1].unsqueeze(-1).cpu().numpy()

    # band_coord_samples_: shape (B, num_b, K)
    band_coord_samples_ = band_coord_samples.cpu().numpy()

    if resp_func_type == 'gaussian':
        # mean:  shape (B, num_b, 1), mean of each interval
        mean = torch.mean(band_coord, dim=-1, keepdim=True).cpu().numpy()
        # std:   shape (B, num_b, 1), standard deviation of each interval
        std = (end - start) / 6

        # band_coord_samples_scale: shape (B, num_b, K)
        band_coord_samples_scale = (band_coord_samples_ - mean) / std

        # probs: shape (B, num_b, K) 
        probs = scipy.stats.norm(loc=0, scale=1).pdf(band_coord_samples_scale)
    elif resp_func_type == 'uniform':
        # probs: shape (B, num_b, K) 
        probs = np.ones(shape=band_coord_samples.shape, dtype=np.float32)
    else:
        raise NotImplementedError

    if resp_func_norm_type == 'softmax':
        probs_norm = scipy.special.softmax(x=probs, axis=-1)
    elif resp_func_norm_type == 'invsum':
        probs_sum = probs.sum(axis=-1, keepdims=True)
        probs_norm = probs / probs_sum
    else:
        raise NotImplementedError

    # return probs_norm, probs, band_coord_samples_
    return probs_norm  # , probs, band_coord_samples_


@register('bandnerf')
class BandNeRFDecoder(nn.Module):
    """
    Given a list of wavelength  embeddeding generated from BandPositionEncoder, and a image embedding on a specific image location, BandNeRFDecoder decode a specific value for each wavelength at this location
    """

    def __init__(self, fedec_out_dim, band_embed_dim, out_dim=1,
                 bandnerf_type="img_band_cat", hidden_list=[], act="sigmoid", **kwargs):
        """
        Args:
            fedec_out_dim: image embedding dimention
            band_embed_dim: band embedding dimention generated from BandPositionEncoder(), parameter
            bandnerf_type: the way how to combine image and band embedding:
                "img_band_cat": concatenate image and band embedding as the imput for NeRF MLP
            hidden_list: the hidden list of BandNeRF 
            act: the final activation function
            
        """
        super(BandNeRFDecoder, self).__init__()

        self.fedec_out_dim = fedec_out_dim
        self.band_embed_dim = band_embed_dim
        self.out_dim = out_dim
        self.bandnerf_type = bandnerf_type
        self.hidden_list = hidden_list

        if bandnerf_type == "img_band_cat":
            mlp_in_dim = fedec_out_dim + band_embed_dim
            self.bandnerf = MLP(in_dim=mlp_in_dim,
                                out_dim=self.out_dim,
                                hidden_list=self.hidden_list)
        elif bandnerf_type == "img_band_dot" or bandnerf_type == "img_band_dot1":
            assert len(hidden_list) == 0
        else:
            raise Exception

        self.act = get_activation_function(activation=act, context_str="BandNeRFDecoder")
        self.act_name = act

    def forward(self, img_embeds, band_embeds):
        '''
        Args:
            img_embeds:  shape (B, X, fedec_out_dim)
            band_embeds: shape (B = batch_size, num_b = C or num_band_sample, K, band_embed_dim), band embedding for K wavelength in each band interval
                            or just 0 when we do not send band_path, shape (B, 1) (avoid this)
                            B: batch size
                            num_b: number of bands/channels we consider to super-res
                            K: number of wavelengthes sample in each band interval
                            band_embed_dim: band embedding dimension
        Return:
            out_embeds: shape (B, X, num_b, K, out_dim)
        '''

        B, X, fedec_out_dim = img_embeds.shape
        assert fedec_out_dim == self.fedec_out_dim
        B1, num_b, K, band_embed_dim = band_embeds.shape
        assert B == B1
        assert band_embed_dim == self.band_embed_dim

        if self.bandnerf_type == "img_band_dot1":
            # img_embeds: shape (B, fedec_out_dim, X)
            img_embeds = img_embeds.permute(0, 2, 1)
            # out_embeds:  shape (B, num_b, K, X)
            out_embeds = torch.einsum('bckd,bdx->bckx', band_embeds, img_embeds)
            # out_embeds:  shape (B, X, num_b, K, 1)
            out_embeds = out_embeds.permute(0, 3, 1, 2).unsqueeze(-1)
            return out_embeds

        elif self.bandnerf_type.startswith("img_band_"):
            # img_embeds:  shape (B, X, 1,     1,          fedec_out_dim)
            img_embeds = img_embeds.unsqueeze(2).unsqueeze(2)
            # img_embeds:  shape (B, X, num_b, 1,          fedec_out_dim)
            img_embeds = torch.repeat_interleave(img_embeds, repeats=num_b, dim=2)
            # img_embeds:  shape (B, X, num_b, K, fedec_out_dim)
            img_embeds = torch.repeat_interleave(img_embeds, repeats=K, dim=3)

            # band_embeds: shape (B, 1, num_b, K, band_embed_dim)
            band_embeds = band_embeds.unsqueeze(1)
            # band_embeds: shape (B, X, num_b, K, band_embed_dim)
            band_embeds = torch.repeat_interleave(band_embeds, repeats=X, dim=1)

            if self.bandnerf_type == "img_band_cat":
                # embeds:      shape (B, X, num_b, K, fedec_out_dim + band_embed_dim)
                embeds = torch.cat([img_embeds, band_embeds], dim=-1)

                # out_embeds:  shape (B, X, num_b, K, out_dim=1)
                out_embeds = self.bandnerf(embeds)

                # print(out_embeds)

            elif self.bandnerf_type == "img_band_dot":
                assert fedec_out_dim == band_embed_dim
                # out_embeds:  shape (B, X, num_b, K, 1)
                out_embeds = torch.sum(img_embeds * band_embeds, dim=-1, keepdim=True)

            else:
                raise NotImplementedError

            # out_embeds:  shape (B, X, num_b, K, out_dim=1)
            out_embeds = self.act(out_embeds)
            if self.act_name == 'tanh':
                out_embeds = out_embeds * (1 + 0.001)

            # print(out_embeds)
            return out_embeds
        else:
            raise NotImplementedError


@register('banddec')
class BandAwareDecoder(nn.Module):
    """
    Decoder

    """

    def __init__(self, fedec_spec, bandenc_spec, bandnerf_spec=None,
                 in_dim=64, num_band_int_sample=32,
                 resp_func_type="gaussian", resp_func_norm_type='softmax', band_int_sample_type='uniform'):
        """
        Args:
            fedec_spec: image feature decoder
            bandenc_spec: BandPositionEncoder(), parameter
            in_dim: the input image feature dim
            num_band_int_sample: the number of wavelengths we will sample for each band interval, only used when bandposenc_type == "band_rb_mlp"
            
        """
        super(BandAwareDecoder, self).__init__()

        # feature decoder, MLP()
        self.fedec = models.make(fedec_spec, args={'in_dim': in_dim})
        # The feature decoder output image feature dim
        fedec_out_dim = self.fedec.out_dim
        # The band encoder's MLP output dimension should match the feature decoder's output dimention
        # Band encoder, BandPositionEncoder()
        self.bandenc = models.make(bandenc_spec, args={'out_dim': fedec_out_dim})

        self.fedec_spec = fedec_spec
        self.bandenc_spec = bandenc_spec

        self.in_dim = in_dim
        self.fedec_out_dim = fedec_out_dim

        self.bandposenc_type = self.bandenc.bandposenc_type
        self.num_band_int_sample = num_band_int_sample
        self.resp_func_type = resp_func_type
        self.resp_func_norm_type = resp_func_norm_type
        self.band_int_sample_type = band_int_sample_type

        if bandnerf_spec is not None:
            assert self.bandposenc_type == "band_rb_mlp"
            self.bandnerf_dec = models.make(bandnerf_spec,
                                            args={
                                                'fedec_out_dim': fedec_out_dim,
                                                'band_embed_dim': fedec_out_dim
                                            })

    def forward(self, inp, band_coord):
        '''
        Args:
            inp: shape (B, X, imnet_in_dim)
            band_coord: shape (B = batch_size, num_b = C or num_band_sample, 2), band interval coordinates, [-1, 1]
                            or just 0 when we do not send band_path, shape (B, 1) (avoid this)
                            B: batch size
                            num_b: number of bands/channels we consider to super-res
                            2: band interval start and end endpoints -> (S, E)
        Return:
            out_img: shape (B, X, num_b)
        '''
        B, X, imnet_in_dim = inp.shape
        assert B == band_coord.shape[0]

        _, num_b, _ = band_coord.shape

        # fea_embed: shape (B, X, fedec_out_dim)
        fea_embed = self.fedec(inp)
        # fea_embed = self.fedec(inp.reshape(B * X, -1)).reshape(B, X, -1)

        if self.bandposenc_type == "band_rb_mlp":
            assert self.bandnerf_dec is not None
            # band_coord_samples: shape (B, num_b, K = num_band_int_sample)
            band_coord_samples = generate_random_pt_within_interval(band_coord,
                                                                    K=self.num_band_int_sample,
                                                                    band_int_sample_type=self.band_int_sample_type)

            # band_embed: shape (B, num_b, K, fedec_out_dim)
            band_embed = self.bandenc(band_coord_samples)  # embedding

            # band_pred_samples:  shape (B, X, num_b, K, out_dim=1)
            band_pred_samples = self.bandnerf_dec(img_embeds=fea_embed, band_embeds=band_embed)

            # band_probs_norm: shape (B, num_b, K), normalized wavelength probablity for K wavelength samples
            band_probs_norm = get_band_sample_prob_with_response_func(band_coord, band_coord_samples,
                                                                      resp_func_type=self.resp_func_type,
                                                                      resp_func_norm_type=self.resp_func_norm_type)

            # band_probs_norm: shape (B, num_b, K)
            band_probs_norm = torch.from_numpy(band_probs_norm).float()
            if band_coord.is_cuda:
                band_probs_norm = band_probs_norm.cuda()

            # band_probs_norm: shape (B, 1, num_b, K, 1)
            band_probs_norm = band_probs_norm.unsqueeze(1).unsqueeze(-1)
            # band_probs_norm: shape (B, X, num_b, K, 1)
            # band_probs_norm = torch.repeat_interleave(band_probs_norm, repeats = X, dim = 1)

            # out_img: shape (B, X, num_b)
            out_img = torch.sum(torch.sum(band_pred_samples * band_probs_norm, dim=-2), dim=-1)

            return out_img

        else:
            # band_embed: shape (B, num_b, fedec_out_dim)
            band_embed = self.bandenc(band_coord)
            # band_embed: shape (B, fedec_out_dim, num_b)
            band_embed = band_embed.permute(0, 2, 1)

            # out_img: shape (B, X, num_b)
            out_img = torch.einsum('bxo,boc->bxc', fea_embed, band_embed)
            return out_img


@register('bandposenc')
class BandPositionEncoder(nn.Module):
    """
    Given a list of (start, end) band interval, encode them using the position encoding function
    """

    def __init__(self, bandposenc_type,
                 freq=16,
                 max_radius=1, min_radius=1e-4,
                 freq_init="geometric",
                 out_dim=256,
                 hidden_list=[]):
        """
        Args:
            bandposenc_type: band position encoder type
            freq: the number of different sinusoidal with different frequencies, not the band number
            max_radius: the largest radius this model can handle
            min_radius: the smallest radius this model can handle 

            out_dim: the output band interval position embedding dim
            hidden_list: the hidden list of band encoder's MLP
            
        """
        super(BandPositionEncoder, self).__init__()

        self.bandposenc_type = bandposenc_type

        self.frequency_num = freq
        self.max_radius = max_radius
        self.min_radius = min_radius
        self.freq_init = freq_init

        self.out_dim = out_dim
        self.hidden_list = hidden_list

        self.make_bandmlp()
        self.cal_freq_list()
        self.cal_freq_mat()

    def make_bandmlp(self):

        F = self.frequency_num
        bandposenc_type = self.bandposenc_type

        if bandposenc_type == "band_rb_mlp":
            '''
            Encode a random point of band interval
            input: [sin(R_k), cos(R_k)], k = 1,...,F
            '''
            self.in_dim = 1 * F * 2
        elif bandposenc_type == "band_m_mlp":
            '''
            Encode the mean point of band interval
            input: [sin(M_k), cos(M_k)], k = 1,...,F
            '''
            self.in_dim = 1 * F * 2
        elif bandposenc_type == "band_mc1_mlp":
            '''
            Encode the mean point, and band width (no pos enc) of band interval
            input: [sin(M_k), cos(M_k)], k = 1,...,F , plus band_with
            '''
            self.in_dim = 1 * F * 2 + 1
        elif bandposenc_type == "band_mc_mlp":
            '''
            Encode the mean point, and band width (with pos enc) of band interval
            input: [sin(M_k), cos(M_k), sin(C_k), cos(C_k)], k = 1,...,F 
            '''
            self.in_dim = 2 * F * 2
        elif bandposenc_type == "band_se_mlp":
            '''
            Encode the start (S) and end (E) point of band interval
            input: [sin(S_k), cos(S_k), sin(E_k), cos(E_k)], k = 1,...,F
            '''
            self.in_dim = 2 * F * 2
        elif bandposenc_type == "band_sme_mlp":
            '''
            Encode the start (S), middle (M), end (E) point of band interval
            input: [sin(S_k), cos(S_k), sin(M_k), cos(M_k), sin(E_k), cos(E_k)], k = 1,...,F
            '''
            self.in_dim = 3 * F * 2
        elif bandposenc_type == "band_sre_mlp":
            '''
            Encode the start (S), random (R) point, end (E) point of band interval
            input: [sin(S_k), cos(S_k), sin(R_k), cos(R_k), sin(E_k), cos(E_k)], k = 1,...,F
            '''
            self.in_dim = 3 * F * 2
        else:
            raise Exception(f"Unknown bandposenc_type: {bandposenc_type}")

        self.bandmlp = MLP(in_dim=self.in_dim,
                           out_dim=self.out_dim,
                           hidden_list=self.hidden_list)

    def cal_freq_list(self):
        # freq_list shape: (frequency_num)
        self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius)

    def cal_freq_mat(self):
        # freq_mat shape: (frequency_num, 1)
        self.freq_mat = np.expand_dims(self.freq_list, axis=1)
        # self.freq_mat shape: (frequency_num, 2)
        # self.freq_mat = np.repeat(freq_mat, 2, axis = 1)

    def forward(self, band_coord):
        '''
        Args:
            band_coord: shape (B = batch_size, num_b = C or num_band_sample, K), band interval coordinates, [-1, 1]
                            or just 0 when we do not send band_path, shape (B, 1) (avoid this)
                            B: batch size
                            num_b: number of bands/channels we consider to super-res
                            K:  if bandposenc_type == "band_rb_mlp": this is the sampled wavelength with in the band interval
                                else: band interval start and end endpoints -> (S, E)
        Return:
            band_embed: shape (B, num_b, out_dim ), band interval embedding
        '''
        # print(band_coord.type())
        F = self.frequency_num
        bandposenc_type = self.bandposenc_type

        B, num_b, _ = band_coord.shape

        # freq_mat: shape (freq, 1)
        freq_mat = self.freq_mat
        freq_mat = torch.from_numpy(freq_mat).float()
        if band_coord.is_cuda:
            freq_mat = freq_mat.cuda()

        # assert band_coord.shape[-1] != 1 # make sure the coordinates are not just zeros
        if bandposenc_type != "band_rb_mlp":
            assert band_coord.shape[-1] == 2

        if bandposenc_type == "band_rb_mlp":
            '''
            Encode a set random point of band interval
            input: [sin(R_k), cos(R_k)], k = 1,...,F
            band_coord: shape (B = batch_size, num_b = C or num_band_sample, K)
            '''
            # band_coord: shape (B, num_b, K, 1, 1)
            band_coord = band_coord.unsqueeze(-1).unsqueeze(-1)
        elif bandposenc_type == "band_m_mlp":
            '''
            Encode the mean point of band interval
            input: [sin(M_k), cos(M_k)], k = 1,...,F
            '''
            # band_coord: shape (B, num_b, 1, 1, 1)
            band_coord = band_coord.mean(dim=-1, keepdim=True).unsqueeze(-1).unsqueeze(-1)
        elif bandposenc_type == "band_mc1_mlp":
            '''
            Encode the mean point, and band width (no pos enc) of band interval
            input: [sin(M_k), cos(M_k)], k = 1,...,F , plus band_with
            '''
            band_width = band_coord[:, :, 1] - band_coord[:, :, 0]
            # band_width: shape (B, num_b, 1)
            band_width = band_width.unsqueeze(-1)
            # band_coord: shape (B, num_b, 1, 1, 1)
            band_coord = band_coord.mean(dim=-1, keepdim=True).unsqueeze(-1).unsqueeze(-1)
        elif bandposenc_type == "band_mc_mlp":
            '''
            Encode the mean point, and band width (with pos enc) of band interval
            input: [sin(M_k), cos(M_k), sin(C_k), cos(C_k)], k = 1,...,F 
            '''
            band_width = band_coord[:, :, 1] - band_coord[:, :, 0]
            # band_width: shape (B, num_b, 1)
            band_width = band_width.unsqueeze(-1)
            # band_mean: shape (B, num_b, 1)
            band_mean = band_coord.mean(dim=-1, keepdim=True)
            # band_coord: shape (B, num_b, 2)
            band_coord = torch.cat([band_mean, band_width], dim=-1)
            # band_coord: shape (B, num_b, 2, 1, 1)
            band_coord = band_coord.unsqueeze(-1).unsqueeze(-1)
        elif bandposenc_type == "band_se_mlp":
            '''
            Encode the start (S) and end (E) point of band interval
            input: [sin(S_k), cos(S_k), sin(E_k), cos(E_k)], k = 1,...,F
            '''
            # band_coord: shape (B, num_b, 2, 1, 1)
            band_coord = band_coord.unsqueeze(-1).unsqueeze(-1)
        elif bandposenc_type in ["band_sme_mlp", "band_sre_mlp"]:
            '''
            Encode the start (S), middle (M) or random (R), end (E) point of band interval
            input: [sin(S_k), cos(S_k), sin(M_k), cos(M_k), sin(E_k), cos(E_k)], k = 1,...,F
            '''
            if self.training and bandposenc_type == "band_sre_mlp":
                '''
                if model is in training, for band_sre_mlp
                we random generate a pt within each (start, end),
                middle_coord: shape (B, num_b, 1) 
                '''
                # print("compute random pt")
                middle_coord = generate_random_pt_within_interval(band_coord, K=1, band_int_sample_type='uniform')
            else:
                # print("compute middle pt")
                # if model is in eval or for band_sme_mlp, we compute the middle pt
                # middle_coord: shape (B, num_b, 1) 
                middle_coord = torch.mean(band_coord, dim=2, keepdim=True)

            # band_coord: shape (B, num_b, 3)
            band_coord = torch.cat([band_coord[:, :, 0].unsqueeze(2),
                                    middle_coord, band_coord[:, :, 1].unsqueeze(2)], dim=2)

            # band_coord: shape (B, num_b, 3, 1, 1)
            band_coord = band_coord.unsqueeze(-1).unsqueeze(-1)
        else:
            raise Exception(f"Unknown bandposenc_type: {bandposenc_type}")

            # band_coord: shape (B, num_b, 1 or 2 or 3 or K, F, 1)
        band_coord = torch.repeat_interleave(band_coord, repeats=F, dim=-2)
        # band_coord: shape (B, num_b, 1 or 2 or 3 or K, F, 1)
        band_coord = band_coord * freq_mat
        # band_embed: shape (B, num_b, 1 or 2 or 3 or K, F, 2)
        band_embed = torch.cat([torch.sin(band_coord), torch.cos(band_coord)], dim=-1)
        if bandposenc_type == "band_rb_mlp":
            _, _, K, _, _ = band_coord.shape
            # band_embed: shape (B, num_b, K, in_dim = 1 * F * 2),
            band_embed = band_embed.reshape(B, num_b, K, -1)

        else:
            # band_embed: shape (B, num_b, in_dim = 1 * F * 2 or 2 * F * 2 or 3 * F * 2 ),
            band_embed = band_embed.reshape(B, num_b, -1)

            if bandposenc_type == "band_mc1_mlp":
                # band_embed: shape (B, num_b, in_dim = 1 * F * 2 + 1 ),
                band_embed = torch.cat([band_embed, band_width], dim=-1)

        # print(band_embed.type())

        # l2norm: shape (B, num_b, 1 )
        # l2norm = torch.linalg.norm(band_embed, dim=-1, keepdim=True) 
        # band_embed /= l2norm 

        # band_embed = self.bandmlp(band_embed.reshape(B * num_b, -1)).reshape(B, num_b, -1)

        '''
        band_embed: 
            if bandposenc_type != "band_rb_mlp": shape (B, num_b, K, out_dim )
            else:                               shape (B, num_b, out_dim )
        '''
        band_embed = self.bandmlp(band_embed)
        return band_embed
