import torch
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange

# helpers
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

# Prunable classes
class PrunableFeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, linear_layer, prune_reg='weight', task_mode='harp_prune'):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.linear1 = linear_layer(dim, hidden_dim, prune_reg=prune_reg, task_mode=task_mode)
        self.gelu = nn.GELU()
        self.linear2 = linear_layer(hidden_dim, dim, prune_reg=prune_reg, task_mode=task_mode)
        
    def forward(self, x):
        x = self.norm(x)
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        return x

class PrunableAttention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, linear_layer=nn.Linear, prune_reg='weight', task_mode='harp_prune'):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)
        
        self.attend = nn.Softmax(dim=-1)
        
        self.to_qkv = linear_layer(dim, inner_dim * 3, bias=False, prune_reg=prune_reg, task_mode=task_mode)
        self.to_out = linear_layer(inner_dim, dim, bias=False, prune_reg=prune_reg, task_mode=task_mode)

    def forward(self, x):
        x = self.norm(x)
        
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        
        attn = self.attend(dots)
        
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class PrunableTransformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, linear_layer, prune_reg='weight', task_mode='harp_prune'):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PrunableAttention(dim, heads=heads, dim_head=dim_head, linear_layer=linear_layer, 
                                prune_reg=prune_reg, task_mode=task_mode),
                PrunableFeedForward(dim, mlp_dim, linear_layer=linear_layer, 
                                  prune_reg=prune_reg, task_mode=task_mode)
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

class PrunableSimpleViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, 
                 channels=3, dim_head=64, linear_layer=nn.Linear, mean=None, std=None, 
                 normalize=False, prune_reg='weight', task_mode='harp_prune'):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, \
            'Image dimensions must be divisible by the patch size.'

        # Handle normalization parameters like ResNet
        self.normalize = normalize
        if mean is not None and std is not None:
            self.mean = torch.Tensor(mean).unsqueeze(1).unsqueeze(1)
            self.std = torch.Tensor(std).unsqueeze(1).unsqueeze(1)
        else:
            self.mean = None
            self.std = None

        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_height, p2=patch_width),
            nn.LayerNorm(patch_dim),
            linear_layer(patch_dim, dim, prune_reg=prune_reg, task_mode=task_mode),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = posemb_sincos_2d(
            h=image_height // patch_height,
            w=image_width // patch_width,
            dim=dim,
        )

        self.transformer = PrunableTransformer(dim, depth, heads, dim_head, mlp_dim, 
                                             linear_layer, prune_reg=prune_reg, task_mode=task_mode)

        self.pool = "mean"
        self.to_latent = nn.Identity()

        self.linear_head = linear_layer(dim, num_classes, prune_reg=prune_reg, task_mode=task_mode)

    def forward(self, img):
        device = img.device

        # Apply normalization if specified (like ResNet)
        if self.normalize and self.mean is not None and self.std is not None:
            img = (img - self.mean.to(device)) / self.std.to(device)

        x = self.to_patch_embedding(img)
        x += self.pos_embedding.to(device, dtype=x.dtype)

        x = self.transformer(x)
        x = x.mean(dim=1)

        x = self.to_latent(x)
        return self.linear_head(x)

    def set_prune_rate(self, k, global_k, alpha, device):
        """Set pruning rates for all prunable layers"""
        # Set pruning rate for patch embedding linear layer
        if hasattr(self.to_patch_embedding[2], 'set_prune_rate'):
            self.to_patch_embedding[2].set_prune_rate(k, global_k, alpha, device)
        
        # Set pruning rate for transformer layers
        for attn, ff in self.transformer.layers:
            # Attention layers
            if hasattr(attn.to_qkv, 'set_prune_rate'):
                attn.to_qkv.set_prune_rate(k, global_k, alpha, device)
            if hasattr(attn.to_out, 'set_prune_rate'):
                attn.to_out.set_prune_rate(k, global_k, alpha, device)
            
            # FeedForward layers
            if hasattr(ff.linear1, 'set_prune_rate'):
                ff.linear1.set_prune_rate(k, global_k, alpha, device)
            if hasattr(ff.linear2, 'set_prune_rate'):
                ff.linear2.set_prune_rate(k, global_k, alpha, device)
        
        # Set pruning rate for classification head
        if hasattr(self.linear_head, 'set_prune_rate'):
            self.linear_head.set_prune_rate(1, 1, 1.0, device)


