""" dt_net_2d.py
    DeepThinking network 2D.

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

import torch
from torch import nn
from torch.nn import functional as F
import math
import random

from .blocks import BasicBlock2D as BasicBlock
from .blocks import Head, PositionalEncoding, HaltConv

# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702)
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914


class CnnGRU(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(
        self,
        block,
        gru_layer,
        num_blocks,
        width,
        in_channels=3,
        recall=True,
        group_norm=False,
        num_class=0,
        use_act=False,
        batch_norm=False,
        imagenet=False,
        tiny_imagenet=False,
        cifar=True,
        pos_enc=False,
        ssl=True,
        **kwargs
    ):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        self.batch_norm = batch_norm
        self.num_class = num_class
        
        proj_conv = nn.Conv2d(
            in_channels, width, kernel_size=3, stride=1, padding=1, bias=False
        )
        
        if imagenet:
            extract_layer = []
            extract_layer.append(self._make_layer(block, width, num_blocks=2, stride=2))
            extract_layer.append(self._make_layer(block, width * 2, num_blocks=2, stride=2))
            extract_layer.append(self._make_layer(block, width * 4, num_blocks=2, stride=2))
            width = width * 4
        elif tiny_imagenet:
            extract_layer = []
            width *= 2
            extract_layer.append(self._make_layer(block, width, num_blocks=2, stride=2))
            width *= 2
            extract_layer.append(self._make_layer(block, width, num_blocks=2, stride=2))
        elif cifar:
            extract_layer = [self._make_layer(block, width, num_blocks=2, stride=2)]
        else:
            extract_layer = [nn.Sequential()]

        self.projection = nn.Sequential(proj_conv, nn.ReLU(inplace=False), *extract_layer)
        
        self.recur_block = gru_layer(width)
        if num_class > 0:
            self.head = Head(width, num_class, batch_norm=True)
        else:
            head_conv1 = nn.Conv2d(
                width, 32, kernel_size=3, stride=1, padding=1, bias=False
            )
            head_conv2 = nn.Conv2d(
                32, 8, kernel_size=3, stride=1, padding=1, bias=False
            )
            head_conv3 = nn.Conv2d(8, 2, kernel_size=3, stride=1, padding=1, bias=False)
            self.head = nn.Sequential(
                head_conv1,
                nn.ReLU(inplace=False),
                head_conv2,
                nn.ReLU(inplace=False),
                head_conv3,
            )
        if ssl:
            self.ssh_head = Head(width, 4, batch_norm=True)
        else:
            self.ssh_head = nn.Sequential()
        self.ssl = ssl
            
        self.use_act = use_act
        if use_act:
            self.halt_conv = HaltConv(ht_channel=width)
        self.pos_enc = PositionalEncoding(width) if pos_enc else nn.Sequential()
        self.use_pos_enc = pos_enc
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, batch_norm=self.batch_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, debug=False, return_ssh=False, **kwargs):
        xt = self.projection(x)
        batch_size = xt.shape[0]
        original_img = x.permute(0, 2, 3, 1).detach().cpu().numpy() * 255
        if interim_thought is None:
            h_prev = torch.zeros_like(xt).to(xt.device)
        else:
            h_prev = interim_thought

        
        if self.num_class:
            all_outputs = torch.zeros((x.size(0), iters_to_do, self.num_class)).to(
                x.device
            )
            all_ssh_outputs = torch.zeros((x.size(0), iters_to_do, 4)).to(x.device)
        else:
            all_outputs = torch.zeros(
                (x.size(0), iters_to_do, 2, x.size(2), x.size(3))
            ).to(x.device)
            all_ssh_outputs = torch.zeros(
                (x.size(0), iters_to_do, 2, x.size(2), x.size(3))
            ).to(x.device)
        all_ssh_outputs = torch.zeros((x.size(0), iters_to_do, 4)).to(x.device)
        
        res = []
        norm = []
        
        pt_accumulate = torch.zeros((x.size(0), 1)).to(x.device)
        p_t = torch.zeros((x.size(0), 1)).to(x.device)
        
        epsilon = 0.01
        halt = torch.zeros((x.size(0), 1)).to(x.device)
    
        if self.use_act:
            h_act = torch.zeros_like(h_prev).to(x.device)
            
        feature_maps = []
        h_total = torch.zeros_like(h_prev).to(h_prev.device)
        
        active_mask = torch.ones(batch_size, device=x.device)
        
        for i in range(iters_to_do):
            if self.use_pos_enc:
                h_prev = self.pos_enc(xt, i)
            h_t = self.recur_block(h_prev, xt)

            ############### act ##################
            if self.use_act:
                if i == 0:
                    sum_pt = p_t
                
                if i > 0:
                    p_t = self.halt_conv(h_prev)
                    pt_accumulate = pt_accumulate + p_t
                    halt = torch.where(mask, torch.tensor(i), halt)
                    PT = p_t.view(x.size(0), 1, 1, 1)
                    mask = mask.view(x.size(0), 1, 1, 1)
                    h_act = h_act + PT * h_t * mask
                
                mask = pt_accumulate < 1 - epsilon
                sum_pt = sum_pt + p_t*mask
            ######################################
                
            threshold = 0.000 # Ngưỡng thay đổi nhỏ giữa h_t và h_prev

            if not self.training and threshold > 0:
                # Tính norm giữa h_t và h_prev cho từng sample
                norm_change = torch.norm(h_t - h_prev, p=2, dim=[1,2,3]) / torch.norm(h_t, p=2, dim=[1,2,3])  # Tính norm theo từng sample (dim=-1)

                # Để tạo mask: norm_change < threshold sẽ có giá trị 1 (có thể dừng)
                stop_mask = norm_change < threshold  # (batch_size,)

                # Cập nhật active_mask để chỉ dừng những sample đó
                active_mask = active_mask * (1 - stop_mask.float())

                # Update h_curr chỉ cho các sample chưa dừng
                expand_mask = active_mask.view(batch_size, 1, 1, 1)
                h_prev = h_t * expand_mask + h_prev * (1 - expand_mask)
            #####################################
                
                    
            res.append((h_t - h_prev).norm().item())
            norm.append(h_t.norm().item())

            if self.use_act:
                out = self.head(h_act)
            else:
                out = self.head(h_t)
            if self.ssl:
                ssh_out = self.ssh_head(h_t)
            else:
                ssh_out = torch.zeros_like(all_ssh_outputs[:, i]).to(x.device)
            all_outputs[:, i] = out
            all_ssh_outputs[:, i] = ssh_out
            if threshold == 0:
                h_prev = h_t
    
        if self.training:
            if self.use_act:
                return out, h_t, ssh_out, sum_pt
            return out, h_t, ssh_out

        if debug:
            return (all_outputs, res, norm)
        
        if return_ssh:
            return all_outputs, all_ssh_outputs
        
        if kwargs.get("return_loss", False) and kwargs.get("gr_truth", None) is not None:
            criterion = torch.nn.CrossEntropyLoss(reduction="none")
            outputs_max_iters = all_outputs.view(all_outputs.size(0),
                                                       all_outputs.size(1), -1)
            targets = kwargs["gr_truth"].view(targets.size(0), -1)
            cls_loss = criterion(outputs_max_iters, targets).mean()
            
        # if kwargs.get("return_loss", False) and kwargs.get("ssh_gr_truth", None) is not None:   
        #     criterion = torch.nn.CrossEntropyLoss(reduction="none")
        #     ssh_outputs_max_iters = all_ssh_outputs.view(all_ssh_outputs.size(0),
        #                                                all_ssh_outputs.size(1), -1)
        #     ssh_labels = kwargs["ssh_gr_truth"].view(ssh_labels.size(0), -1)
        #     ssh_loss = criterion(ssh_outputs_max_iters, ssh_labels).mean()
        #     if kwargs.get("return_loss", False) and kwargs.get("gr_truth", None) is not None:
        #         return all_outputs, cls_loss, all_ssh_outputs, ssh_loss
        #     else:
        #         return all_outputs, cls_loss, all_ssh_outputs, ssh_loss
            
        #     return all_outputs,
        return all_outputs


class GRULayer(nn.Module):
    def __init__(self, d_model, group_norm=False):
        super().__init__()
        self.Wz = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Uz = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Wr = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Ur = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Wh = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Uh = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.gn = nn.Sequential()
        if group_norm:
            self.gn = nn.GroupNorm(d_model, d_model)

    def forward(self, h_prev, xt):
        zt = nn.Sigmoid()(self.gn(self.Wz(xt) + self.Uz(h_prev)))
        rt = nn.Sigmoid()(self.gn(self.Wr(xt) + self.Ur(h_prev)))
        h_t_mu = nn.Tanh()(self.gn(self.Wh(xt) + self.Uh(rt * h_prev)))
        h_t = (1 - zt) * h_prev + zt * h_t_mu
        return h_t
    
class LiGRU(nn.Module):
    def __init__(self, d_model, batch_norm=False):
        super().__init__()
        self.Wz = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Uz = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Wh = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Uh = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(d_model) if batch_norm else nn.Sequential()
        self.bn2 = nn.BatchNorm2d(d_model) if batch_norm else nn.Sequential()
        
    def forward(self, h_prev, xt):
        zt = nn.Sigmoid()(self.bn1(self.Wz(xt)) + self.Uz(h_prev))
        h_t_mu = nn.ReLU()(self.bn2(self.Wh(xt)) + self.Uh(h_prev))
        h_t = zt * h_prev + (1 - zt) * h_t_mu
        return h_t
    
class GatedAttentionLayer(nn.Module):
    def __init__(self, d_model, hidden_state, batch_norm=True):
        super().__init__()
        self.V = nn.Linear(d_model, hidden_state)
        self.bn1 = nn.BatchNorm1d(hidden_state) if batch_norm else nn.Sequential()
        self.U = nn.Linear(d_model, hidden_state)
        self.bn2 = nn.BatchNorm1d(hidden_state) if batch_norm else nn.Sequential()
        self.w = nn.Linear(hidden_state, 1)
        
    def forward(self, X):
        Xv = self.bn1(self.V(X).squeeze(1)).unsqueeze(1)
        Xv = nn.Tanh()(Xv)
        Xu = self.bn2(self.U(X).squeeze(1)).unsqueeze(1)
        Xu = nn.Sigmoid()(Xu)
        X = Xv * Xu
        X = self.w(X)
        X = torch.exp(X)
        return X.squeeze(-1)
    
def cnn_gru(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
    )

def cnn_gru_attention(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=True,
        num_class=kwargs["num_class"],
    )

def cnn_gru_wo_bn(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        batch_norm=False
    )

def cnn_gru_gn(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
        num_class=kwargs["num_class"],
    )

def cnn_gru_gn_act(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
        num_class=kwargs["num_class"],
        use_act=True
    )
    
def cnn_gru_pos_enc(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        pos_enc=True,
    )
    
def cnn_gru_pos_enc_gn(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
        use_attention=False,
        num_class=kwargs["num_class"],
        pos_enc=True
    )

def cnn_ligru(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        LiGRU,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
    )

def cnn_ligru_act(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        LiGRU,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        use_act=True,
        num_class=kwargs["num_class"],
    )
    
def cnn_gru_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        batch_norm=True,
        imagenet=True
    )

def cnn_gru_gn_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
        use_attention=False,
        num_class=kwargs["num_class"],
        imagenet=True,
        batch_norm=True
    )

def cnn_gru_gn_act_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
        use_attention=False,
        num_class=kwargs["num_class"],
        use_act=True,
        imagenet=True,
        batch_norm=True
    )
    
def cnn_gru_pos_enc_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        pos_enc=True,
        imagenet=True,
        batch_norm=True
    )
    
def cnn_gru_pos_enc_gn_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
        use_attention=False,
        num_class=kwargs["num_class"],
        pos_enc=True,
        imagenet=True,
        batch_norm=True
    )

def cnn_ligru_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        LiGRU,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        batch_norm=True,
        imagenet=True
    )

def cnn_ligru_act_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        LiGRU,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        use_act=True,
        num_class=kwargs["num_class"],
        imagenet=True,
        batch_norm=True
    )

def cnn_gru_tiny_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        batch_norm=True,
        tiny_imagenet=True
    )

def cnn_gru_gn_tiny_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
        use_attention=False,
        num_class=kwargs["num_class"],
        tiny_imagenet=True,
        batch_norm=True
    )

def cnn_gru_gn_act_tiny_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
        use_attention=False,
        num_class=kwargs["num_class"],
        use_act=True,
        tiny_imagenet=True,
        batch_norm=True
    )
    
def cnn_gru_pos_enc_tiny_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        pos_enc=True,
        tiny_imagenet=True,
        batch_norm=True
    )
    
def cnn_gru_pos_enc_gn_tiny_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        GRULayer,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=True,
        use_attention=False,
        num_class=kwargs["num_class"],
        pos_enc=True,
        tiny_imagenet=True,
        batch_norm=True
    )

def cnn_ligru_tiny_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        LiGRU,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_attention=False,
        num_class=kwargs["num_class"],
        batch_norm=True,
        tiny_imagenet=True
    )

def cnn_ligru_act_tiny_imagenet(width, **kwargs):
    return CnnGRU(
        BasicBlock,
        LiGRU,
        [2],
        width=width,
        in_channels=kwargs["in_channels"],
        recall=False,
        group_norm=False,
        use_act=True,
        num_class=kwargs["num_class"],
        tiny_imagenet=True,
        batch_norm=True
    )