""" dt_net_2d.py
    DeepThinking network 2D.

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

import torch
from torch import nn
from torch.nn import functional as F
import torchvision.models as models

import math
import random

from .blocks import Head, PositionalEncoding, HaltConv

# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702)
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914


class CnnGRUImagenet(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(
        self,
        base_model,
        gru_layer,
        width=2048,
        group_norm=False,
        num_class=0,
        use_act=False,
        batch_norm=False,
        ssl=False,
        **kwargs
    ):
        super().__init__()

        self.width = int(width)
        self.group_norm = group_norm
        self.batch_norm = batch_norm
        self.num_class = num_class
        
        self.base_model = base_model
        
        self.recur_block = gru_layer(width)
        if ssl:
            self.ssh_head = nn.Linear(width * 2, 4)
        else:
            self.ssh_head = nn.Sequential()
        self.ssl = ssl
            
        self.use_act = use_act
        if use_act:
            self.halt_conv = HaltConv(ht_channel=width)

    def forward(self, x, iters_to_do, interim_thought=None, debug=False, return_ssh=False, **kwargs):
        x = self.base_model.conv1(x)
        x = self.base_model.bn1(x)
        x = self.base_model.relu(x)
        x = self.base_model.maxpool(x)

        # Pass through the layers up until the last residual block
        x = self.base_model.layer1(x)
        x = self.base_model.layer2(x)
        xt = self.base_model.layer3(x)
        
        batch_size = xt.shape[0]
        if interim_thought is None:
            h_prev = torch.zeros_like(xt).to(xt.device)
        else:
            h_prev = interim_thought

        
        if self.num_class:
            all_outputs = torch.zeros((x.size(0), iters_to_do, self.num_class)).to(
                x.device
            )
            all_ssh_outputs = torch.zeros((x.size(0), iters_to_do, 4)).to(x.device)
        else:
            all_outputs = torch.zeros(
                (x.size(0), iters_to_do, 2, x.size(2), x.size(3))
            ).to(x.device)
            all_ssh_outputs = torch.zeros(
                (x.size(0), iters_to_do, 2, x.size(2), x.size(3))
            ).to(x.device)
        all_ssh_outputs = torch.zeros((x.size(0), iters_to_do, 4)).to(x.device)
        
        res = []
        norm = []
        
        pt_accumulate = torch.zeros((x.size(0), 1)).to(x.device)
        p_t = torch.zeros((x.size(0), 1)).to(x.device)
        
        epsilon = 0.05
        halt = torch.zeros((x.size(0), 1)).to(x.device)
    
        if self.use_act:
            h_act = torch.zeros_like(h_prev).to(x.device)
            
        feature_maps = []
        h_total = torch.zeros_like(h_prev).to(h_prev.device)
        
        active_mask = torch.ones(batch_size, device=x.device)
        
        for i in range(iters_to_do):
            h_t = self.recur_block(h_prev, xt)

            ############### act ##################
            if self.use_act:
                if i == 0:
                    sum_pt = p_t
                
                if i > 0:
                    p_t = self.halt_conv(h_prev)
                    pt_accumulate = pt_accumulate + p_t
                    halt = torch.where(mask, torch.tensor(i), halt)
                    PT = p_t.view(x.size(0), 1, 1, 1)
                    mask = mask.view(x.size(0), 1, 1, 1)
                    h_act = h_act + PT * h_t * mask
                    
                mask = pt_accumulate < 1 - epsilon
                sum_pt = sum_pt + p_t*mask
            ######################################
                
            threshold = 0.00 # Ngưỡng thay đổi nhỏ giữa h_t và h_prev

            if not self.training and threshold > 0:
                # Tính norm giữa h_t và h_prev cho từng sample
                norm_change = torch.norm(h_t - h_prev, p=2, dim=[1,2,3]) / torch.norm(h_t, p=2, dim=[1,2,3])  # Tính norm theo từng sample (dim=-1)

                # Để tạo mask: norm_change < threshold sẽ có giá trị 1 (có thể dừng)
                stop_mask = norm_change < threshold  # (batch_size,)

                # Cập nhật active_mask để chỉ dừng những sample đó
                active_mask = active_mask * (1 - stop_mask.float())

                # Update h_curr chỉ cho các sample chưa dừng
                expand_mask = active_mask.view(batch_size, 1, 1, 1)
                h_prev = h_t * expand_mask + h_prev * (1 - expand_mask)
            #####################################
                
                    
            res.append((h_t - h_prev).norm().item())
            norm.append(h_t.norm().item())

            if self.use_act:
                h_t = h_act
            
            # out = self.head(h_t)
            # Pass through the remaining layers (fully connected layers)
            feature_map = self.base_model.layer4(h_t)  # Assuming layer4 is the residual block to replace
            feature_map = self.base_model.avgpool(feature_map)
            feature_map = torch.flatten(feature_map, 1)
            out = self.base_model.fc(feature_map)
            if self.ssl:
                ssh_out = self.ssh_head(feature_map)
            else:
                ssh_out = torch.zeros_like(all_ssh_outputs[:, i]).to(x.device)
            all_outputs[:, i] = out
            all_ssh_outputs[:, i] = ssh_out
            if threshold == 0:
                h_prev = h_t
    
        if self.training:
            if self.use_act:
                return out, h_t, ssh_out, sum_pt
            return out, h_t, ssh_out

        if debug:
            return (all_outputs, res, norm)
        
        if return_ssh:
            return all_outputs, all_ssh_outputs
        return all_outputs


class GRULayer(nn.Module):
    def __init__(self, d_model, group_norm=False):
        super().__init__()
        self.Wz = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Uz = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Wr = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Ur = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Wh = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Uh = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.gn = nn.Sequential()
        if group_norm:
            self.gn = nn.GroupNorm(d_model, d_model)

    def forward(self, h_prev, xt):
        zt = nn.Sigmoid()(self.gn(self.Wz(xt) + self.Uz(h_prev)))
        rt = nn.Sigmoid()(self.gn(self.Wr(xt) + self.Ur(h_prev)))
        h_t_mu = nn.Tanh()(self.gn(self.Wh(xt) + self.Uh(rt * h_prev)))
        h_t = (1 - zt) * h_prev + zt * h_t_mu
        return h_t
    
class LiGRU(nn.Module):
    def __init__(self, d_model, batch_norm=False):
        super().__init__()
        self.Wz = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Uz = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Wh = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.Uh = nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(d_model) if batch_norm else nn.Sequential()
        self.bn2 = nn.BatchNorm2d(d_model) if batch_norm else nn.Sequential()
        
    def forward(self, h_prev, xt):
        zt = nn.Sigmoid()(self.bn1(self.Wz(xt)) + self.Uz(h_prev))
        h_t_mu = nn.ReLU()(self.bn2(self.Wh(xt)) + self.Uh(h_prev))
        h_t = zt * h_prev + (1 - zt) * h_t_mu
        return h_t
    
class GatedAttentionLayer(nn.Module):
    def __init__(self, d_model, hidden_state, batch_norm=True):
        super().__init__()
        self.V = nn.Linear(d_model, hidden_state)
        self.bn1 = nn.BatchNorm1d(hidden_state) if batch_norm else nn.Sequential()
        self.U = nn.Linear(d_model, hidden_state)
        self.bn2 = nn.BatchNorm1d(hidden_state) if batch_norm else nn.Sequential()
        self.w = nn.Linear(hidden_state, 1)
        
    def forward(self, X):
        Xv = self.bn1(self.V(X).squeeze(1)).unsqueeze(1)
        Xv = nn.Tanh()(Xv)
        Xu = self.bn2(self.U(X).squeeze(1)).unsqueeze(1)
        Xu = nn.Sigmoid()(Xu)
        X = Xv * Xu
        X = self.w(X)
        X = torch.exp(X)
        return X.squeeze(-1)
    
def imagenet_gru(width=1024, **kwargs):
    return CnnGRUImagenet(
        models.resnet50(pretrained=True),
        GRULayer,
        width=width,
        group_norm=True,
        use_attention=False,
        num_class=kwargs["num_class"],
        ssl=True
    )
    

if __name__ == "__main__":
    model = imagenet_gru(num_class=1000)
    x = torch.rand((8, 3, 224, 224))
    y = model(x, 10)