""" dt_net_2d.py
    DeepThinking network 2D.

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

import torch
from torch import nn
import math
import random

from .blocks import BasicBlock2D as BasicBlock
from .blocks import Head
from .transformer import TransformerBlock

# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702)
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914


class CnnTransformer(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(
        self,
        block,
        num_blocks,
        width,
        in_channels=3,
        recall=True,
        group_norm=False,
        num_class=10,
        depth=4,
        heads=4,
        mlp_dim=512,
        **kwargs
    ):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        self.num_class = num_class
        proj_conv = nn.Conv2d(
            in_channels, width, kernel_size=3, stride=1, padding=1, bias=False
        )
        conv1 = self._make_layer(block, width, num_blocks=2, stride=2)  # 16 x 16 x 32

        self.projection = nn.Sequential(proj_conv, nn.ReLU(inplace=False), conv1)

        self.recur_block = TransformerBlock(
            width * 16 * 16,
            width * 4,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim,
        )
        self.head = nn.Sequential(nn.Linear(width * 4, width * 4), nn.ReLU(inplace=False), nn.Linear(width * 4, num_class))

        self.ssh_head = nn.Sequential(nn.Linear(width * 4, width * 4), nn.ReLU(inplace=False), nn.Linear(width * 4, 4))

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(
        self,
        x,
        iters_to_do,
        interim_thought=None,
        debug=False,
        return_ssh=False,
        **kwargs
    ):
        xt = self.projection(x)
        xt = torch.flatten(xt, start_dim=1)
        xt = xt.unsqueeze(1)
        X = xt.repeat(1, iters_to_do, 1)
        transformer_out, X = self.recur_block(X)
        out = self.head(transformer_out)
        ssh_out = self.ssh_head(transformer_out)
        all_outputs = self.head(X)[:, 1:, :]
        all_ssh_outputs = self.ssh_head(X)[:, 1:, :]
        res = []
        norm = []
        for i in range(iters_to_do - 1):
            h_prev = X[:, i + 1, :]
            h_t = X[:, i + 2, :]
            res.append((h_t - h_prev).norm().item() / (1e-5 + h_t.norm().item()))
            # res.append((h_t - h_t_old).norm().item())
            norm.append(h_t.norm().item())

        if self.training:
            return out, transformer_out, ssh_out

        if debug:
            return (all_outputs, res, norm)

        if return_ssh:
            return all_outputs, all_ssh_outputs
        return all_outputs

def cnn_transformer(width, **kwargs):
    return CnnTransformer(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
    )