
import torch.nn as nn
import torch.nn.functional as F

from .base.conv4d import CenterPivotConv4d as Conv4d
from src.model.hsnet_helpers.transformer_decoder import MultiscaleTransformerDecoder, HierarchMultiscaleTransformerDecoder, \
        BidirMultiscaleTransformerDecoder

class HPNLearner(nn.Module):
    def __init__(self, inch, transformer_decoder_type='None', transformer_decoder_opts={},
                 vis_corr_tensors=False):

        super(HPNLearner, self).__init__()
        self.transformer_decoder_type = transformer_decoder_type
        self.vis_corr_tensors = vis_corr_tensors

        def make_building_block(in_channel, out_channels, kernel_sizes, spt_strides, group=4):
            assert len(out_channels) == len(kernel_sizes) == len(spt_strides)

            building_block_layers = []
            for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)):
                inch = in_channel if idx == 0 else out_channels[idx - 1]
                ksz4d = (ksz,) * 4
                str4d = (1, 1) + (stride,) * 2
                pad4d = (ksz // 2,) * 4

                building_block_layers.append(Conv4d(inch, outch, ksz4d, str4d, pad4d))
                building_block_layers.append(nn.GroupNorm(group, outch))
                building_block_layers.append(nn.ReLU(inplace=True))

            return nn.Sequential(*building_block_layers)

        outch1, outch2, outch3 = 16, 64, 128

        # Squeezing building blocks
        self.encoder_layer4 = make_building_block(inch[0], [outch1, outch2, outch3], [3, 3, 3], [2, 2, 2])
        self.encoder_layer3 = make_building_block(inch[1], [outch1, outch2, outch3], [5, 3, 3], [4, 2, 2])
        self.encoder_layer2 = make_building_block(inch[2], [outch1, outch2, outch3], [5, 5, 3], [4, 4, 2])
        self.encoder_layer4to3 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1])
        self.encoder_layer3to2 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1])

        # Mixing building blocks
        print('===========> Running Transformer decoder Type', self.transformer_decoder_type)
        if self.transformer_decoder_type == 'MultiscaleDecoder':
            self.multiscale_transformer_decoder = MultiscaleTransformerDecoder(in_channels=outch3, hidden_dim=outch3,
                                                                               num_queries=transformer_decoder_opts['n_queries'],
                                                                               nheads=8,
                                                                               dim_feedforward=outch2, dec_layers=9,
                                                                               pre_norm=False, use_sa=transformer_decoder_opts['use_sa'])
        elif self.transformer_decoder_type == 'HierarchMultiscaleDecoder':
            self.multiscale_transformer_decoder = HierarchMultiscaleTransformerDecoder(in_channels=outch3, hidden_dim=outch3,
                                                                                       num_queries=transformer_decoder_opts['n_queries'],
                                                                                       nheads=8,
                                                                                       dim_feedforward=outch2, dec_layers=9,
                                                                                       pre_norm=False, use_sa=transformer_decoder_opts['use_sa'])
        elif self.transformer_decoder_type == 'BidirMultiscaleDecoder':
            self.multiscale_transformer_decoder = BidirMultiscaleTransformerDecoder(in_channels=outch3, hidden_dim=outch3,
                                                                                    num_queries=transformer_decoder_opts['n_queries'],
                                                                                    nheads=8,
                                                                                    dim_feedforward=outch2, dec_layers=9,
                                                                                    pre_norm=False, use_sa=transformer_decoder_opts['use_sa'])
        else:
            # Decoder layers
            self.decoder1 = nn.Sequential(nn.Conv2d(outch3, outch3, (3, 3), padding=(1, 1), bias=True),
                                          nn.ReLU(),
                                          nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True),
                                          nn.ReLU())

            self.decoder2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True),
                                          nn.ReLU(),
                                          nn.Conv2d(outch2, 2, (3, 3), padding=(1, 1), bias=True))

    def interpolate_support_dims(self, hypercorr, spatial_size=None):
        bsz, ch, ha, wa, hb, wb = hypercorr.size()
        hypercorr = hypercorr.permute(0, 4, 5, 1, 2, 3).contiguous().view(bsz * hb * wb, ch, ha, wa)
        hypercorr = F.interpolate(hypercorr, spatial_size, mode='bilinear', align_corners=True)
        o_hb, o_wb = spatial_size
        hypercorr = hypercorr.view(bsz, hb, wb, ch, o_hb, o_wb).permute(0, 3, 4, 5, 1, 2).contiguous()
        return hypercorr

    def forward(self, hypercorr_pyramid, invalid_pixels=None, num_frames=-1, allow_interm=False, seq_name=''):
        output_vis = {}

        # Encode hypercorrelations from each layer (Squeezing building blocks)
        hypercorr_sqz4 = self.encoder_layer4(hypercorr_pyramid[0])
        hypercorr_sqz3 = self.encoder_layer3(hypercorr_pyramid[1])
        hypercorr_sqz2 = self.encoder_layer2(hypercorr_pyramid[2])
        if self.vis_corr_tensors:
            output_vis = {'original_corr': hypercorr_pyramid}

        # Propagate encoded 4D-tensor (Mixing building blocks)
        hsnet_pyramid_out = []
        hsnet_pyramid_out.append(hypercorr_sqz4.mean(dim=[-1,-2]))

        hypercorr_sqz4 = self.interpolate_support_dims(hypercorr_sqz4, hypercorr_sqz3.size()[-4:-2])
        hypercorr_mix43 = hypercorr_sqz4 + hypercorr_sqz3
        hypercorr_mix43 = self.encoder_layer4to3(hypercorr_mix43)
        if self.vis_corr_tensors:
            output_vis['mix43'] = hypercorr_mix43.mean(dim=[-1,-2])
        hsnet_pyramid_out.append(hypercorr_mix43.mean(dim=[-1,-2]))

        hypercorr_mix43 = self.interpolate_support_dims(hypercorr_mix43, hypercorr_sqz2.size()[-4:-2])
        hypercorr_mix432 = hypercorr_mix43 + hypercorr_sqz2
        hypercorr_mix432 = self.encoder_layer3to2(hypercorr_mix432)
        if self.vis_corr_tensors:
            output_vis['mix432'] = hypercorr_mix432.mean(dim=[-1,-2])
        hsnet_pyramid_out.append(hypercorr_mix432.mean(dim=[-1,-2]))

        bsz, ch, ha, wa, hb, wb = hypercorr_mix432.size()
        hypercorr_encoded = hypercorr_mix432.view(bsz, ch, ha, wa, -1).mean(dim=-1)

        if self.transformer_decoder_type != 'None':
            logit_mask, interm_feats = self.multiscale_transformer_decoder(hsnet_pyramid_out, invalid_pixels,
                                                                           num_frames=num_frames, seq_name=seq_name)
        else:
            # Decode the encoded 4D-tensor
            hypercorr_decoded = self.decoder1(hypercorr_encoded)
            upsample_size = (hypercorr_decoded.size(-1) * 2,) * 2
            hypercorr_decoded = F.interpolate(hypercorr_decoded, upsample_size, mode='bilinear', align_corners=True)
            logit_mask = self.decoder2(hypercorr_decoded)
            if allow_interm:
                interm_feats = self.decoder2[0](hypercorr_decoded)
                interm_feats = self.decoder2[1](interm_feats)
            else:
                interm_feats = None
        return logit_mask, interm_feats, output_vis
