import torch
import torch.nn as nn
import torch.nn.functional as F
from .depth_transformer import TransformerEncoder, TransformerEncoderLayer
from einops.einops import rearrange

class DepthPredictor(nn.Module):
    
    def __init__(self, config):
        """
        Initialize depth predictor and depth encoder
        Args:
            model_cfg: Depth classification network config
        """
        super().__init__()
        # Note: the following three params heavily depend on dataset
        self.depth_ratio = 3.0
        depth_num_bins = int(config['num_depth_bins']) # TODO: make it in config, 80 as default
        depth_min = float(config['depth_min']) # TODO: make it in config, 0.1 as default
        depth_max = float(config['depth_max']) # TODO: make if in config, 8.0 as default
        # -----------------
        self.depth_max = depth_max
        
        bin_size = 2 * (depth_max - depth_min) / (depth_num_bins * (1 + depth_num_bins))
        bin_indice = torch.linspace(0, depth_num_bins - 1, depth_num_bins)
        bin_value = (bin_indice + 0.5).pow(2) * bin_size / 2 - bin_size / 8 + depth_min
        bin_value = torch.cat([bin_value, torch.tensor([depth_max])], dim=0)
        self.depth_bin_values = nn.Parameter(bin_value, requires_grad=False)
        
        # Create modules
        feat_m_dim = config['middle_dim']
        d_model = config['hidden_dim']
        self.downsample = nn.Sequential(
            nn.Conv2d(feat_m_dim, d_model, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.GroupNorm(32, d_model)
        )
        self.proj = nn.Sequential(
            nn.Conv2d(d_model, d_model, kernel_size=(1, 1)),
            nn.GroupNorm(32, d_model)
        )
        
        # Predict depth info
        self.depth_head = nn.Sequential(
            nn.Conv2d(d_model, d_model, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(32, num_channels=d_model),
            nn.ReLU(),
            nn.Conv2d(d_model, d_model, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(32, num_channels=d_model),
            nn.ReLU()
        )
        
        # classification depth interval
        self.depth_classifier = nn.Conv2d(d_model, depth_num_bins + 1, kernel_size=(1, 1)) # Additional interval as BG
        
        # self.depth_pos_embed = nn.Embedding(int(self.depth_max * self.depth_ratio)+1, d_model)
        
    def forward(self, mlvl_feats, pos_encoder):
        assert len(mlvl_feats) == 2
        
        # foreground depth map
        src_4 = self.downsample(mlvl_feats[0])
        src_8 = self.proj(mlvl_feats[1])
        src = (src_4 + src_8) / 2
        
        src = self.depth_head(src) # generate depth vector
        depth_logits = self.depth_classifier(src) # depth interval classification # [N, bin, H, W]
        
        depth_probs = F.softmax(depth_logits, dim=1) # [N, bin, H, W]
        weighted_depth = (depth_probs * self.depth_bin_values.reshape(1, -1, 1, 1)).sum(dim=1) # [N, H, W]
        
        # generate depth positional encodings                
        # depth_pos_embed_ip = self.interpolate_depth_embed(weighted_depth) # [N C H W]
        
        return depth_logits, None, weighted_depth 
    
    def interpolate_depth_embed(self, depth):
        depth = depth.clamp(min=0, max=self.depth_max)
        pos = self.interpolate_1d(depth * self.depth_ratio, self.depth_pos_embed)
        pos = rearrange(pos, 'n h w c -> n c h w')
        return pos
    
    def interpolate_1d(self, coord, embed):
        floor_coord = coord.floor()
        delta = (coord - floor_coord).unsqueeze(-1)
        floor_coord = floor_coord.long()
        ceil_coord = (floor_coord + 1).clamp(max=embed.num_embeddings-1)
        return embed(floor_coord) * (1 - delta) + embed(ceil_coord) * delta # [N H W C]