r""" Hypercorrelation Squeeze Network """
from functools import reduce
from operator import add

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet
from torchvision.models import vgg

from src.model.hsnet_helpers.base.feature import extract_feat_vgg, extract_feat_res
from src.model.hsnet_helpers.base.correlation import Correlation
from src.model.hsnet_helpers.learner import HPNLearner


class HypercorrSqueezeNetwork(nn.Module):
    def __init__(self, args, backbone, use_original_imgsize, transformer_decoder_type='None'):
        super(HypercorrSqueezeNetwork, self).__init__()
        # 1. Backbone network initialization
        self.backbone_type = backbone
        self.use_original_imgsize = use_original_imgsize

#        if backbone == 'vgg16':
#            self.backbone = vgg.vgg16(pretrained=True)
#            self.feat_ids = [17, 19, 21, 24, 26, 28, 30]
#            self.extract_feats = extract_feat_vgg
#            nbottlenecks = [2, 2, 3, 3, 3, 1]
        if backbone == 50:
            self.backbone = resnet.resnet50(pretrained=True)
            self.feat_ids = list(range(4, 17))
            self.extract_feats = extract_feat_res
            nbottlenecks = [3, 4, 6, 3]
        elif backbone == 101:
            self.backbone = resnet.resnet101(pretrained=True)
            self.feat_ids = list(range(4, 34))
            self.extract_feats = extract_feat_res
            nbottlenecks = [3, 4, 23, 3]
        else:
            raise Exception('Unavailable backbone: %s' % backbone)

        self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks)))
        self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)])
        self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3]
        self.backbone.eval()

        if not hasattr(args, 'transformer_decoder_opts'):
            args.transformer_decoder_opts = {}
        else:
            if 'n_queries' not in args.transformer_decoder_opts:
                args.transformer_decoder_opts['n_queries'] = 2

        self.hpn_learner = HPNLearner(list(reversed(nbottlenecks[-3:])),
                                      transformer_decoder_type=transformer_decoder_type,
                                      transformer_decoder_opts=args.transformer_decoder_opts,
                                      vis_corr_tensors=args.vis_corr_tensors)
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, query_img, support_img, support_mask, invalid_pixels=None, allow_interm=False, seq_name=''):
        """
        query_img: [B x T x C x H x W]
        support_img: [B x K x C x H x W]
        support_mask: [B x K x H x W]
        """
        B, T = query_img.shape[:2]
        query_img = query_img.view(-1, *query_img.shape[-3:])
        support_img = support_img.view(-1, *support_img.shape[-3:])
        support_mask = support_mask.view(-1, *support_mask.shape[-2:])
        if invalid_pixels is not None:
            invalid_pixels = invalid_pixels.view(-1, *invalid_pixels[-3:])

        with torch.no_grad():
            query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
            support_feats = self.extract_feats(support_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
            support_feats = self.mask_feature(support_feats, support_mask.clone())
            if support_feats[0].shape[0] != query_feats[0].shape[0]:
                # Batch size not equal because of extra temporal dim
                support_feats = [sfeats.unsqueeze(1).repeat(1, T, 1, 1, 1).view(-1, *sfeats.shape[-3:]) for sfeats in support_feats]

            corr = Correlation.multilayer_correlation(query_feats, support_feats, self.stack_ids)

        logit_mask, feats, output_vis = self.hpn_learner(corr, allow_interm=allow_interm)
        if self.use_original_imgsize:
            logit_mask = F.interpolate(logit_mask, support_img.size()[2:], mode='bilinear', align_corners=True)
        return logit_mask, feats, None, None, output_vis

    def mask_feature(self, features, mask):
        """
        features: [BxCxHxW]*L or [BxTxCxHxW]
        mask: BxHxW or BxTxHxW
        """
        if len(mask.shape) > 3:
            # BTxHxW
            mask = mask.view(-1, *mask.shape[-2:])

        for idx, feature in enumerate(features):
            mask_ds = F.interpolate(mask.unsqueeze(1).float(), feature.size()[2:], mode='bilinear', align_corners=True)
            features[idx] = features[idx] * mask_ds
        return features

    def extract_hsnet_feats(self, batch, nshot):
        all_feats = []
        for s_idx in range(nshot):
            _, feats, _, _ = self(batch['query_img'], batch['support_imgs'][:, s_idx], batch['support_masks'][:, s_idx])
            all_feats.append(feats)
        return torch.cat(all_feats, dim=0)

    def predict_mask_nshot_perframe(self, batch, nshot, seq_name=''):
        """
        Perform Prediction on all frames, Left for later on when incorporating labelprop
        """
        # Perform multiple prediction given (nshot) number of different support sets
        logit_mask_agg = 0
        for s_idx in range(nshot):
            logit_mask, _, _, _, _ = self(batch['query_img'], batch['support_imgs'][:, s_idx], batch['support_masks'][:, s_idx], seq_name=seq_name)

            if self.use_original_imgsize:
                org_qry_imsize = batch['support_imgs'].size()[-2:]
                logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True)

            logit_mask_agg += logit_mask.argmax(dim=1).clone()
            if nshot == 1: return logit_mask_agg

        # Average & quantize predictions given threshold (=0.5)
        bsz = logit_mask_agg.size(0)
        max_vote = logit_mask_agg.view(bsz, -1).max(dim=1)[0]
        max_vote = torch.stack([max_vote, torch.ones_like(max_vote).long()])
        max_vote = max_vote.max(dim=0)[0].view(bsz, 1, 1)
        pred_mask = logit_mask_agg.float() / max_vote
        pred_mask[pred_mask < 0.5] = 0
        pred_mask[pred_mask >= 0.5] = 1

        return pred_mask

    def predict_mask_nshot(self, qry_img, spprt_imgs, s_label, seq_name=''):
        probas = []
        Nframes = qry_img.shape[1]
        shot = spprt_imgs.shape[1]
        batch = {'query_img': qry_img, 'support_imgs': spprt_imgs, 'support_masks': s_label}
        probas = self.predict_mask_nshot_perframe(batch, shot, seq_name=seq_name).unsqueeze(1) # nclasses dim
        probas = torch.cat([1-probas, probas], dim=1).unsqueeze(1).float() # shot dim
        return probas

    def get_backbone_modules(self):
        return [self.backbone]

    def get_new_modules(self):
        return [self.hpn_learner]

    def compute_objective(self, input_tuple, original_gt_mask):
        original_logit_mask = input_tuple[0]
        bsz = original_logit_mask.size(0)
        logit_mask = original_logit_mask.view(bsz, 2, -1)
        gt_mask = original_gt_mask.view(bsz, -1).long()
        return self.cross_entropy_loss(logit_mask, gt_mask), original_logit_mask, {}

    def train_mode(self):
        self.train()
        self.backbone.eval()  # to prevent BN from learning data statistics with exponential averaging
