import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):

    def __init__(self, in_channels, feat_out, stride=(1, 1)):
        super(ConvBlock, self).__init__()
        self.feat_out = feat_out
        self.stride = stride
        
        self.conv1 = nn.Conv2d(in_channels, feat_out, kernel_size=3, 
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(feat_out)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(feat_out, feat_out, kernel_size=3, 
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(feat_out)
        self.relu2 = nn.ReLU(inplace=True)

        self.shortcut_conv = nn.Conv2d(in_channels, feat_out, kernel_size=1, 
                                      stride=stride, bias=False)
        self.shortcut_bn = nn.BatchNorm2d(feat_out)
        
        self.final_relu = nn.ReLU(inplace=True)
        
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        
        # shortcut
        shortcut = self.shortcut_conv(x)
        shortcut = self.shortcut_bn(shortcut)
        
        out = out + shortcut
        out = self.final_relu(out)
        
        return out

class IdentityBlock(nn.Module):

    def __init__(self, in_channels, feat_out):
        super(IdentityBlock, self).__init__()
        self.feat_out = feat_out
        
        self.conv = nn.Conv2d(in_channels, feat_out, kernel_size=3, 
                             stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(feat_out)
        self.relu1 = nn.ReLU(inplace=True)
        self.final_relu = nn.ReLU(inplace=True)
        
    def forward(self, x):

        out = self.conv(x)
        out = self.bn(out)
        out = self.relu1(out)
    
        out = out + x
        out = self.final_relu(out)
        
        return out


class TimeAlignment(nn.Module):

    def __init__(self, input_dim, output_dim, embed_dim):
        super(TimeAlignment, self).__init__()
        self.attention_dim = 64
        self.output_dim = output_dim
        self.pos_dim = output_dim // input_dim
        self.query = nn.Linear(input_dim, self.attention_dim, bias=False)
        self.key = nn.Linear(input_dim, self.attention_dim, bias=False)
        self.value = nn.Linear(input_dim, self.attention_dim, bias=False)
        self.output_proj = nn.Linear(self.attention_dim, embed_dim)
        self.pos_linear = nn.Linear(embed_dim, self.pos_dim)
        self.scale = self.attention_dim ** -0.5
        self.dropout = nn.Dropout(0.1)

        self.max_seq_len = 1000 
        self.register_buffer('pos_encoding', self._create_positional_encoding(self.max_seq_len))

    def _create_positional_encoding(self, max_len):
        pe = torch.zeros(max_len)
        position = torch.arange(0, max_len, dtype=torch.float)
        
        div_term = 10000.0 ** (torch.arange(0, max_len, dtype=torch.float) / max_len)
        pe[0::2] = torch.sin(position[0::2] / div_term[0::2])
        pe[1::2] = torch.cos(position[1::2] / div_term[1::2])
        
        return pe.unsqueeze(0)  # (1, max_len)
        

    def forward(self, x):
        B, C, H, T = x.shape
        energy = x.pow(2)
        freq_importance = F.softmax(energy.mean(dim=(1,3)),dim=-1)  
        freq_importance = freq_importance.unsqueeze(1).unsqueeze(-1)

        channel_importance = F.softmax(energy.mean(dim=(2,3)),dim=-1) 
        channel_importance = channel_importance.unsqueeze(-1).unsqueeze(-1)

        M = energy * freq_importance * channel_importance

        w = M.permute(0, 3, 2, 1)  # (B, T, H, C)
        w = self.pos_linear(w)

        time_len = w.size(1)

        pos_enc = self.pos_encoding[:, :time_len].unsqueeze(-1).unsqueeze(-1).to(w.device)  # (1, T, 1, 1)

        w = w + pos_enc
        w = w.permute(0, 2, 1, 3)  # (B, H, T, pos_dim)
        w = w.contiguous().view(B, H, self.pos_dim * T)  # (B, H, pos_dim * T)
        w = w.permute(0, 2, 1)    # (B, pos_dim * T, H)
        # 
        q = self.query(w)  # (B, pos_dim * T, attention_dim)
        k = self.key(w)    # (B, pos_dim * T, attention_dim)
        v = self.value(w)  # (B, pos_dim * T, attention_dim)
        
        # Self attention
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale  
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = torch.matmul(attn_weights, v)  # (B, pos_dim * T, attention_dim)
        w = self.output_proj(attn_output)

        return w

class ResidualBlock(nn.Module):
    def __init__(self, input_channels=10):
        super(ResidualBlock, self).__init__()
        
        # conv1
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=1, padding=3, bias=False)
        self.bn_conv1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # conv2_x
        self.conv_block2 = ConvBlock(64, 64, stride=(1, 1))
        self.identity_block2 = IdentityBlock(64, 64)
        
        # conv3_x
        self.conv_block3 = ConvBlock(64, 128, stride=(2, 2))
        self.identity_block3 = IdentityBlock(128, 128)
        
        # conv4_x
        self.conv_block4 = ConvBlock(128, 256, stride=(2, 2))
        self.identity_block4 = IdentityBlock(256, 256)
        
        self.cross_alignment1 = TimeAlignment(input_dim=50,output_dim=1000,embed_dim=64)   
        self.cross_alignment2 = TimeAlignment(input_dim=25,output_dim=500,embed_dim=128) 
        self.cross_alignment3 = TimeAlignment(input_dim=10,output_dim=250,embed_dim=256)

    
    def forward(self, input_tensor):
        # conv1
        x = self.conv1(input_tensor)

        x = self.bn_conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # conv2_x
        x = self.conv_block2(x)
        x = self.identity_block2(x)
        cross_x1 = self.cross_alignment1(x)
  
        
        # conv3_x
        x = self.conv_block3(x)
        x = self.identity_block3(x)
        cross_x2 = self.cross_alignment2(x)
        
        # conv4_x
        x = self.conv_block4(x)
        x = self.identity_block4(x)

        x = F.adaptive_avg_pool2d(x, (10, 10))

        cross_x3 = self.cross_alignment3(x)

        return cross_x1, cross_x2, cross_x3



if __name__ == "__main__":
    
    model = ResidualBlock(input_channels=10) 
    test_input = torch.randn(32, 10, 100, 100)  
    cross1, cross2, cross3 = model(test_input)
    print(f"Cross1 shape: {cross1.shape}")
    print(f"Cross2 shape: {cross2.shape}")
    print(f"Cross3 shape: {cross3.shape}")

   
