""" blocks.py
    Neural network blocks.

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    BasicBlocks borrowed from ResNet architechtures
    Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>

    Developed for DeepThinking project
    October 2021
"""
import math

import torch
from torch import nn
import torch.nn.functional as F


class BasicBlock1D(nn.Module):
    """Basic residual block class 1D"""

    expansion = 1

    def __init__(self, in_planes, planes, stride=1, group_norm=False):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.gn1 = nn.GroupNorm(4, planes, affine=False) if group_norm else nn.Sequential()
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.gn2 = nn.GroupNorm(4, planes, affine=False) if group_norm else nn.Sequential()

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(nn.Conv1d(in_planes, self.expansion * planes,
                                                    kernel_size=1, stride=stride, bias=False))

    def forward(self, x):
        out = F.relu(self.gn1(self.conv1(x)))
        out = self.gn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class BasicBlock2D(nn.Module):
    """Basic residual block class 2D"""

    expansion = 1

    def __init__(self, in_planes, planes, stride=1, group_norm=False, batch_norm=False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.gn1 = nn.GroupNorm(4, planes, affine=False) if group_norm else nn.Sequential()
        self.bn1 = nn.BatchNorm2d(planes) if batch_norm else nn.Sequential()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.gn2 = nn.GroupNorm(4, planes, affine=False) if group_norm else nn.Sequential()
        self.bn2 = nn.BatchNorm2d(planes) if batch_norm else nn.Sequential()
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes,
                                                    kernel_size=1, stride=stride, bias=False))

    def forward(self, x):
        out = F.relu(self.bn1(self.gn1(self.conv1(x))))
        out = self.bn2(self.gn2(self.conv2(out)))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    # def forward(self, x):
    #     x = x + self.pe[: x.size(0)]
    #     return self.dropout(x)

    def forward(self, x, t):
        x = x + self.pe[t].unsqueeze(-1).unsqueeze(-1)
        return x

class Head(nn.Module):
    def __init__(self, width, num_class, batch_norm=False):
        super(Head, self).__init__()
        head_conv1 = nn.Conv2d(
            width, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        bn1 = nn.BatchNorm2d(64) if batch_norm else nn.Sequential()
        head_conv2 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        bn2 = nn.BatchNorm2d(32) if batch_norm else nn.Sequential()
        global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.convs = nn.Sequential(
            head_conv1, bn1, nn.ReLU(), head_conv2, bn2, nn.ReLU(), global_avg_pool
        )
        self.fc = nn.Linear(32, num_class)

    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
class HaltConv(nn.Module):
    def __init__(self, ht_channel=128):
        super(HaltConv, self).__init__()
        halt_conv1 = nn.Conv2d(ht_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        bn1 = nn.BatchNorm2d(64)
        halt_conv2 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, bias=False)
        bn2 = nn.BatchNorm2d(32)
        
        # out_conv = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.convs = nn.Sequential(halt_conv1, bn1, nn.ReLU(), halt_conv2, bn2, nn.ReLU())
        self.fc = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.convs(x)
        x = torch.mean(x, dim=(2, 3))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        p_t = self.sigmoid(x)
        return p_t