
import torch
import torch.nn as nn
from vit_pytorch import ViT
from einops.layers.torch import Rearrange
import torch.nn.functional as F


class EquivariantLayer(nn.Module):
    """
    A permutation-equivariant layer for DeepSets.
    """
    def __init__(self, d_in, d_hidden):
        super().__init__()
        self.phi_self = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_hidden)
        )
        
        self.phi_mean = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_hidden)
        )

    def forward(self, x):
        """
        x: [B, N, d] - Input batch
        Returns: [B, N, d] - Equivariant output
        """
        # x.shape: [B, num_patches, d, `num_pixels`]
        self_term = self.phi_self(x)
        # x.max(dim=-2, keepdim=True)[0].shape: [B, num_patches, 1, `num_pixels`]
        sum_term = self.phi_mean(x.max(dim=-2, keepdim=True)[0])
        return self_term + sum_term

class InvariantLayer(nn.Module):
    """
    A permutation-invariant layer that maps [B, N, d] -> [B, d_out]
    """
    def __init__(self, d_in, d_out):
        super().__init__()
        self.rho = nn.Sequential(
            nn.Linear(d_in, d_out),
            nn.ReLU(),
            nn.Linear(d_out, d_out)
        )

    def forward(self, x):
        """
        x: [B, N, d] - Input batch
        Returns: [B, d_out] - Aggregated representation
        """
        x = x.max(dim=-2)[0]  # Mean-pooling aggregation
        return self.rho(x)

class DeepSets(nn.Module):
    """
    Full DeepSets model with K equivariant layers and one invariant layer.
    """
    def __init__(self, d_in, d_hidden, d_out, dropout, K=2, inv=True):
        super().__init__()
        self.equivariant_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.dropout = dropout
        self.inv = inv
        
        for i in range(K):
            self.equivariant_layers.append(
                EquivariantLayer(d_in if i == 0 else d_hidden, d_hidden)
            )
            self.batch_norms.append(nn.BatchNorm1d(d_hidden))
        if self.inv:
            self.invariant_layer = InvariantLayer(d_hidden, d_out)

    def forward(self, x):
        """
        x: [B, N, d_in] - Input batch
        Returns: [B, d_out] - Set-level representation
        """
        for layer, bn in zip(self.equivariant_layers, self.batch_norms):
            # x.shape: [B, num_patches, d, `num_pixels`]
            x = layer(x)
            B, num_patches, features, d = x.shape
            x = x.view(B * num_patches * features, d) 
            x = bn(x)
            x = x.view(B, num_patches, features, d)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout)
        if self.inv:
            return self.invariant_layer(x)
        else:
            return x



class DSS_Equiv(nn.Module):
    """
    Full DSS_Equiv model with K equivariant layers and one invariant layer.
    """
    def __init__(self, d_in, d_hidden, dropout):
        super().__init__()
        self.phi_self = DeepSets(d_in=d_in, d_hidden=d_hidden, K=1, inv=False, d_out=None, dropout=dropout)
        self.phi_mean = DeepSets(d_in=d_in, d_hidden=d_hidden, K=1, inv=False, d_out=None, dropout=dropout)

        
    def forward(self, x):
        # x.shape: [B, num_patches, d, `num_pixels`]
        self_term = self.phi_self(x)
        # x.max(dim=-2, keepdim=True)[0].shape: [B, 1, d, `num_pixels`]
        sum_term = self.phi_mean(x.max(dim=1, keepdim=True)[0])
        return self_term + sum_term


class DSS_Inv(nn.Module):
    """
    Full DSS_Inv model with K equivariant layers and one invariant layer.
    """
    def __init__(self, d_in, d_out):
        super().__init__()
        self.phi_self = DeepSets(d_in=None, d_hidden=d_in, K=0, d_out=d_out, dropout=None)
        self.phi_mean = DeepSets(d_in=None, d_hidden=d_in, K=0, d_out=d_out,dropout=None)

        
    def forward(self, x):
        # x.shape: [B, num_patches, d, `num_pixels`]
        
        self_term = self.phi_self(x) # [B, num_patches, 1, `num_pixels`]
        
        # x.max(dim=1, keepdim=True)[0].shape: [B, 1, d, `num_pixels`]
        sum_term = self.phi_mean(x.max(dim=1, keepdim=True)[0]) # [B, 1, 1, `num_pixels`]
        return self_term + sum_term




class DSS(nn.Module):
    """
    Full DSS model with K equivariant layers and one invariant layer.
    """
    def __init__(self, d_in, d_hidden, d_out, dropout, K=2):
        super().__init__()
        self.equivariant_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.dropout = dropout
        
        for i in range(K):
            self.equivariant_layers.append(
                DSS_Equiv(d_in=d_in if i == 0 else d_hidden, d_hidden=d_hidden, dropout=dropout)
            )
            self.batch_norms.append(nn.BatchNorm1d(d_hidden))
        
        self.invariant_layer = DSS_Inv(d_hidden, d_out)

        
    def forward(self, x):
        """
        x: [B, N, d_in] - Input batch
        Returns: [B, d_out] - Set-level representation
        """
        for layer, bn in zip(self.equivariant_layers, self.batch_norms):
            x = layer(x)
            B, num_patches, features, d = x.shape
            x = x.view(B * num_patches * features, d) 
            x = bn(x)
            x = x.view(B, num_patches, features, d)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout)
        return self.invariant_layer(x)




class SdInvVit(ViT):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0., DS_model=DeepSets):
        super().__init__(
            image_size=image_size,
            patch_size=patch_size,
            num_classes=num_classes,
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim,
            pool=pool,
            channels=channels,
            dim_head=dim_head,
            dropout=dropout,
            emb_dropout=emb_dropout
        )
        patch_height, patch_width = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
        
        # Override patch embedding method
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) c (p1 p2)', p1 = patch_height, p2 = patch_width), # [1 x num_patches x c x p1*p2 ]
            DS_model(d_in=patch_height*patch_width, d_hidden=dim, d_out=dim, dropout=dropout),
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
        )




class PatchEmbeddingWithPos(nn.Module):
    def __init__(self, patch_height, patch_width, patch_dim, hidden_dim):
        super().__init__()
        self.rearrange1 = Rearrange('b c (h p1) (w p2) -> b h w p1 p2 c', p1=patch_height, p2=patch_width)
        self.rearrange2 = Rearrange('b h w p1 p2 c -> b (h w) (p1 p2 c)')
        # self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width)
        self.pos_embedding = nn.Parameter(torch.randn(1, 1, patch_height, patch_width, hidden_dim))

        self.norm1 = nn.LayerNorm(patch_dim)
        self.linear = nn.Linear(patch_dim, hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        # x.shape: [B, c, h=L, w=N]
        x = self.rearrange1(x)  # x.shape: [B, num_patches, patch_dim]
        x = x + self.pos_embedding
        x = self.rearrange2(x)
        # x = self.rearrange(x)
        x = self.norm1(x)
        x = self.linear(x)
        x = self.norm2(x)
        return x
