import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .depth_transformer import TransformerEncoder, TransformerEncoderLayer, TransformerEncoderLayerGeneral
from src.loftr.utils.position_encoding import PositionEncodingSine
from einops.einops import rearrange

class DepthPredictor(nn.Module):
    
    def __init__(self, config):
        """
        Initialize depth predictor and depth encoder
        Args:
            model_cfg: Depth classification network config
        """
        super().__init__()
        # Note: the following three params heavily depend on dataset
        self.depth_ratio = 3.0
        depth_num_bins = int(config['num_depth_bins']) # TODO: make it in config, 80 as default
        depth_min = float(config['depth_min']) # TODO: make it in config, 0.1 as default
        depth_max = float(config['depth_max']) # TODO: make if in config, 8.0 as default
        # -----------------
        self.depth_max = depth_max
        
        # encoding the depth prameterize same as 3d PE
        bin_size = (depth_max - depth_min) / (depth_num_bins * (1 + depth_num_bins))
        bin_indice = torch.arange(start=0, end=depth_num_bins, step=1).float()
        bin_value = depth_min + bin_size * bin_indice * (bin_indice + 1)
        bin_value = torch.cat([bin_value, torch.tensor([depth_max])], dim=0)
        
        self.depth_bin_values = nn.Parameter(bin_value, requires_grad=False)
        # Create modules
        feat_m_dim = config['middle_dim']
        d_model = config['hidden_dim']
        self.downsample = nn.Sequential(
            nn.Conv2d(feat_m_dim, d_model, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.GroupNorm(32, d_model)
        )
        self.proj = nn.Sequential(
            nn.Conv2d(d_model, d_model, kernel_size=(1, 1)),
            nn.GroupNorm(32, d_model)
        )
        
        # Predict depth info
        self.depth_head = nn.Sequential(
            nn.Conv2d(d_model, d_model, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(32, num_channels=d_model),
            nn.ReLU(),
            nn.Conv2d(d_model, d_model, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(32, num_channels=d_model),
            nn.ReLU()
        )
        
        # classification depth interval
        self.depth_classifier = nn.Conv2d(d_model, depth_num_bins + 1, kernel_size=(1, 1)) # Additional interval as BG

        # depth encoder
        depth_encoder_layer = TransformerEncoderLayerGeneral(d_model, nhead=config['nhead'])
        self.num_layers = config['num_layers']
        self.encoder_layers = nn.ModuleList(
            [copy.deepcopy(depth_encoder_layer) for _ in range(self.num_layers)])

        self.pos_encoding = PositionEncodingSine(d_model, temp_bug_fix=True)

        self.depth_pos_embed = nn.Embedding(int(self.depth_max * self.depth_ratio)+1, d_model)
        
    def forward(self, mlvl_feats):
        assert len(mlvl_feats) == 2
        
        # foreground depth map
        src_4 = self.downsample(mlvl_feats[0])
        src_8 = self.proj(mlvl_feats[1])
        src = (src_4 + src_8) / 2
        
        src = self.depth_head(src) # generate depth vector
        depth_logits = self.depth_classifier(src) # depth interval classification # [N, bin, H, W]
        
        depth_probs = F.softmax(depth_logits, dim=1) # [N, bin, H, W]
        weighted_depth = (depth_probs * self.depth_bin_values.reshape(1, -1, 1, 1)).sum(dim=1) # [N, H, W]
        
        # depth embeddings with depth positional encodings
        B, C, H, W = src.shape
        depth_embed = rearrange(self.pos_encoding(src), 'n c h w -> n (h w) c')
        for layer in self.encoder_layers:
            depth_embed = layer(depth_embed, depth_embed)
        
        depth_embed = rearrange(depth_embed, 'n l c -> n c l').view(B, C, H, W)

        # generate depth positional encodings                
        depth_pos_embed_ip = self.interpolate_depth_embed(weighted_depth) # [N C H W]
        depth_embed = depth_embed + depth_pos_embed_ip
        depth_embed = rearrange(depth_embed, 'n c h w -> n (h w) c')

        return depth_logits, depth_embed, weighted_depth 
    
    def interpolate_depth_embed(self, depth):
        depth = depth.clamp(min=0, max=self.depth_max)
        pos = self.interpolate_1d(depth * self.depth_ratio, self.depth_pos_embed)
        pos = rearrange(pos, 'n h w c -> n c h w')
        return pos
    
    def interpolate_1d(self, coord, embed):
        floor_coord = coord.floor()
        delta = (coord - floor_coord).unsqueeze(-1)
        floor_coord = floor_coord.long()
        ceil_coord = (floor_coord + 1).clamp(max=embed.num_embeddings-1)
        return embed(floor_coord) * (1 - delta) + embed(ceil_coord) * delta # [N H W C]