""" dt_net_2d.py
    DeepThinking network 2D.

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

import torch
from torch import nn

from .blocks import BasicBlock2D as BasicBlock

# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702)
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914


class DTNet(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False, **kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU(inplace=False))
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(inplace=False),
                                  head_conv2, nn.ReLU(inplace=False),
                                  head_conv3)
        self.gated_attetion = GatedAttentionLayer(width, width // 2)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, debug=False, is_hard_net=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        attention_weights = []
        act_probs = torch.zeros((x.size(0), iters_to_do)).to(x.device)
        att_weight_accumulate = torch.zeros((x.size(0), 1)).to(x.device)
        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            all_pool_interim = torch.mean(interim_thought, (2, 3)).unsqueeze(-1)
            att_weight = self.gated_attetion(all_pool_interim.transpose(2, 1))
            att_weight_accumulate = att_weight_accumulate + att_weight
            if i == 0:
                all_interim = torch.einsum('bijkl,bl->bijk', interim_thought.unsqueeze(-1), att_weight)
            else:
                all_interim = all_interim + torch.einsum('bijkl,bl->bijk', interim_thought.unsqueeze(-1), att_weight)
            interim_thought = all_interim / att_weight_accumulate.view(x.size(0), 1, 1, 1)
            out = self.head(interim_thought)
            all_outputs[:, i] = out
            if not self.training:
                attention_weights.append(att_weight)
                act_prob = torch.sigmoid(torch.log(att_weight))
                act_probs[:, i] = act_prob.squeeze(-1)

        if self.training and not is_hard_net:
            return out, interim_thought
        
        if debug:
            return (all_outputs, attention_weights, act_probs) 

        return all_outputs

class GatedAttentionLayer(nn.Module):
    def __init__(self, d_model, hidden_state):
        super().__init__()
        self.V = nn.Linear(d_model, hidden_state)
        self.U = nn.Linear(d_model, hidden_state)
        self.w = nn.Linear(hidden_state, 1)
    
    def forward(self, X):
        Xv = self.V(X)
        Xv = nn.Tanh()(Xv)
        Xu = self.U(X)
        Xu = nn.Sigmoid()(Xu)   
        X = Xv * Xu
        X = self.w(X)
        X = torch.exp(X)
        return X.squeeze(-1)
        
        


def dt_net_2d(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=False)


def dt_net_recall_2d(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


def dt_net_gn_2d(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=False, group_norm=True)


def dt_net_recall_gn_2d(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True)
