""" dt_net_2d.py
    DeepThinking network 2D.

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

import torch
from torch import nn
import math
import random

from .blocks import BasicBlock2D as BasicBlock

# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702)
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)

    def forward_timestamp_t(self, x, t):
        x = x + self.pe[t].unsqueeze(-1).unsqueeze(-1)
        return x


class Head(nn.Module):
    def __init__(self, width, num_class):
        super(Head, self).__init__()
        head_conv1 = nn.Conv2d(
            width, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        bn1 = nn.BatchNorm2d(64)
        head_conv2 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        bn2 = nn.BatchNorm2d(32)
        global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.convs = nn.Sequential(
            head_conv1, bn1, nn.ReLU(), head_conv2, bn2, nn.ReLU(), global_avg_pool
        )
        self.fc = nn.Linear(32, num_class)

    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


class ResnetCifar(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(
        self,
        block,
        num_blocks,
        width,
        in_channels=3,
        recall=True,
        group_norm=False,
        use_attention=False,
        num_class=0,
        gating=False,
        max_iters=8,
        imagenet=False,
        batch_norm=False,
        tiny_imagenet=False,
        **kwargs
    ):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        self.batch_norm = batch_norm
        self.num_class = num_class
        proj_conv = nn.Conv2d(
            in_channels, width, kernel_size=3, stride=1, padding=1, bias=False
        )
        # extract_layer = list(self._make_layer(block, width, num_blocks=2, stride=2))
        extract_layer = [self._make_layer(block, width, num_blocks=2, stride=2)]
        
        if imagenet:
            extract_layer = []
            extract_layer.append(self._make_layer(block, width, num_blocks=2, stride=2))
            extract_layer.append(self._make_layer(block, width * 2, num_blocks=2, stride=2))
            extract_layer.append(self._make_layer(block, width * 4, num_blocks=2, stride=2))
            width = width * 4
        elif tiny_imagenet:
            extract_layer = []
            width = width * 2
            extract_layer.append(self._make_layer(block, width, num_blocks=2, stride=2))
            width = width * 2
            extract_layer.append(self._make_layer(block, width, num_blocks=2, stride=2))
        
        self.feedforward_layers = nn.ModuleList()
        for _ in range(max_iters):
            internal_block = []
            for j in range(len(num_blocks)):
                internal_block.append(self._make_layer(block, width, num_blocks[j], stride=1))
            self.feedforward_layers.append(nn.Sequential(*internal_block))

        self.projection = nn.Sequential(proj_conv, nn.ReLU(inplace=False), *extract_layer)
        # self.projection = nn.Sequential(proj_conv, nn.ReLU(inplace=False), extract_layer)
        
        if self.num_class:
            self.head = Head(width, num_class)
        else:
            head_conv1 = nn.Conv2d(
                width, 32, kernel_size=3, stride=1, padding=1, bias=False
            )
            head_conv2 = nn.Conv2d(
                32, 8, kernel_size=3, stride=1, padding=1, bias=False
            )
            head_conv3 = nn.Conv2d(8, 2, kernel_size=3, stride=1, padding=1, bias=False)
            self.head = nn.Sequential(
                head_conv1,
                nn.ReLU(inplace=False),
                head_conv2,
                nn.ReLU(inplace=False),
                head_conv3,
            )
        self.ssh_head = Head(width, 4)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, batch_norm=self.batch_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, debug=False, return_ssh=False, **kwargs):
        xt = self.projection(x)

        if interim_thought is None:
            h_prev = xt
        else:
            h_prev = interim_thought

        if self.num_class:
            all_outputs = torch.zeros((x.size(0), iters_to_do, self.num_class)).to(
                x.device
            )
        else:
            all_outputs = torch.zeros(
                (x.size(0), iters_to_do, 2, x.size(2), x.size(3))
            ).to(x.device)
        all_ssh_outputs = torch.zeros((x.size(0), iters_to_do, 4)).to(x.device)
        res = []
        norm = []
        
        feature_maps = []
        for i in range(iters_to_do):
            if i < len(self.feedforward_layers):
                h_t = self.feedforward_layers[i](h_prev)
            out = self.head(h_t)
            ssh_out = self.ssh_head(h_t)
            all_outputs[:, i] = out
            all_ssh_outputs[:, i] = ssh_out
            h_prev = h_t
                
        #         if i < 4:
        #             import torch.nn.functional as F
        #         import cv2
        #         import os
        #         import numpy as np
        #         import matplotlib.pyplot as plt
        #         from matplotlib.colors import LinearSegmentedColormap
                
        #         os.makedirs("log_hidden", exist_ok=True)
        #         x_norm = h_t.mean(dim=1).detach().cpu().numpy()
        #         x_norm = x_norm / np.max(x_norm)
        #         x_norm = x_norm.squeeze(0)
        #         x_norm = cv2.resize(x_norm, (32, 32), interpolation = cv2.INTER_LINEAR)
        #         x_norm = x_norm[:, :, np.newaxis]
        #         x_norm = x_norm * 255
        #         x_norm = x_norm.astype(np.uint8)
        #         x_norm = cv2.cvtColor(x_norm,cv2.COLOR_GRAY2RGB)
        #         x_norm = cv2.applyColorMap(x_norm, cv2.COLORMAP_JET)
        #         x_norm = x_norm[..., ::-1]
        #         cv2.imwrite(f'log_hidden/{i}.png', x_norm)
        #         feature_maps.append(x_norm)
        
        # # Create a figure
        # fig, axs = plt.subplots(2, 6, figsize=(18, 6))  # 2 rows, 6 columns
        # # cmap = "magma"
        # # Create a custom colormap (red -> black -> blue)
        # colors = [(0, 0, 1), (1, 0, 0)]  # Red -> Black -> Blue
        # cmap = LinearSegmentedColormap.from_list("blue_red", colors, N=256)


        # # Plot the original image
        # original_image = cv2.imread("visualize_hard_sample.png")
        # axs[0, 0].imshow(original_image, cmap='gray')  # Display in grayscale
        # axs[0, 0].set_title("Original Image", fontsize=20)
        # axs[0, 0].axis('off')

        # # Plot the feature maps
        # for i in range(len(feature_maps)):
        #     row, col = divmod(i + 1, 6)  # Start from column 1
        #     axs[row, col].imshow(feature_maps[i], cmap=cmap)
        #     # idx = i * 2 
        #     axs[row, col].set_title(f"Iteration #{i + 1}", fontsize=18)
        #     axs[row, col].axis('off')

        # # Add an empty space for the last subplot if necessary
        # if len(feature_maps) < 11:
        #     axs[1, -1].axis('off')
            
        # # Add a colorbar for the feature maps
        # cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # Position: [left, bottom, width, height]
        # norm = plt.Normalize(vmin=0, vmax=1)  # Adjust vmin and vmax as needed
        # sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        # sm.set_array([])
        # # plt.colorbar(sm, cax=cbar_ax)
        # colorbar = plt.colorbar(sm, cax=cbar_ax)

        # # Adjust the fontsize for the colorbar
        # colorbar.ax.tick_params(labelsize=14)  # Increase the tick font size
        # # colorbar.set_label("Intensity", fontsize=14)  # Add a label with larger font size

        # plt.tight_layout(rect=[0, 0, 0.9, 1])  # Adjust layout to fit the colorbar
        # plt.savefig("hidden_states.pdf")
    
        if self.training:
            return out, h_t, ssh_out

        if debug:
            return (all_outputs, res, norm)
        
        if return_ssh:
            return all_outputs, all_ssh_outputs
        return all_outputs

def resnet_cifar(width, **kwargs):
    return ResnetCifar(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        max_iters=kwargs["max_iters"]
    )

def resnet_imagenet(width, **kwargs):
    return ResnetCifar(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        max_iters=4,
        batch_norm=True,
        imagenet=True
    )

def resnet_tiny_imagenet(width, **kwargs):
    return ResnetCifar(
        BasicBlock,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        max_iters=kwargs["max_iters"],
        batch_norm=True,
        imagenet=False,
        tiny_imagenet=True,
    )
