# shallow_titan.py
import torch
import torch.nn as nn
import torch.nn.functional as F


# Multi-scale Attention Module
class MultiScaleAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv1 = nn.Conv2d(dim, dim*3, kernel_size=1)
        self.qkv3 = nn.Conv2d(dim, dim*3, kernel_size=3, padding=1)
        self.qkv5 = nn.Conv2d(dim, dim*3, kernel_size=5, padding=2)

        self.proj = nn.Conv2d(dim, dim, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        qkv1 = self.qkv1(x).reshape(B, 3, self.num_heads, C // self.num_heads, H, W)
        qkv3 = self.qkv3(x).reshape(B, 3, self.num_heads, C // self.num_heads, H, W)
        qkv5 = self.qkv5(x).reshape(B, 3, self.num_heads, C // self.num_heads, H, W)
        qkv = (qkv1 + qkv3 + qkv5) / 3.0

        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
        q = q.flatten(3).transpose(2, 3)
        k = k.flatten(3).transpose(2, 3)
        v = v.flatten(3).transpose(2, 3)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(2, 3).reshape(B, C, H, W)
        x = self.proj(x)
        return x


# ShallowTitan Model
class ShallowTitan(nn.Module):
    def __init__(self, num_classes=1000, embed_dim=2048, depth=3):
        super().__init__()
        self.embedding = nn.Conv2d(3, embed_dim, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(embed_dim)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.attention_paths = nn.ModuleList([
            nn.Sequential(
                MultiScaleAttention(embed_dim, num_heads=16),
                nn.BatchNorm2d(embed_dim),
                nn.SiLU()
            ) for _ in range(2)
        ])

        self.additional_layers = nn.ModuleList([
            nn.Sequential(
                MultiScaleAttention(embed_dim, num_heads=16),
                nn.BatchNorm2d(embed_dim),
                nn.SiLU()
            ) for _ in range(depth - 1)
        ])

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.pool(F.silu(self.bn1(self.embedding(x))))
        path_outputs = [pathway(x) for pathway in self.attention_paths]
        x = x + sum(path_outputs) / len(path_outputs)

        for layer in self.additional_layers:
            x = x + layer(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)

