import torch, copy
import torch.nn as nn
import torch.nn.functional as F


class BiRQA(nn.Module):
    def _get_sequencial(self, layer: int, in_channels: int = 1):
        conv1 = nn.Conv2d(in_channels,         128 // (2 ** layer), kernel_size=5, stride=1, padding=2)
        conv2 = nn.Conv2d(128 // (2 ** layer), 128 // (2 ** layer), kernel_size=3, stride=1, padding=1)
        conv3 = nn.Conv2d(128 // (2 ** layer), 256 // (2 ** layer), kernel_size=3, stride=1, padding=1)
        conv4 = nn.Conv2d(256 // (2 ** layer), 256 // (2 ** layer), kernel_size=3, padding=1)
        pool1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        
        return nn.Sequential(
                conv1, nn.GroupNorm(num_groups=8,  num_channels=128 // (2 ** layer)), nn.ReLU(), nn.Dropout(0.2),
                conv2, nn.GroupNorm(num_groups=16, num_channels=128 // (2 ** layer)), nn.ReLU(), nn.Dropout(0.2),
                conv3, nn.GroupNorm(num_groups=16, num_channels=256 // (2 ** layer)), nn.ReLU(),
                conv4, nn.GroupNorm(num_groups=16, num_channels=256 // (2 ** layer)), nn.ReLU(), pool1
                )

    def __init__(self):
        super().__init__()

        self.procc = nn.ModuleList([
                nn.ModuleList([self._get_sequencial(layer=j, in_channels=3 if i == 2 else 1) for i in range(4)]) 
                for j in range(4)
            ])

        self.fusion_layers = nn.ModuleList([AdaptiveFusionFourB(in_channels=256 // (2 ** i) * 4, out_channels=64) for i in range(4)])
        self.chan_att = nn.ModuleList([channel_attention_module(ch=64, ratio=8)  for _ in range(4)])
        self.spat_att = nn.ModuleList([spatial_attention_module(num_channels=64) for _ in range(4)])
        self.downscale = EnhancedDownScale(in_channels=64)

        self.drop = nn.Dropout(0.4)
        self.relu = nn.ReLU()
        self.output_fc = nn.Linear(64*4, 1)

        self.scgb1 = CrossGatingBlock(x_features=64)
        self.scgb2 = CrossGatingBlock(x_features=64)
        self.scgb3 = CrossGatingBlock(x_features=64)
    
    def procces_layer(self, l, layer):
        # spat
        x1 = self.procc[layer][0](l[:, 0 ].unsqueeze(1))
        # info
        x2 = self.procc[layer][1](l[:, 1 ].unsqueeze(1))
        # color maps
        x3 = self.procc[layer][2](l[:, 2:5])
        # lbp
        x4 = self.procc[layer][3](l[:, 5 ].unsqueeze(1))
        return self.fusion_layers[layer](x1, x2, x3, x4)

    def forward(self, x):
        fused_maps = [self.procces_layer(x[i], i) for i in range(4)]
        fused_maps = [self.chan_att[i](j) for i,j in enumerate(fused_maps)]
        fused_maps_down = [self.downscale(i) for i in fused_maps[:-1]]
        for i in range(1, 4):
            fused_maps[i] = fused_maps[i] * self.spat_att[i-1](fused_maps_down[i-1])
        
        fused_maps = [torch.mean(i, dim=(-2, -1)) for i in fused_maps]

        fused_maps[2] = fused_maps[2] + self.scgb3(fused_maps[2], fused_maps[3])
        fused_maps[1] = fused_maps[1] + self.scgb2(fused_maps[1], fused_maps[2])
        fused_maps[0] = fused_maps[0] + self.scgb1(fused_maps[0], fused_maps[1])

        x = torch.cat(fused_maps, dim=1)
        x = self.output_fc(x)
        # x = self.norm
        x = x.view(-1)
        return x



class CrossGatingBlock(nn.Module):  #input shape: n, c, h, w
    """Cross-gating MLP block."""
    def __init__(self, x_features, use_bias=True, dropout_rate=0.1):
        super().__init__()
        self.x_features = x_features
        self.use_bias = use_bias
        self.drop = dropout_rate
        self.dense_0 = nn.Linear(self.x_features, self.x_features)
        self.dense_1 = nn.Linear(self.x_features, self.x_features)
        self.in_project_x = nn.Linear(self.x_features, self.x_features, bias=self.use_bias)
        self.gelu1 = nn.GELU(approximate='tanh')
        self.out_project_y = nn.Linear(self.x_features, self.x_features, bias=self.use_bias)
        self.dropout1 = nn.Dropout(self.drop)
    def forward(self, x, y):     #b,n
        # Upscale Y signal, y is the gating signal.
        assert y.shape == x.shape
        x = self.dense_0(x)
        y = self.dense_1(y)
        shortcut_y = y
        x = self.in_project_x(x)
        gx = self.gelu1(x)
        # Apply cross gating
        y = y * gx  # gating y using x
        y = self.out_project_y(y)
        y = self.dropout1(y)
        y = y + shortcut_y # y = y * x + y
        return y


class EnhancedDownScale(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.down1 = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=0)
        self.rel1 = nn.ReLU()
        self.down2 = nn.Conv2d(in_channels, in_channels, 5, stride=1, padding=0)
        self.rel2 = nn.ReLU()

        self.skip  = nn.Conv2d(in_channels, in_channels, 1)
        
    def forward(self, x):
        down_path = self.rel2(self.down2(self.rel1(self.down1(x))))
        skip_path = F.interpolate(self.skip(x), size=down_path.shape[2:])
        return down_path + skip_path

class AdaptiveFusionFourB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels//2, 1),
            nn.ReLU(),
            nn.Conv2d(in_channels//2, 4, 1),  # 3 weights for 3 streams
            nn.Softmax(dim=1)
        )
        self.proj_out = nn.Conv2d(in_channels, out_channels, 1)
    
    def forward(self, x1, x2, x3, x4):
        combined = torch.cat((x1, x2, x3, x4), dim=1)
        weights = self.attention(combined)
        combined = torch.cat([weights[:,0:1]*x1, weights[:,1:2]*x2, weights[:,2:3]*x3, weights[:,3:4]*x4], dim=1)
        return self.proj_out(combined)
    
class AdaptiveFusion(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels//2, 1),
            nn.ReLU(),
            nn.Conv2d(in_channels//2, 3, 1),  # 3 weights for 3 streams
            nn.Softmax(dim=1)
        )
        self.proj_out = nn.Conv2d(in_channels, out_channels, 1)
    
    def forward(self, x1, x2, x3):
        combined = torch.cat((x1, x2, x3), dim=1)
        weights = self.attention(combined)
        combined = torch.cat([weights[:,0:1]*x1, weights[:,1:2]*x2, weights[:,2:3]*x3], dim=1)
        return self.proj_out(combined)



class SpatialCrossAttention(nn.Module):
    def __init__(self, in_channels, embed_dim):
        super(SpatialCrossAttention, self).__init__()
        self.embed_dim = embed_dim
        self.scale = embed_dim ** -0.5  # Scaling factor for attention scores

        # Linear layers for Q, K, V projections
        self.query_conv = nn.Conv2d(in_channels, embed_dim, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, embed_dim, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, embed_dim, kernel_size=1)

        # Output layer
        self.out_conv = nn.Conv2d(embed_dim, in_channels, kernel_size=1)

    def forward(self, x, context):
        # x: Input feature map of shape (batch_size, channels, width, height)
        # context: Context feature map of shape (batch_size, channels, width, height)

        batch_size, C, H, W = x.size()

        # Project inputs to queries (Q), keys (K), and values (V)
        query = self.query_conv(x).view(batch_size, self.embed_dim, -1)  # (B, E, H*W)
        key   = self.key_conv(context).view(batch_size, self.embed_dim, -1)  # (B, E, H*W)
        value = self.value_conv(context).view(batch_size, self.embed_dim, -1)  # (B, E, H*W)

        # Compute attention scores
        attn_scores = query * key * self.scale  # (B, H*W, H*W)

        # Apply softmax to get attention weights
        attn_weights = F.softmax(attn_scores, dim=-1)  # (B, H*W, H*W)

        # Compute the output as a weighted sum of values
        attn_output = attn_weights * value  # (B, H*W, E)

        # Reshape back to spatial dimensions and apply output convolution
        attn_output = attn_output.view(batch_size, self.embed_dim, H, W)
        return self.out_conv(attn_output)


class channel_attention_module_sep_mlp(nn.Module):
    def __init__(self, ch, ratio=2, bias=False):
        super().__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.mlp1 = nn.Sequential(
            nn.Linear(ch, ch//ratio, bias=bias),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(ch//ratio, ch, bias=bias)
        )
        self.mlp2 = nn.Sequential(
            nn.Linear(ch, ch//ratio, bias=bias),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(ch//ratio, ch, bias=bias)
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
        x1 = self.mlp1(x1)

        x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
        x2 = self.mlp2(x2)

        feats = x1 + x2
        feats = self.sigmoid(feats).unsqueeze(-1).unsqueeze(-1)
        refined_feats = x * feats

        return refined_feats

class channel_attention_module(nn.Module):
    def __init__(self, ch, ratio=2, bias=False):
        super().__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.mlp = nn.Sequential(
            nn.Linear(ch, ch//ratio, bias=bias),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(ch//ratio, ch, bias=bias)
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
        x1 = self.mlp(x1)

        x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
        x2 = self.mlp(x2)

        feats = x1 + x2
        feats = self.sigmoid(feats).unsqueeze(-1).unsqueeze(-1)
        refined_feats = x * feats

        return refined_feats


class spatial_attention_module(nn.Module):
    def __init__(self, num_channels=128, kernel_size=7, padding=3):
        super().__init__()

        self.norm = nn.GroupNorm(num_groups=8, num_channels=num_channels, eps=1e-6, affine=True)
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.norm(x)
        x1 = torch.mean(x, dim=1, keepdim=True)
        x2, _ = torch.max(x, dim=1, keepdim=True)

        feats = torch.cat([x1, x2], dim=1)
        feats = self.conv(feats)
        feats = self.sigmoid(feats)

        return feats
    

class DownScaleBlock(nn.Module):
    def __init__(self, in_channels, num_groups=8):
        super().__init__()
        self.in_channels = in_channels

        self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
        self.c1 = nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=2,padding=0)
        self.rel1 = nn.ReLU()
        self.c2 = nn.Conv2d(in_channels,in_channels,kernel_size=5,stride=1,padding=0)
        self.rel2 = nn.ReLU()

    def forward(self, x):
        x = self.norm(x)
        return self.rel2(self.c2(self.rel1(self.c1(x))))