import torch
import torch.nn as nn
import math


class Conov2x2(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(Conov2x2, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=2,
            stride=stride, padding=0, bias=True
        )

    def forward(self, x):
        result = self.conv(x)
        return result

class DWCONV(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(DWCONV, self).__init__()
        self.depthwise = nn.Conv2d(
            in_channels, out_channels, kernel_size=3,
            stride=stride, padding=1, groups=in_channels, bias=True
        )

    def forward(self, x):
        result = self.depthwise(x)
        return result

class LPU(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LPU, self).__init__()
        self.DWConv = DWCONV(in_channels, out_channels)

    def forward(self, x):
        result = self.DWConv(x) + x
        return result

class LMHSA(nn.Module):
    def __init__(self, input_size, channels, d_k, d_v, stride, heads, dropout):
        super(LMHSA, self).__init__()
        self.dwconv_k = DWCONV(channels, channels, stride=stride)
        self.dwconv_v = DWCONV(channels, channels, stride=stride)
        self.fc_q = nn.Linear(channels, heads * d_k)
        self.fc_k = nn.Linear(channels, heads * d_k)
        self.fc_v = nn.Linear(channels, heads * d_v)
        self.fc_o = nn.Linear(heads * d_k, channels)

        self.channels = channels
        self.d_k = d_k
        self.d_v = d_v
        self.stride = stride
        self.heads = heads
        self.dropout = dropout
        self.scaled_factor = self.d_k ** -0.5
        self.num_patches = (self.d_k // self.stride) ** 2
        self.B = nn.Parameter(
            torch.Tensor(1, self.heads, input_size ** 2, (input_size // stride) ** 2),
            requires_grad=True
        )
        nn.init.trunc_normal_(self.B, std=0.02)

    def forward(self, x):
        b, c, h, w = x.shape

        # Reshape
        x_reshape = x.view(b, c, h * w).permute(0, 2, 1)  # [B, H*W, C]
        x_reshape = torch.nn.functional.layer_norm(x_reshape, (c,))

        # Get q, k, v
        q = self.fc_q(x_reshape)  # [B, H*W, heads*d_k]
        q = q.view(b, h * w, self.heads, self.d_k).permute(0, 2, 1, 3).contiguous()  # [B, heads, H*W, d_k]

        k = self.dwconv_k(x)  # [B, C, H/s, W/s]
        k_b, k_c, k_h, k_w = k.shape
        k = k.view(k_b, k_c, k_h * k_w).permute(0, 2, 1).contiguous()  # [B, H'*W', C]
        k = self.fc_k(k)  # [B, H'*W', heads*d_k]
        k = k.view(k_b, k_h * k_w, self.heads, self.d_k).permute(0, 2, 1, 3).contiguous()  # [B, heads, H'*W', d_k]

        v = self.dwconv_v(x)  # [B, C, H/s, W/s]
        v_b, v_c, v_h, v_w = v.shape
        v = v.view(v_b, v_c, v_h * v_w).permute(0, 2, 1).contiguous()  # [B, H'*W', C]
        v = self.fc_v(v)  # [B, H'*W', heads*d_v]
        v = v.view(v_b, v_h * v_w, self.heads, self.d_v).permute(0, 2, 1, 3).contiguous()  # [B, heads, H'*W', d_v]

        # Attention
        attn = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scaled_factor  # [B, heads, H*W, H'*W']
        attn = attn + self.B
        attn = torch.softmax(attn, dim=-1)  # [B, heads, H*W, H'*W']

        result = torch.matmul(attn, v)  # [B, heads, H*W, d_v]
        result = result.permute(0, 2, 1, 3).contiguous()  # [B, H*W, heads, d_v]
        result = result.view(b, h * w, self.heads * self.d_v)  # [B, H*W, heads*d_v]
        result = self.fc_o(result)  # [B, H*W, C]
        result = result.permute(0, 2, 1).contiguous().view(b, self.channels, h, w)  # [B, C, H, W]
        result = result + x
        return result

class IRFFN(nn.Module):
    def __init__(self, in_channels, R):
        super(IRFFN, self).__init__()
        exp_channels = int(in_channels * R)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, exp_channels, kernel_size=1),
            nn.BatchNorm2d(exp_channels),
            nn.GELU()
        )

        self.dwconv = nn.Sequential(
            DWCONV(exp_channels, exp_channels),
            nn.BatchNorm2d(exp_channels),
            nn.GELU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(exp_channels, in_channels, 1),
            nn.BatchNorm2d(in_channels)
        )

    def forward(self, x):
        result = x + self.conv2(self.dwconv(self.conv1(x)))
        return result

class Patch_Aggregate(nn.Module):
    def __init__(self, in_channels, out_channels=None):
        super(Patch_Aggregate, self).__init__()
        if out_channels is None:
            out_channels = in_channels
        self.conv = Conov2x2(in_channels, out_channels, stride=2)
        self.init_weight()

    def init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.LayerNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv(x)
        _, c, h, w = x.size()
        result = torch.nn.functional.layer_norm(x, (c, h, w))
        return result

class CMTStem(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3,
            stride=2, padding=1, bias=False
        )
        self.gelu1 = nn.GELU()
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False
        )
        self.gelu2 = nn.GELU()
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False
        )
        self.gelu3 = nn.GELU()
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.init_weight()

    def init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)
        x = self.gelu1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.gelu2(x)
        x = self.bn2(x)
        x = self.conv3(x)
        x = self.gelu3(x)
        result = self.bn3(x)
        return result


