# https://github.com/lucidrains/mlp-mixer-pytorch/blob/main/mlp_mixer_pytorch/mlp_mixer_pytorch.py
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce

pair = lambda x: x if isinstance(x, tuple) else (x, x)

class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x

def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
    inner_dim = int(dim * expansion_factor)
    return nn.Sequential(
        dense(dim, inner_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        dense(inner_dim, dim),
        nn.Dropout(dropout)
    )

# def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.):
#     image_h, image_w = pair(image_size)
#     assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
#     num_patches = (image_h // patch_size) * (image_w // patch_size)
#     chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear

#     return nn.Sequential(
#         Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
#         nn.Linear((patch_size ** 2) * channels, dim),
#         *[nn.Sequential(
#             PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
#             PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
#         ) for _ in range(depth)],
#         nn.LayerNorm(dim),
#         Reduce('b n c -> b c', 'mean'),
#         nn.Linear(dim, num_classes)
#     )

class MLPMixer(nn.Module):
    def __init__(self, *, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor=4, expansion_factor_token=0.5, dropout=0.):
        super().__init__()
        
        image_h, image_w = pair(image_size)
        assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
        num_patches = (image_h // patch_size) * (image_w // patch_size)
        chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear

        self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
        self.patch_embedding = nn.Linear((patch_size ** 2) * channels, dim)

        self.mixer_layers = nn.ModuleList([
            nn.Sequential(
                PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
                PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
            ) for _ in range(depth)
        ])

        self.layer_norm = nn.LayerNorm(dim)
        self.reduce = Reduce('b n c -> b c', 'mean')
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, x, return_intermediates=False):
        intermediates = []

        # Rearrange the input into patches
        x = self.rearrange(x)
        
        # Patch embedding
        x = self.patch_embedding(x)

        # Pass through mixer layers
        for layer in self.mixer_layers:
            x = layer(x)
            if return_intermediates:
                intermediates.append(x.clone())  # Save after each mixer layer

        # Layer normalization before classification
        x = self.layer_norm(x)

        # Global pooling and classification
        x = self.reduce(x)
        x = self.fc(x)
        if return_intermediates:
            intermediates.append(x.clone())  # Save after final fully connected layer

        if return_intermediates:
            return x, intermediates
        return x

