import torch
from .segformer_depth_pattern import (
    SegformerForSemanticSegmentation,
    SegformerConfig
)
import torch.nn as nn 
from pattern_compensated_dom.dpt import DepthAnything
import math


def segformer(pretrained=True):
    id2label = {0: "others"}
    label2id = {label: id for id, label in id2label.items()}
    num_labels = len(id2label)
    if pretrained:
        model = SegformerForSemanticSegmentation.from_pretrained(
            "./weight/segformer-b2-finetuned-ade-512-512",
            ignore_mismatched_sizes=True,
            num_labels=num_labels,
            id2label=id2label,
            label2id=label2id)
        return model

    else :
        config = SegformerConfig.from_json_file("./weight/segformer-b2-finetuned-ade-512-512/config.json")
        config.num_labels = num_labels
        config.id2label = id2label
        config.label2id = label2id
        model = SegformerForSemanticSegmentation(config)
        return model 
    
def affinity_image_depth(fea_a, fea_b):
    B, C, H, W = fea_a.shape
    anchor_fea = fea_a.flatten(start_dim=2)
    fea_b = fea_b.reshape((B, C, -1))
    a_sq = fea_b.pow(2).sum(1).unsqueeze(2)  
    ab = fea_b.transpose(1, 2) @ anchor_fea
    affinity = (2 * ab - a_sq) / math.sqrt(C)

    maxes = torch.max(affinity, dim=1, keepdim=True)[0]
    x_exp = torch.exp(affinity - maxes)
    x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
    affinity = x_exp / x_exp_sum

    fea_a = torch.bmm(fea_b, affinity)
    return fea_a.view(B, C, H, W)

class FTM_Net(nn.Module):
    def __init__(self, pretrianed=True, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.model = segformer(pretrianed)
        encoder = 'vits' # can also be 'vitb' or 'vitl'
        self.depth_anything_model = DepthAnything.from_pretrained('LiheYoung/depth_anything_{:}14'.format(encoder))
        self.depth_align_conv = nn.Conv2d(384, 512, 3, 1, 1)

    def forward(self, x):
        assert len(x.shape) == 5
        B, T, C, W, H = x.shape
        target_frame = -1
        self.depth_anything_model.depth_head_seg.num_frames = T
        x_single = x[:, target_frame]
        
        depth, depth_features = self.depth_anything_model(x)
        depth_features = self.depth_align_conv(depth_features)
    
        x_single = torch.cat([x_single, depth], dim=1)

        features = list(self.model.forward_features(x_single))

        features[-1] = features[-1] + affinity_image_depth(features[-1], depth_features)

        x = self.model.decode_head(features)

        upsampled_logits = nn.functional.interpolate(
                x, scale_factor=4.0, mode="bilinear", align_corners=False
            )
        return upsampled_logits