"""R3D"""
import math
from collections import OrderedDict
import numpy as np
import pdb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _triple
from lib.normalize import Normalize


class SpatioTemporalConv(nn.Module):
    r"""Applies a factored 3D convolution over an input signal composed of several input
    planes with distinct spatial and time axes, by performing a 2D convolution over the
    spatial axes to an intermediate subspace, followed by a 1D convolution over the time
    axis to produce the final output.
    Args:
        in_channels (int): Number of channels in the input tensor
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to the sides of the input during their respective convolutions. Default: 0
        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(SpatioTemporalConv, self).__init__()

        # if ints are entered, convert them to iterables, 1 -> [1, 1, 1]
        kernel_size = _triple(kernel_size)
        stride = _triple(stride)
        padding = _triple(padding)


        self.temporal_spatial_conv = nn.Conv3d(in_channels, out_channels, kernel_size,
                                    stride=stride, padding=padding, bias=bias)

    def forward(self, x):
        x = self.temporal_spatial_conv(x)
        return x


class SpatioTemporalResBlock(nn.Module):
    r"""Single block for the ResNet network. Uses SpatioTemporalConv in
        the standard ResNet block layout (conv->batchnorm->ReLU->conv->batchnorm->sum->ReLU)
        Args:
            in_channels (int): Number of channels in the input tensor.
            out_channels (int): Number of channels in the output produced by the block.
            kernel_size (int or tuple): Size of the convolving kernels.
            downsample (bool, optional): If ``True``, the output size is to be smaller than the input. Default: ``False``
        """

    def __init__(self, in_channels, out_channels, kernel_size, downsample=False):
        super(SpatioTemporalResBlock, self).__init__()

        # If downsample == True, the first conv of the layer has stride = 2
        # to halve the residual output size, and the input x is passed
        # through a seperate 1x1x1 conv with stride = 2 to also halve it.

        # no pooling layers are used inside ResNet
        self.downsample = downsample

        # to allow for SAME padding
        padding = kernel_size // 2

        if self.downsample:
            # downsample with stride = 2 the input x
            self.downsampleconv = SpatioTemporalConv(in_channels, out_channels, 1, stride=2)
            self.downsamplebn = nn.BatchNorm3d(out_channels)

            # downsample with stride = 2 when producing the residual
            self.conv1 = SpatioTemporalConv(in_channels, out_channels, kernel_size, padding=padding, stride=2)
        else:
            self.conv1 = SpatioTemporalConv(in_channels, out_channels, kernel_size, padding=padding)

        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu1 = nn.ReLU()

        # standard conv->batchnorm->ReLU
        self.conv2 = SpatioTemporalConv(out_channels, out_channels, kernel_size, padding=padding)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.outrelu = nn.ReLU()

    def forward(self, x):
        res = self.relu1(self.bn1(self.conv1(x)))
        res = self.bn2(self.conv2(res))

        if self.downsample:
            x = self.downsamplebn(self.downsampleconv(x))

        return self.outrelu(x + res)


class SpatioTemporalResLayer(nn.Module):
    r"""Forms a single layer of the ResNet network, with a number of repeating
    blocks of same output size stacked on top of each other
        Args:
            in_channels (int): Number of channels in the input tensor.
            out_channels (int): Number of channels in the output produced by the layer.
            kernel_size (int or tuple): Size of the convolving kernels.
            layer_size (int): Number of blocks to be stacked to form the layer
            block_type (Module, optional): Type of block that is to be used to form the layer. Default: SpatioTemporalResBlock.
            downsample (bool, optional): If ``True``, the first block in layer will implement downsampling. Default: ``False``
        """

    def __init__(self, in_channels, out_channels, kernel_size, layer_size, block_type=SpatioTemporalResBlock,
                 downsample=False):

        super(SpatioTemporalResLayer, self).__init__()

        # implement the first block
        self.block1 = block_type(in_channels, out_channels, kernel_size, downsample)

        # prepare module list to hold all (layer_size - 1) blocks
        self.blocks = nn.ModuleList([])
        for i in range(layer_size - 1):
            # all these blocks are identical, and have downsample = False by default
            self.blocks += [block_type(out_channels, out_channels, kernel_size)]

    def forward(self, x):
        x = self.block1(x)
        for block in self.blocks:
            x = block(x)

        return x


class R3DNet(nn.Module):
    r"""Forms the overall ResNet feature extractor by initializng 5 layers, with the number of blocks in
    each layer set by layer_sizes, and by performing a global average pool at the end producing a
    512-dimensional vector for each element in the batch.
        Args:
            layer_sizes (tuple): An iterable containing the number of blocks in each layer
            block_type (Module, optional): Type of block that is to be used to form the layers. Default: SpatioTemporalResBlock.
    """

    def __init__(self, layer_sizes, block_type=SpatioTemporalResBlock, with_classifier=False, return_conv=False, num_classes=101, modality='rgb',
                 use_dropout=True, dropout=0.5, use_l2_norm=False, use_final_bn=False):
        super(R3DNet, self).__init__()
        self.with_classifier = with_classifier
        self.return_conv = return_conv
        self.num_classes = num_classes
        self.use_l2_norm = use_l2_norm
        self.use_final_bn = use_final_bn

        message = 'Classifier to %d classes with r3d backbone;' % num_classes
        if use_dropout: message += ' + dropout %f' % dropout
        if use_l2_norm: message += ' + L2Norm'
        if use_final_bn: message += ' + final BN'
        print(message)

        # first conv, with stride 1x2x2 and kernel size 3x7x7
        if modality == 'uv':
            self.conv1 = SpatioTemporalConv(2, 64, [3, 7, 7], stride=[1, 2, 2], padding=[1, 3, 3])
            print('[Warning]: using optical flow 3D models')
        else:
            self.conv1 = SpatioTemporalConv(3, 64, [3, 7, 7], stride=[1, 2, 2], padding=[1, 3, 3])

        self.bn1 = nn.BatchNorm3d(64)
        self.relu1 = nn.ReLU()
        # output of conv2 is same size as of conv1, no downsampling needed. kernel_size 3x3x3
        self.conv2 = SpatioTemporalResLayer(64, 64, 3, layer_sizes[0], block_type=block_type)
        # each of the final three layers doubles num_channels, while performing downsampling
        # inside the first block
        self.conv3 = SpatioTemporalResLayer(64, 128, 3, layer_sizes[1], block_type=block_type, downsample=True)
        self.conv4 = SpatioTemporalResLayer(128, 256, 3, layer_sizes[2], block_type=block_type, downsample=True)
        self.conv5 = SpatioTemporalResLayer(256, 512, 3, layer_sizes[3], block_type=block_type, downsample=True)

        # if self.return_conv:
        #     self.feature_pool = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))   # 9216
        #     # self.feature_pool = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 4182
        self.feature_pool = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))  # 9216

        # global average pooling of the output
        self.pool = nn.AdaptiveAvgPool3d(1)

        if self.with_classifier:
            if use_dropout:
                self.linear = nn.Sequential(
                    nn.Dropout(dropout),  # dropout=0.5, but args.dropout=0.9
                    nn.Linear(512, self.num_classes))
            else:
                self.linear = nn.Linear(512, self.num_classes)
            # for name, param in self.linear.named_parameters():
            #     if 'bias' in name:
            #         nn.init.constant_(param, 0.0)
            #     elif 'weight' in name:
            #         nn.init.normal_(param, mean=0.0, std=0.01)

            if use_final_bn:
                # self.final_bn = nn.BatchNorm1d(self.param['feature_size'])
                self.final_bn = nn.BatchNorm1d(512)
                self.final_bn.weight.data.fill_(1)
                self.final_bn.bias.data.zero_()

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)

        if self.return_conv:
            x = self.feature_pool(x)
            # print(x.shape)
            return x.view(x.shape[0], -1)

        x = self.pool(x)
        x = x.view(-1, 512)

        if self.with_classifier:
            if self.use_l2_norm:
                x = F.normalize(x, p=2, dim=1)

            if self.use_final_bn:
                x = self.linear(self.final_bn(x))
            else:
                x = self.linear(x)

        return x


# with predictor;
class R3DNet_Focus_MPred_Pro_v2(nn.Module):
    r"""Forms the overall ResNet feature extractor by initializing 5 layers, with the number of blocks in
    each layer set by layer_sizes, and by performing a global average pool at the end producing a
    512-dimensional vector for each element in the batch.
        Args:
            layer_sizes (tuple): An iterable containing the number of blocks in each layer
            block_type (Module, optional): Type of block that is to be used to form the layers. Default: SpatioTemporalResBlock.
    """

    def __init__(self, layer_sizes, block_type=SpatioTemporalResBlock, with_classifier=False, return_conv=False, num_classes=101, modality='rgb',
                 focus_level=5, focus_res=False, conv_level=1, sample_num=1,
                 pro_p=1., pro_clamp_value=0.):  # feat_dim=clip_len//2
        super(R3DNet_Focus_MPred_Pro_v2, self).__init__()
        self.with_classifier = with_classifier
        self.return_conv = return_conv
        self.num_classes = num_classes
        # FEAT_DIMS = [16, 16, 8, 4, 2]
        # self.sample_num = sample_num * 2 #
        self.sample_num = sample_num * 2 * 5 # for new_v5_* dataset
        self.conv_level = conv_level
        FEAT_DIMS = [64, 64, 128, 256, 512]
        NUM_PREDS = [16, 16, 8, 4, 2]
        self.feat_dim = FEAT_DIMS[conv_level - 1]
        self.num_pred = NUM_PREDS[conv_level - 1] // 2

        self.pixpro_p = pro_p
        self.pixpro_clamp_value = pro_clamp_value

        assert focus_level in list(range(-1,6)), 'focus level should be in 0~5, or -1 means no focus.'
        self.focus_level = focus_level # 0~5
        self.F_SIZES = [112, 56, 56, 28, 14, 7]
        self.f_size = self.F_SIZES[self.focus_level]
        self.f_upsample = nn.Upsample(size=(self.f_size, self.f_size), mode='bilinear',align_corners=True)

        self.focus_res = focus_res

        # first conv, with stride 1x2x2 and kernel size 3x7x7
        if modality == 'uv':
            self.conv1 = SpatioTemporalConv(2, 64, [3, 7, 7], stride=[1, 2, 2], padding=[1, 3, 3])
            print('[Warning]: using optical flow 3D models')
        else:
            self.conv1 = SpatioTemporalConv(3, 64, [3, 7, 7], stride=[1, 2, 2], padding=[1, 3, 3])

        self.bn1 = nn.BatchNorm3d(64)
        self.relu1 = nn.ReLU()
        # output of conv2 is same size as of conv1, no downsampling needed. kernel_size 3x3x3
        self.conv2 = SpatioTemporalResLayer(64, 64, 3, layer_sizes[0], block_type=block_type)
        # each of the final three layers doubles num_channels, while performing downsampling
        # inside the first block
        self.conv3 = SpatioTemporalResLayer(64, 128, 3, layer_sizes[1], block_type=block_type, downsample=True)
        self.conv4 = SpatioTemporalResLayer(128, 256, 3, layer_sizes[2], block_type=block_type, downsample=True)
        self.conv5 = SpatioTemporalResLayer(256, 512, 3, layer_sizes[3], block_type=block_type, downsample=True)

        # if self.return_conv:
        self.feature_pool = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))   # 9216
        # self.feature_pool = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 4182

        self.value_transform = nn.Sequential(
            conv1x1(self.feat_dim, self.feat_dim),
        )


        self.projector = conv1x1(self.feat_dim, self.feat_dim)

        # global average pooling of the output
        self.pool = nn.AdaptiveAvgPool3d(1)

        if self.with_classifier:
            self.linear = nn.Linear(512, self.num_classes)

    def forward(self, x, focus_map):
        # generate focus map
        # focus_map = self.f_upsample(focus_map.unsqueeze(1)).unsqueeze(1)
        focus_map = self.f_upsample(focus_map).unsqueeze(1)
        # x_tgt_1, x_tgt_2 = None, None
        # x_pred_1, x_pred_2 = None, None
        # pdb.set_trace()
        if self.focus_level==0:
            if self.focus_res:
                tmp = torch.mul(x, focus_map)
                x = x + tmp
            else:
                focus_map = (focus_map-focus_map.min())/(focus_map.max()-focus_map.min())
                x = torch.mul(x, focus_map)

        # print('input: ', x.size())                # [2, 3, 16, 112, 112]
        x = self.relu1(self.bn1(self.conv1(x)))
        # print('level 1: ', x.size())              # [2, 64, 16, 56, 56]
        if self.focus_level==1:
            if self.focus_res:
                x += torch.mul(x, focus_map)
            else:
                focus_map = (focus_map-focus_map.min())/(focus_map.max()-focus_map.min())
                x = torch.mul(x, focus_map)
        if self.conv_level == 1:
            x_proj = self.projector(x)
            x_reshape = x_proj.reshape(-1, self.sample_num, x.size(1), x.size(2), x.size(3), x.size(4)) # [bs,s_num,64,16,56,56]
            x_tgt_1 = x_reshape[:,:self.sample_num//2,:,:,:,:] # [bs,s_num/2,64,16,56,56]
            x_tgt_1 = x_tgt_1.permute(0,2,1,3,4,5)  # [bs,64,s_num/2,16,56,56]
            x_tgt_1 = x_tgt_1.reshape(x_reshape.size(0), x.size(1), -1,  x.size(3), x.size(4)).contiguous()
            x_tgt_2 = x_reshape[:,self.sample_num//2:,:,:,:,:] # [bs,s_num/2,64,16,56,56]
            x_tgt_2 = x_tgt_2.permute(0,2,1,3,4,5)  # [bs,64,s_num/2,16,56,56]
            x_tgt_2 = x_tgt_2.reshape(x_reshape.size(0), x.size(1), -1,  x.size(3), x.size(4)).contiguous()
            x_pred_2 = self.featprop(x_tgt_1)
            x_pred_1 = self.featprop(x_tgt_2)

        x = self.conv2(x)
        # print('level 2: ', x.size())              # [2, 64, 16, 56, 56]
        if self.focus_level==2:
            if self.focus_res:
                x += torch.mul(x, focus_map)
            else:
                focus_map = (focus_map - focus_map.min()) / (focus_map.max() - focus_map.min())
                x = torch.mul(x, focus_map)
        if self.conv_level == 2:
            x_proj = self.projector(x)
            x_reshape = x_proj.reshape(-1, self.sample_num, x.size(1), x.size(2), x.size(3), x.size(4))  # [bs,s_num,64,16,56,56]
            x_tgt_1 = x_reshape[:, :self.sample_num // 2, :, :, :, :]  # [bs,s_num/2,64,16,56,56]
            x_tgt_1 = x_tgt_1.permute(0, 2, 1, 3, 4, 5)  # [bs,64,s_num/2,16,56,56]
            x_tgt_1 = x_tgt_1.reshape(x_reshape.size(0), x.size(1), -1, x.size(3), x.size(4)).contiguous()
            x_tgt_2 = x_reshape[:, self.sample_num // 2:, :, :, :, :]  # [bs,s_num/2,64,16,56,56]
            x_tgt_2 = x_tgt_2.permute(0, 2, 1, 3, 4, 5)  # [bs,64,s_num/2,16,56,56]
            x_tgt_2 = x_tgt_2.reshape(x_reshape.size(0), x.size(1), -1, x.size(3), x.size(4)).contiguous()
            x_pred_2 = self.featprop(x_tgt_1)
            x_pred_1 = self.featprop(x_tgt_2)


        x = self.conv3(x)
        # print('level 3: ', x.size())              # [2, 128, 8, 28, 28]
        if self.focus_level==3:
            if self.focus_res:
                x += torch.mul(x, focus_map)
            else:
                focus_map = (focus_map - focus_map.min()) / (focus_map.max() - focus_map.min())
                x = torch.mul(x, focus_map)
        if self.conv_level == 3:
            x_proj = self.projector(x)
            x_reshape = x_proj.reshape(-1, self.sample_num, x.size(1), x.size(2), x.size(3), x.size(4))  # [bs,s_num,128, 8, 28, 28]
            x_tgt_1 = x_reshape[:, :self.sample_num // 2, :, :, :, :]  # [bs,s_num/2,128, 8, 28, 28]
            x_tgt_1 = x_tgt_1.permute(0, 2, 1, 3, 4, 5)  # [bs,128,s_num/2, 8, 28, 28]
            x_tgt_1 = x_tgt_1.reshape(x_reshape.size(0), x.size(1), -1, x.size(3), x.size(4)).contiguous()
            x_tgt_2 = x_reshape[:, self.sample_num // 2:, :, :, :, :]  # [bs,s_num/2,128, 8, 28, 28]
            x_tgt_2 = x_tgt_2.permute(0, 2, 1, 3, 4, 5)  # [bs,128,s_num/2, 8, 28, 28]
            x_tgt_2 = x_tgt_2.reshape(x_reshape.size(0), x.size(1), -1, x.size(3), x.size(4)).contiguous()
            x_pred_2 = self.featprop(x_tgt_1)
            x_pred_1 = self.featprop(x_tgt_2)


        x = self.conv4(x)
        # print('level 4: ', x.size())              # [2, 256, 4, 14, 14]
        if self.focus_level==4:
            if self.focus_res:
                x += torch.mul(x, focus_map)
            else:
                focus_map = (focus_map - focus_map.min()) / (focus_map.max() - focus_map.min())
                x = torch.mul(x, focus_map)
        if self.conv_level == 4:
            x_proj = self.projector(x)
            x_reshape = x_proj.reshape(-1, self.sample_num, x.size(1), x.size(2), x.size(3), x.size(4))  # [bs,s_num,256, 4, 14, 14]
            x_tgt_1 = x_reshape[:, :self.sample_num // 2, :, :, :, :]  # [bs,s_num/2,256, 4, 14, 14]
            x_tgt_1 = x_tgt_1.permute(0, 2, 1, 3, 4, 5)  # [bs,256,s_num/2, 4, 14, 14]
            x_tgt_1 = x_tgt_1.reshape(x_reshape.size(0), x.size(1), -1, x.size(3), x.size(4)).contiguous()
            x_tgt_2 = x_reshape[:, self.sample_num // 2:, :, :, :, :]  # [bs,s_num/2,256, 4, 14, 14]
            x_tgt_2 = x_tgt_2.permute(0, 2, 1, 3, 4, 5)  # [bs,256,s_num/2, 4, 14, 14]
            x_tgt_2 = x_tgt_2.reshape(x_reshape.size(0), x.size(1), -1, x.size(3), x.size(4)).contiguous()
            x_pred_2 = self.featprop(x_tgt_1)
            x_pred_1 = self.featprop(x_tgt_2)

        x = self.conv5(x)
        # print('level 5: ', x.size())              # [2, 512, 2, 7, 7]
        if self.focus_level==5:
            if self.focus_res:
                x += torch.mul(x, focus_map)
            else:
                focus_map = (focus_map - focus_map.min()) / (focus_map.max() - focus_map.min())
                x = torch.mul(x, focus_map)
        if self.conv_level == 5:
            x_proj = self.projector(x)
            x_reshape = x_proj.reshape(-1, self.sample_num, x.size(1), x.size(2), x.size(3), x.size(4))  # [bs,s_num,512, 2, 7, 7]
            x_tgt_1 = x_reshape[:, :self.sample_num // 2, :, :, :, :]  # [bs,s_num/2,512, 2, 7, 7]
            x_tgt_1 = x_tgt_1.permute(0, 2, 1, 3, 4, 5)  # [bs,512,s_num/2, 2, 7, 7]
            x_tgt_1 = x_tgt_1.reshape(x_reshape.size(0), x.size(1), -1, x.size(3), x.size(4)).contiguous()
            x_tgt_2 = x_reshape[:, self.sample_num // 2:, :, :, :, :]  # [bs,s_num/2,512, 2, 7, 7]
            x_tgt_2 = x_tgt_2.permute(0, 2, 1, 3, 4, 5)  # [bs,512,s_num/2, 2, 7, 7]
            x_tgt_2 = x_tgt_2.reshape(x_reshape.size(0), x.size(1), -1, x.size(3), x.size(4)).contiguous()
            x_pred_2 = self.featprop(x_tgt_1)
            x_pred_1 = self.featprop(x_tgt_2)

        if self.return_conv:
            x = self.feature_pool(x)
            # print('feature pool: ', x.size())     # [2, 512, 2, 3, 3]
            pred_x = self.predictor(x.view(x.shape[0], -1)) # [2, 9216] -> [2, 512]
            return x.view(x.shape[0], -1), pred_x

        x = self.pool(x)
        # print('pool: ', x.size())                # [2, 512, 1, 1, 1]
        x = x.view(-1, 512)
        # x_pred = self.predictor(x)

        if self.with_classifier:
            x = self.linear(x)

        if self.conv_level > 0:
            return x, x_pred_1, x_pred_2, x_tgt_1, x_tgt_2
        else:
            return x, None, None, None, None

    def featprop(self, feat):
        N, C, T, H, W = feat.shape

        # Value transformation
        feat_value = self.value_transform(feat)
        feat_value = F.normalize(feat_value, dim=1)
        feat_value = feat_value.view(N, C, -1)

        # Similarity calculation
        feat = F.normalize(feat, dim=1)

        # [N, C, T * H * W]
        feat = feat.view(N, C, -1)

        # [N, H * W, H * W]
        attention = torch.bmm(feat.transpose(1, 2), feat)
        attention = torch.clamp(attention, min=self.pixpro_clamp_value)
        if self.pixpro_p < 1.:
            attention = attention + 1e-6
        attention = attention ** self.pixpro_p

        # [N, C, T * H * W]
        feat = torch.bmm(feat_value, attention.transpose(1, 2))

        return feat.view(N, C, T, H, W)


# but do not consider snum here; only predict within a clip
def conv1x1(in_planes, out_planes):
    """1x1 convolution"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=True)