def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class CMTBlock(nn.Module):
    def __init__(self, img_size, stride, d_k, d_v, num_heads, R=3.6, in_channels=46, apply_lmhsa=True):
        super(CMTBlock, self).__init__()

        # Local Perception Unit
        self.lpu = LPU(in_channels, in_channels)

        # Optionally include LMHSA
        self.apply_lmhsa = apply_lmhsa
        if self.apply_lmhsa:
            self.lmhsa = LMHSA(
                img_size,
                in_channels,
                d_k,
                d_v,
                stride,
                num_heads,
                0.0
            )

        # Inverted Residual FFN
        self.irffn = IRFFN(in_channels, R)

    def forward(self, x):
        x = self.lpu(x)
        if self.apply_lmhsa:
            x = self.lmhsa(x)
        x = self.irffn(x)
        return x

# Modify the Encoder class
class Encoder(nn.Module):
    """
    Encoder class that encodes the input image into a feature map.
    """
    def __init__(self, in_channels, img_size, out_channels, out_img_size):
        super(Encoder, self).__init__()
        # Calculate the required number of downsampling stages
        assert img_size % out_img_size == 0 and ((img_size // out_img_size) & ((img_size // out_img_size) - 1)) == 0, \
            "img_size / out_img_size must be a power of 2"
        num_stages = int(math.log2(img_size / out_img_size)) - 1  # Subtract 1 because the Stem already downsamples once
        stem_channels = 32  # Adjust as needed
        self.stem = CMTStem(in_channels, stem_channels)
        current_channels = stem_channels
        current_img_size = img_size // 2  # Size after Stem

        # Define stages and downsampling modules
        self.stages = nn.ModuleList()
        self.patch_aggregates = nn.ModuleList()
        for i in range(num_stages):
            next_channels = current_channels * 2
            # Apply LMHSA only in later stages to save memory
            apply_lmhsa = current_img_size <= 32  # Apply LMHSA when spatial size is 32 or less
            stage = nn.Sequential(
                CMTBlock(
                    img_size=current_img_size,
                    stride=4 if apply_lmhsa else 1,  # Increase stride to reduce spatial dimensions in LMHSA
                    d_k=8,
                    d_v=8,
                    num_heads=1,
                    R=3.6,
                    in_channels=current_channels,
                    apply_lmhsa=apply_lmhsa
                )
            )
            self.stages.append(stage)
            # Downsampling
            patch_aggregate = Patch_Aggregate(
                in_channels=current_channels,
                out_channels=next_channels
            )
            self.patch_aggregates.append(patch_aggregate)
            current_channels = next_channels
            current_img_size = current_img_size // 2
        # Last stage, no downsampling
        # Adjust channels to out_channels if needed
        apply_lmhsa = current_img_size <= 32  # Apply LMHSA when spatial size is 32 or less
        if current_channels != out_channels:
            stage = nn.Sequential(
                CMTBlock(
                    img_size=current_img_size,
                    stride=4 if apply_lmhsa else 1,
                    d_k=8,
                    d_v=8,
                    num_heads=1,
                    R=3.6,
                    in_channels=current_channels,
                    apply_lmhsa=apply_lmhsa
                ),
                nn.Conv2d(current_channels, out_channels, kernel_size=1)
            )
        else:
            stage = nn.Sequential(
                CMTBlock(
                    img_size=current_img_size,
                    stride=4 if apply_lmhsa else 1,
                    d_k=8,
                    d_v=8,
                    num_heads=1,
                    R=3.6,
                    in_channels=current_channels,
                    apply_lmhsa=apply_lmhsa
                )
            )
        self.stages.append(stage)
        self.current_img_size = current_img_size

    def forward(self, x):
        x = self.stem(x)
        for i in range(len(self.stages) - 1):
            x = self.stages[i](x)
            x = self.patch_aggregates[i](x)
        x = self.stages[-1](x)
        return x



class Decoder(nn.Module):
    def __init__(self, out_channels, out_img_size, in_channels, img_size):
        super(Decoder, self).__init__()

        assert img_size % out_img_size == 0 and ((img_size // out_img_size) & ((img_size // out_img_size) - 1)) == 0, "img_size / out_img_size must be a power of 2"
        num_upsamples = int(math.log2(img_size / out_img_size))

        self.upsamples = nn.ModuleList()
        current_channels = out_channels
        for i in range(num_upsamples):
            next_channels = current_channels // 2 if current_channels > in_channels else in_channels
            upsample = nn.Sequential(
                nn.ConvTranspose2d(
                    current_channels, next_channels, kernel_size=2, stride=2
                ),
                nn.BatchNorm2d(next_channels),
                nn.ReLU(inplace=True)
            )
            self.upsamples.append(upsample)
            current_channels = next_channels

        if current_channels != in_channels:
            self.final_conv = nn.Conv2d(current_channels, in_channels, kernel_size=1)
        else:
            self.final_conv = nn.Identity()

    def forward(self, x):
        for upsample in self.upsamples:
            x = upsample(x)
        x = self.final_conv(x)
        return x



class GIAO(nn.Module):
    def __init__(self, in_channels, img_size, out_channels, out_img_size):
        super(GIAO, self).__init__()
        self.encoder = Encoder(in_channels, img_size, out_channels, out_img_size)
        self.decoder = Decoder(out_channels, out_img_size, in_channels, img_size)

    def forward(self, x):
        encoder = self.encoder(x)
        out = self.decoder(encoder)
        return encoder,out


if __name__ == "__main__":
    in_channels = 3
    img_size = 256

    out_channels = 8
    out_img_size = 64

    autoencoder = GIAO(in_channels, img_size, out_channels, out_img_size)
    autoencoder.apply(weights_init)
    x = torch.randn(4, in_channels, img_size, img_size)
    encoder,output = autoencoder(x)

    print(f"Input shape: {x.shape}")
    print(f'Encoder shape: {encoder.shape}')
    print(f"Output shape: {output.shape}")
