""" dt_net_2d.py
    DeepThinking network 2D.

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

import torch
from torch import nn
import math
import random

from .blocks import BasicBlock2D as BasicBlock

# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702)
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)

    def forward_timestamp_t(self, x, t):
        x = x + self.pe[t].unsqueeze(-1).unsqueeze(-1)
        return x


class Head(nn.Module):
    def __init__(self, width, num_class):
        super(Head, self).__init__()
        head_conv1 = nn.Conv2d(
            width, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        bn1 = nn.BatchNorm2d(64)
        head_conv2 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        bn2 = nn.BatchNorm2d(32)
        global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.convs = nn.Sequential(
            head_conv1, bn1, nn.ReLU(), head_conv2, bn2, nn.ReLU(), global_avg_pool
        )
        self.fc = nn.Linear(32, num_class)

    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class HaltConv(nn.Module):
    def __init__(self, ht_channel=128):
        super(HaltConv, self).__init__()
        halt_conv1 = nn.Conv2d(ht_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        bn1 = nn.BatchNorm2d(64)
        halt_conv2 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, bias=False)
        bn2 = nn.BatchNorm2d(32)
        global_avg_pool = nn.AdaptiveAvgPool2d(1)
        max_pool = nn.MaxPool2d(kernel_size=32)
        
        # out_conv = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.convs = nn.Sequential(halt_conv1, bn1, nn.ReLU(), halt_conv2, bn2, nn.ReLU(), max_pool)
        self.fc = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        p_t = self.sigmoid(x)
        return p_t
    
    

class DTNet(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(
        self,
        block,
        num_blocks,
        width,
        in_channels=3,
        recall=True,
        group_norm=False,
        use_act=False,
        num_class=0,
        img_classifation=False,
        **kwargs
    ):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        self.num_class = num_class if num_class > 0 else 2
        proj_conv = nn.Conv2d(
            in_channels, width, kernel_size=3, stride=1, padding=1, bias=False
        )
        # extract_layer = self._make_layer(block, width, num_blocks=4, stride=1)
        recur_layers = []
        if recall:
            conv_recall = nn.Conv2d(
                width + in_channels,
                width,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            )
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        self.projection = nn.Sequential(proj_conv, nn.ReLU(inplace=False))
        
        self.recur_block = nn.Sequential(*recur_layers)
        num_class = num_class if num_class > 0 else 2
        
        head_conv1 = nn.Conv2d(
            width, 32, kernel_size=3, stride=1, padding=1, bias=False
        )
        head_conv2 = nn.Conv2d(
            32, 8, kernel_size=3, stride=1, padding=1, bias=False
        )
        head_conv3 = nn.Conv2d(8, num_class, kernel_size=3, stride=1, padding=1, bias=False)
        self.head = nn.Sequential(
            head_conv1,
            nn.ReLU(inplace=False),
            head_conv2,
            nn.ReLU(inplace=False),
            head_conv3,
        )
        if img_classifation:
            self.head = Head(width, num_class)
            self.ssh_head = Head(width, 4)        
        else:
            self.ssh_head = self.head

        self.use_act = use_act
        if use_act:
            self.halt_conv = HaltConv()
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, debug=False, return_ssh=False, **kwargs):
        h_0 = self.projection(x)

        if interim_thought is None:
            h_t = h_0
        else:
            h_t = interim_thought

        all_outputs = torch.zeros(
            (x.size(0), iters_to_do, self.num_class, x.size(2), x.size(3))
        ).to(x.device)
        all_ssh_outputs = torch.zeros(
            (x.size(0), iters_to_do, self.num_class, x.size(2), x.size(3))
        ).to(x.device)
            
        res = []
        norm = []
        
        pt_accumulate = torch.zeros((x.size(0), 1)).to(x.device)
        p_t = torch.zeros((x.size(0), 1)).to(x.device)
        epsilon = 1e-5
        halt = torch.zeros((x.size(0), 1)).to(x.device)
        
        
        for i in range(iters_to_do):
            h_t_old = h_t
            if self.recall:
                h_t = torch.cat([h_t, x], 1)
            h_t = self.recur_block(h_t)
            ############### act ##################
            if self.use_act:
                if i == 0:
                    sum_pt = p_t
                
                if i > 0:
                    p_t = self.halt_conv(h_t_old)
                    pt_accumulate = pt_accumulate + p_t
                    halt = torch.where(mask, torch.tensor(i), halt)
                    PT = p_t.view(x.size(0), 1, 1, 1)
                    mask = mask.view(x.size(0), 1, 1, 1)
                    h_t = h_t + PT * h_t * mask
                    
                mask = pt_accumulate < 1 - epsilon
                sum_pt = sum_pt + p_t*mask
            ######################################
                
            res.append((h_t - h_t_old).norm().item() / (1e-5 + h_t.norm().item()))
            norm.append(h_t.norm().item())
            
            out = self.head(h_t)
            ssh_out = self.ssh_head(h_t)
            all_outputs[:, i] = out
            all_ssh_outputs[:, i] = ssh_out
        
        if self.training:
            if self.use_act:
                return out, h_t, ssh_out, sum_pt
            return out, h_t, ssh_out

        if debug:
            return (all_outputs, res, norm)
        
        if return_ssh:
            return all_outputs, all_ssh_outputs
        return all_outputs


def dt_net_2d(width, **kwargs):
    return DTNet(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        num_class=kwargs["num_class"],
    )


def dt_net_recall_2d(width, **kwargs):
    return DTNet(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=True,
        num_class=kwargs["num_class"],
    )

def dt_net_act_2d(width, **kwargs):
    print('run in here')
    return DTNet(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=True,
        use_act=True,
        num_class=kwargs["num_class"],
    )


def dt_net_attention_2d(width, **kwargs):
    return DTNet(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        num_class=kwargs["num_class"],
    )
    
def dt_net_attention_gn_2d(width, **kwargs):
    return DTNet(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
        num_class=kwargs["num_class"],
    )

def dt_net_gn_2d(width, **kwargs):
    return DTNet(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
    )


def dt_net_recall_gn_2d(width, **kwargs):
    return DTNet(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=True,
        group_norm=True,
    )
