import torch
import torch.nn as nn

from opencood.models.spconv_utils import ConvertSparseTensor
import spconv
import spconv.pytorch
from spconv.pytorch import SparseSequential, SparseConv2d
from spconv.core import ConvAlgo
    
class Heterogeneity_reflected_Convolution(nn.Module):
    def __init__(self, dim, res_con, deconv, conv):
        super(Heterogeneity_reflected_Convolution, self).__init__()

        self.res_con = res_con
        
        blocks = []
        deconv_k = deconv['kernel']
        deconv_s = deconv['stride']
        deconv_p = deconv['padding']
        
        conv_k = conv['kernel']
        conv_s = conv['stride']
        conv_p = conv['padding']
        
        for l in range(len(res_con)):
            block = PyramideConvResBlock(dim, deconv_k[l], deconv_s[l], deconv_p[l],
                                         conv_k[l], conv_s[l], conv_p[l])
            
            blocks.append(block)
            
        self.blocks = nn.ModuleList(blocks)
    
    def generate_s_mask(self, infra, regroup_s_mask):
        B, L, H, W = regroup_s_mask.shape
        
        he_o_list = []
        ho_o_list = []
                
        for b in range(B):
            he_o = torch.zeros((L, H, W), dtype=torch.bool, device=regroup_s_mask.device)
            ho_o = torch.zeros((L, H, W), dtype=torch.bool, device=regroup_s_mask.device)
            
            tmp_s_mask = regroup_s_mask[b, :]
            
            if 1 in infra[b, :]:
                inf_index = torch.nonzero(infra[b, :] == 1).item()
                inf_s_mask = tmp_s_mask[infra[b, :] == 1, :].squeeze()
                
            for l in range(L-1):                
                if 1 in infra[b, :]:
                    if l == inf_index:
                        continue
                    
                    he_o[l, :] = torch.logical_and(inf_s_mask, tmp_s_mask[l, :])
                    he_o[inf_index, :] = torch.logical_or(he_o[inf_index, :], he_o[l, :])
                
                for i in range(l+1, L):
                    tmp_ho = torch.logical_and(tmp_s_mask[l, :], tmp_s_mask[i, :])
                    ho_o[l, :] = torch.logical_or(ho_o[l, :], tmp_ho)
                    ho_o[i, :] = torch.logical_or(ho_o[i, :], tmp_ho)   
            
            he_o_list.append(he_o)
            ho_o_list.append(ho_o)
        
        # B, L, H, W
        he_o = torch.stack(he_o_list, dim=0)
        ho_o = torch.stack(ho_o_list, dim=0)    
        
        not_o = torch.logical_and(~he_o, ~ho_o)
        
        return he_o, ho_o, not_o
        
    def forward(self, x, infra, regroup_s_mask=None):
        B, L, H, W, C = x.shape
        _, _, MH, MW = regroup_s_mask.shape
        
        he_o, ho_o, _ = self.generate_s_mask(infra, regroup_s_mask)
            
        # B, L, H, W -> B, L, C, H, W
        he_o = he_o.unsqueeze(2).expand(-1, -1, C//2, -1, -1).reshape(-1, C//2, MH, MW)
        ho_o = ho_o.unsqueeze(2).expand(-1, -1, C//2, -1, -1).reshape(-1, C//2, MH, MW)
        # not_o = not_o.unsqueeze(2).expand(-1, -1, C, -1, -1).reshape(-1, C, H, W)
        
        # B, L, H, W, C -> BL, C, H, W
        x = x.permute(0, 1, 4, 2, 3).view(-1, C, H, W)
    
        for i in range(len(self.blocks)):
                
            x = self.blocks[i](x, he_o, ho_o, self.res_con[i])
            
        return x.view(B, L, -1, H, W).permute(0, 1, 3, 4, 2)

class PyramideConvResBlock(nn.Module):
    def __init__(self, dim, deconv_k, deconv_s, deconv_p, conv_k, conv_s, conv_p):
        super(PyramideConvResBlock, self).__init__()
        self.upsample_1 = nn.Sequential(nn.ConvTranspose2d(dim, dim//2, deconv_k, deconv_s, deconv_p),
                                        nn.BatchNorm2d(dim//2),
                                        nn.ReLU())
        
        self.upsample_2 = nn.Sequential(nn.ConvTranspose2d(dim//2, dim//2, deconv_k, deconv_s, deconv_p),
                                        nn.BatchNorm2d(dim//2),
                                        nn.ReLU())
        
        self.feature_fusion = nn.Sequential(nn.Conv2d(dim, dim//2, deconv_k, 1, 1, bias=False),
                                            nn.BatchNorm2d(dim//2),
                                            nn.ReLU())
        
        self.downsample_1 = nn.Sequential(nn.Conv2d(dim//2, dim//2, conv_k, conv_s, conv_p, bias=False),
                                          nn.BatchNorm2d(dim//2))
        
        self.downsample_2 = nn.Sequential(nn.Conv2d(dim//2, dim, conv_k, conv_s, conv_p, bias=False),
                                          nn.BatchNorm2d(dim))

        self.act = nn.ReLU()

    def forward(self, x, he_o, ho_o, res_flag=True):
        identity = x
        
        x = self.upsample_1(x)
        identity_2x = x
        
        out = self.upsample_2(x)
        
        # he_out = out.masked_fill(he_o==False, 0)
        # ho_out = out.masked_fill(ho_o==False, 0)
        
        # out = self.feature_fusion(torch.concat((he_out, ho_out), dim=1))
        
        out = self.act(self.downsample_1(out) + identity_2x)
        out = self.downsample_2(out)
        
        out = out + identity if res_flag else out
        out = self.act(out)
        return out
    
