# coding=utf-8
# Adapted from Ravens - Transporter Networks, Zeng et al., 2021
# https://github.com/google-research/ravens

"""Resnet module."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
import torch.nn.functional as F


def init_xavier_weights(m):
    if type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0)


def Forward3LayersConvBlock(in_channels,
                            kernel_size,
                            out_channels,
                            stride=1,
                            include_batchnorm=False):

    out_channels1, out_channels2, out_channels3 = out_channels

    if include_batchnorm:
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels1,
                      kernel_size=1, stride=stride),
            nn.BatchNorm2d(out_channels1),
            nn.ReLU(),

            nn.Conv2d(out_channels1, out_channels2,
                      kernel_size, padding=1),
            nn.BatchNorm2d(out_channels2),
            nn.ReLU(),

            nn.Conv2d(out_channels2, out_channels3, kernel_size=1),
            nn.BatchNorm2d(out_channels3),
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels1,
                      kernel_size=1, stride=stride),
            nn.ReLU(),

            nn.Conv2d(out_channels1, out_channels2,
                      kernel_size, padding=1),
            nn.ReLU(),

            nn.Conv2d(out_channels2, out_channels3, kernel_size=1),
        )


class IdentityBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 kernel_size,
                 out_channels,
                 activation=True,
                 include_batchnorm=False):
        """
        The identity block is the block that has no conv layer at shortcut.

        Args:
        in_channels: input tensor
        kernel_size: default 3, the kernel size of
            middle conv layer at main path
        out_channels: list of integers, the filters of 3 conv layer at main path
        activation: If True, include ReLU activation on the output.
        include_batchnorm: If True, include intermediate batchnorm layers.
        """
        super().__init__()

        self.activation = activation
        self.relu = nn.ReLU()

        self.forward_block = Forward3LayersConvBlock(
            in_channels,
            kernel_size,
            out_channels,
            include_batchnorm=include_batchnorm)

    def forward(self, x):
        out = self.forward_block(x)

        out = out + x

        if self.activation:
            out = self.relu(out)

        return out


class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 kernel_size,
                 out_channels,
                 stride=(2, 2),
                 activation=True,
                 include_batchnorm=False):
        """A block that has a conv layer at shortcut.

        Note that from stage 3,
        the first conv layer at main path is with strides=(2, 2)
        And the shortcut should have strides=(2, 2) as well

        Args:
        in_channels: input tensor
        kernel_size: default 3, the kernel size of
            middle conv layer at main path
        out_channels: list of integers, the filters of 3 conv layer at main path
        strides: Strides for the first conv layer in the block.
        activation: If True, include ReLU activation on the output.
        include_batchnorm: If True, include intermediate batchnorm layers.
        """
        super().__init__()

        self.forward_block = Forward3LayersConvBlock(
            in_channels,
            kernel_size,
            out_channels,
            stride,
            include_batchnorm)

        _, _, out_channels3 = out_channels

        self.activation = activation
        self.relu = nn.ReLU()

        self.conv_shortcut = nn.Conv2d(in_channels, out_channels3,
                                       kernel_size=1, stride=stride)

        self.include_batchnorm = include_batchnorm
        if include_batchnorm:
            self.bn_shortcut = nn.BatchNorm2d(out_channels3)

    def forward(self, x):
        out = self.forward_block(x)

        shortcut = self.conv_shortcut(x)
        if self.include_batchnorm:
            shortcut = self.bn_shortcut(shortcut)

        out = out + shortcut

        if self.activation:
            out = self.relu(out)

        return out


class ResNet43_8s(nn.Module):
    def __init__(self,
                 in_channels,
                 output_dim,
                 include_batchnorm=False,
                 cutoff_early=False,):
        """Build Resnet 43 8s."""
        super().__init__()

        self.cutoff_early = cutoff_early

        if include_batchnorm:
            self.block_short = nn.Sequential(
                Rearrange('b h w c -> b c h w'),
                nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
            )
        else:
            self.block_short = nn.Sequential(
                Rearrange('b h w c -> b c h w'),
                nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
            )
        self.block_short.apply(init_xavier_weights)

        if cutoff_early:
            self.block_cutoff_early = nn.Sequential(
                ConvBlock(64, 5, [64, 64, output_dim], stride=1,
                          include_batchnorm=include_batchnorm),
                IdentityBlock(output_dim, 5, [64, 64, output_dim],
                              include_batchnorm=include_batchnorm),
                Rearrange('b c h w -> b h w c'),
            )
            self.block_cutoff_early.apply(init_xavier_weights)

        self.block_full1 = nn.Sequential(
            ConvBlock(64, 3, [64, 64, 64], stride=1),
            IdentityBlock(64, 3, [64, 64, 64]))
        
        self.block_full2 = nn.Sequential(
            ConvBlock(64, 3, [128, 128, 128], stride=2),
            IdentityBlock(128, 3, [128, 128, 128]))
        
        self.block_full3 = nn.Sequential(
            ConvBlock(128, 3, [256, 256, 256], stride=2),
            IdentityBlock(256, 3, [256, 256, 256]))
        
        self.block_full4 = nn.Sequential(
            ConvBlock(256, 3, [512, 512, 512], stride=2),
            IdentityBlock(512, 3, [512, 512, 512]))
        
        self.block_full5 = nn.Sequential(
            ConvBlock(512, 3, [256, 256, 256], stride=1),
            IdentityBlock(256, 3, [256, 256, 256]))
        
        self.block_full6 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            ConvBlock(256, 3, [128, 128, 128], stride=1),
            IdentityBlock(128, 3, [128, 128, 128]))
        
        self.block_full7 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            ConvBlock(128, 3, [64, 64, 64], stride=1),
            IdentityBlock(64, 3, [64, 64, 64]))
        
        self.block_full8 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            ConvBlock(64, 3, [16, 16, output_dim], stride=1, activation=False),
            IdentityBlock(output_dim, 3, [16, 16, output_dim], activation=False),
            Rearrange('b c h w -> b h w c'),
        )
        
        self.block_full1.apply(init_xavier_weights)
        self.block_full2.apply(init_xavier_weights)
        self.block_full3.apply(init_xavier_weights)
        self.block_full4.apply(init_xavier_weights)
        self.block_full5.apply(init_xavier_weights)
        self.block_full6.apply(init_xavier_weights)
        self.block_full7.apply(init_xavier_weights)
        self.block_full8.apply(init_xavier_weights)

    def normalize_features(self, x):
        batch_size = x.shape[0]
        norms = torch.norm(x.reshape(batch_size, -1), dim=1, keepdim=True).reshape(batch_size, 1, 1, 1)
        # norms = torch.linalg.norm(x, dim=1).unsqueeze(1) 
        # If norm is zero for a sample, then divide by 1
        # norms = torch.where(norms == 0, torch.tensor(1.0).to(x.device), norms)
        x = x/(norms)
        return x
    
    def forward(self, x, return_features=False):
        out = self.block_short(x)

        if self.cutoff_early:
            return self.block_cutoff_early(out)
        
        f1 = self.block_full1(out) # 1, 64, 320, 320
        f2 = self.block_full2(f1) # 1, 128, 160, 160
        f3 = self.block_full3(f2) # 1, 256, 80, 80
        f4 = self.block_full4(f3) # 1, 512, 40, 40
        f4 = self.normalize_features(f4)
        if return_features:
            return f4.view(-1) 
        f5 = self.block_full5(f4) # 1, 256, 40, 40
        f6 = self.block_full6(f5) # 1, 128, 80, 80
        f7 = self.block_full7(f6) # 1, 64, 160, 160
        out = self.block_full8(f7) # 1, 320, 320, 1
        
        return out # self.block_full(out)

class ResNet36_4s(nn.Module):
    def __init__(self,
                 in_channels,
                 output_dim,
                 include_batchnorm=False,
                 cutoff_early=False):
        """Build Resnet 36 4s."""
        super().__init__()

        self.cutoff_early = cutoff_early

        if include_batchnorm:
            self.block_short = nn.Sequential(
                Rearrange('b h w c -> b c h w'),
                nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
            )
        else:
            self.block_short = nn.Sequential(
                Rearrange('b h w c -> b c h w'),
                nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
            )
        self.block_short.apply(init_xavier_weights)

        if cutoff_early:
            self.block_cutoff_early = nn.Sequential(
                ConvBlock(64, 5, [64, 64, output_dim], stride=1,
                          include_batchnorm=include_batchnorm),
                IdentityBlock(output_dim, 5, [64, 64, output_dim],
                              include_batchnorm=include_batchnorm),
                Rearrange('b c h w -> b h w c'),
            )

            self.block_cutoff_early.apply(init_xavier_weights)
        
        # Different blocks
        self.block_full1 = nn.Sequential(
            ConvBlock(64, 3, [64, 64, 64], stride=1),
            IdentityBlock(64, 3, [64, 64, 64]))
        self.block_full2 = nn.Sequential(
            ConvBlock(64, 3, [64, 64, 64], stride=2),
            IdentityBlock(64, 3, [64, 64, 64]))
        self.block_full3 = nn.Sequential(
            ConvBlock(64, 3, [64, 64, 64], stride=2),
            IdentityBlock(64, 3, [64, 64, 64]))
        self.block_full4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            ConvBlock(64, 3, [64, 64, 64], stride=1),
            IdentityBlock(64, 3, [64, 64, 64]))
        self.block_full5 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            ConvBlock(64, 3, [16, 16, output_dim], stride=1, activation=False),
            IdentityBlock(output_dim, 3, [
                          16, 16, output_dim], activation=False),
            Rearrange('b c h w -> b h w c'),
        )

        # self.block_full.apply(init_xavier_weights)
        self.block_full1.apply(init_xavier_weights)
        self.block_full2.apply(init_xavier_weights)
        self.block_full3.apply(init_xavier_weights)
        self.block_full4.apply(init_xavier_weights)
        self.block_full5.apply(init_xavier_weights)


    def normalize_features(self, x):
        batch_size = x.shape[0]
        norms = torch.norm(x.reshape(batch_size, -1), dim=1, keepdim=True).reshape(batch_size, 1, 1, 1)
        # norms = torch.linalg.norm(x, dim=1).unsqueeze(1) 
        # If norm is zero for a sample, then divide by 1
        # norms = torch.where(norms == 0, torch.tensor(1.0).to(x.device), norms)
        x = x/(norms)
        return x
    
    def forward(self, x, return_features=False):
        out = self.block_short(x)

        if self.cutoff_early:
            return self.block_cutoff_early(out)

        f1 = self.block_full1(out)
        f2 = self.block_full2(f1)
        f2 = self.normalize_features(f2)
        if return_features:
            # Return features
            return f2.view(-1) 
        f3 = self.block_full3(f2)
        f4 = self.block_full4(f3)
        out = self.block_full5(f4)
        
        return out
        
    

####################################################################################
class SimpleNet_attention(nn.Module):
    def __init__(self,
                 in_channels,
                 output_dim,
                 include_batchnorm=False,
                 cutoff_early=False):
        """Build Resnet 36 4s."""
        super().__init__()

        self.cutoff_early = cutoff_early

        if include_batchnorm:
            self.block_short = nn.Sequential(
                nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
            )
        else:
            self.block_short = nn.Sequential(
                nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1), 
                nn.ReLU(),
            )
        self.block_short.apply(init_xavier_weights)

        if cutoff_early:
            self.block_cutoff_early = nn.Sequential(
                ConvBlock(64, 5, [64, 64, output_dim], stride=1,
                          include_batchnorm=include_batchnorm),
                IdentityBlock(output_dim, 5, [64, 64, output_dim],
                              include_batchnorm=include_batchnorm),
                Rearrange('b c h w -> b h w c'),
            )

            self.block_cutoff_early.apply(init_xavier_weights)
        
        # Different blocks (small)
        self.rearrange_in = Rearrange('b h w c -> b c h w')
        self.conv_block1 = ConvBlock(64, 3, [64, 64, 64], stride=1)
        self.identity_block1 = IdentityBlock(64, 3, [64, 64, 64])
        self.conv_block2 = ConvBlock(64, 3, [64, 64, 64], stride=2) 
        self.identity_block2 = IdentityBlock(64, 3, [64, 64, 64]) 
        self.conv_block3 = ConvBlock(64, 3, [64, 64, 64], stride=2)
        self.identity_block3 = IdentityBlock(64, 3, [64, 64, 64])
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear')
        self.conv_block4 = ConvBlock(64, 3, [64, 64, 64], stride=1)
        self.identity_block4 = IdentityBlock(64, 3, [64, 64, 64])
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear')
        self.conv_block5 = ConvBlock(64, 3, [16, 16, output_dim], stride=1, activation=False) 
        self.identity_block5 = IdentityBlock(output_dim, 3, [16, 16, output_dim], activation=False) 
        self.rearrange_out = Rearrange('b c h w -> b h w c')
                
        # Initialize weights
        self.conv_block1.apply(init_xavier_weights)
        self.identity_block1.apply(init_xavier_weights)
        self.conv_block2.apply(init_xavier_weights)
        self.identity_block2.apply(init_xavier_weights)
        self.conv_block3.apply(init_xavier_weights)
        self.identity_block3.apply(init_xavier_weights)
        self.conv_block4.apply(init_xavier_weights) # 
        self.identity_block4.apply(init_xavier_weights) # 
        self.conv_block5.apply(init_xavier_weights)
        self.identity_block5.apply(init_xavier_weights)
        
        self.downsample_img = nn.Upsample(scale_factor=0.25, mode='bilinear')
        self.upsample_img = nn.Upsample(scale_factor=4, mode='bilinear')
        

    def normalize_features(self, x):
        batch_size = x.shape[0]
        norms = torch.norm(x.reshape(batch_size, -1), dim=1, keepdim=True).reshape(batch_size, 1, 1, 1)
        # norms = torch.linalg.norm(x, dim=1).unsqueeze(1) 
        # If norm is zero for a sample, then divide by 1
        # norms = torch.where(norms == 0, torch.tensor(1.0).to(x.device), norms)
        x = x/(norms)
        return x
    
    def forward(self, x_in, return_features=False):
        
        # Rearrange input
        x = self.rearrange_in(x_in)
        
        # Downsample image
        x = self.downsample_img(x)
        
        out = self.block_short(x)

        if self.cutoff_early:
            return self.block_cutoff_early(out)

        x = self.conv_block1(out) 
        x = self.identity_block1(x)
        x = self.conv_block2(x) 
        x = self.identity_block2(x) 
        x = self.conv_block3(x)
        x = self.identity_block3(x) 
        x = self.normalize_features(x)
        if return_features:
            # Return features
            return x.reshape(-1)
        x = self.upsample1(x)
        x = self.conv_block4(x) 
        x = self.identity_block4(x)
        x = self.upsample2(x)
        x = self.conv_block5(x)
        x = self.identity_block5(x)
        
        # Upsample image
        x = self.upsample_img(x)
        
        # Rearrange output
        out = self.rearrange_out(x)
                        
        return out