# Helper functions to create different ViT variants with pruning support for CIFAR-10
def prunable_vit_tiny_cifar10(conv_layer, linear_layer, init_type="kaiming_normal", **kwargs):
    """Create a tiny prunable ViT optimized for CIFAR-10"""
    # Set default values for CIFAR-10, but allow overrides from kwargs
    defaults = {
        'image_size': 32,
        'patch_size': 4,
        'num_classes': 10,
        'mean': [0.4914, 0.4822, 0.4465],
        'std': [0.2471, 0.2435, 0.2616],
        'normalize': False,  # Set to True if you want automatic normalization
        'dim': 192,
        'depth': 6,
        'heads': 3,
        'mlp_dim': 768,
        'linear_layer': linear_layer
    }
    # Update defaults with any provided kwargs
    defaults.update(kwargs)
    return PrunableSimpleViT(**defaults)

def prunable_vit_small_cifar10(conv_layer, linear_layer, init_type="kaiming_normal", **kwargs):
    """Create a small prunable ViT optimized for CIFAR-10"""
    defaults = {
        'image_size': 32,
        'patch_size': 4,
        'num_classes': 10,
        'mean': [0.4914, 0.4822, 0.4465],
        'std': [0.2471, 0.2435, 0.2616],
        'normalize': False,
        'dim': 384,
        'depth': 8,
        'heads': 6,
        'mlp_dim': 1536,
        'linear_layer': linear_layer
    }
    defaults.update(kwargs)
    return PrunableSimpleViT(**defaults)

def prunable_vit_base_cifar10(conv_layer, linear_layer, init_type="kaiming_normal", **kwargs):
    """Create a base prunable ViT optimized for CIFAR-10"""
    defaults = {
        'image_size': 32,
        'patch_size': 4,
        'num_classes': 10,
        'mean': [0.4914, 0.4822, 0.4465],
        'std': [0.2471, 0.2435, 0.2616],
        'normalize': False,
        'dim': 512,
        'depth': 10,
        'heads': 8,
        'mlp_dim': 2048,
        'linear_layer': linear_layer
    }
    defaults.update(kwargs)
    return PrunableSimpleViT(**defaults)

# For compatibility with larger datasets (ImageNet), keeping original variants
def prunable_vit_tiny_imagenet(linear_layer, init_type="kaiming_normal", **kwargs):
    """Create a tiny prunable ViT for ImageNet"""
    defaults = {
        'image_size': 224,
        'patch_size': 16,
        'num_classes': 1000,
        'dim': 192,
        'depth': 12,
        'heads': 3,
        'mlp_dim': 768,
        'linear_layer': linear_layer
    }
    defaults.update(kwargs)
    return PrunableSimpleViT(**defaults)

def prunable_vit_small_imagenet(linear_layer, init_type="kaiming_normal", **kwargs):
    """Create a small prunable ViT for ImageNet"""
    defaults = {
        'image_size': 224,
        'patch_size': 16,
        'num_classes': 1000,
        'dim': 384,
        'depth': 12,
        'heads': 6,
        'mlp_dim': 1536,
        'linear_layer': linear_layer
    }
    defaults.update(kwargs)
    return PrunableSimpleViT(**defaults)

def prunable_vit_base_imagenet(linear_layer, init_type="kaiming_normal", **kwargs):
    """Create a base prunable ViT for ImageNet"""
    defaults = {
        'image_size': 224,
        'patch_size': 16,
        'num_classes': 1000,
        'dim': 768,
        'depth': 12,
        'heads': 12,
        'mlp_dim': 3072,
        'linear_layer': linear_layer
    }
    defaults.update(kwargs)
    return PrunableSimpleViT(**defaults)

# Example usage:
def test_prunable_vit():
    from subnet_layers import SubnetLinear  # Import your pruning layers
    
    # Create a prunable ViT for CIFAR-10
    model = prunable_vit_small_cifar10(
        linear_layer=SubnetLinear,
        prune_reg='weight',
        task_mode='harp_prune'
    )
    
    # Set pruning rates
    model.set_prune_rate(k=0.8, global_k=0.5, alpha=0.1, device='cpu')
    
    # Test forward pass
    x = torch.randn(1, 3, 32, 32)
    y = model(x)
    print(f"Output shape: {y.shape}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    return model

if __name__ == "__main__":
    test_prunable_vit()