r""" Hypercorrelation Squeeze Network """
from functools import reduce
from operator import add

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet
from torchvision.models import vgg

from src.model.hsnet_helpers.base.feature import extract_feat_vgg, extract_feat_res
from src.model.hsnet_helpers.base.correlation import Correlation
from src.model.hsnet_helpers.learner import HPNLearner
from src.model.hsnet import HypercorrSqueezeNetwork

class TransformerHypercorrSqueezeNetwork(HypercorrSqueezeNetwork):
    def __init__(self, args, backbone, use_original_imgsize):
        transformer_decoder_type = 'MultiscaleDecoder' if not hasattr(args, 'transformer_decoder_type') else args.transformer_decoder_type

        super(TransformerHypercorrSqueezeNetwork, self).__init__(args, backbone, use_original_imgsize,
                                                                 transformer_decoder_type=transformer_decoder_type)

    def forward(self, query_img, support_img, support_mask, invalid_pixels=None, seq_name=''):
        """
        query_img: [B x T x C x H x W]
        support_img: [B x K x C x H x W]
        support_mask: [B x K x H x W]
        """
        B, T = query_img.shape[:2]
        query_img = query_img.view(-1, *query_img.shape[-3:])
        support_img = support_img.view(-1, *support_img.shape[-3:])
        support_mask = support_mask.view(-1, *support_mask.shape[-2:])

        with torch.no_grad():
            query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
            support_feats = self.extract_feats(support_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
            support_feats = self.mask_feature(support_feats, support_mask.clone())

            # Output correlation is BxCxHxWxHxW
            support_feats = [feats.unsqueeze(1).repeat(1, T, 1, 1, 1).flatten(0,1) for feats in support_feats]

            # [BT x C x H x W x H x W] * L
            corr_frames = Correlation.multilayer_correlation(query_feats, support_feats, self.stack_ids)

        logit_mask, feats, output_vis = self.hpn_learner(corr_frames, invalid_pixels, num_frames=T, seq_name=seq_name)
        if self.use_original_imgsize:
            logit_mask = F.interpolate(logit_mask, support_img.size()[2:], mode='bilinear', align_corners=True)
        return logit_mask, feats, None, None, output_vis
