import torch
from torch import nn, einsum
import torch.nn.functional
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

from inspect import isfunction


# Code adapted for  data

def exists(val):
    return val is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, c=None):
        if exists(c):
            return self.fn(self.norm(x), self.norm(c)) + x
        return self.fn(self.norm(x)) + x


class SqueezeExcitation(nn.Module):
    def __init__(self, dim, shrinkage_rate = 0.25):
        super().__init__()
        hidden_dim = int(dim * shrinkage_rate)

        self.gate = nn.Sequential(
            Reduce('b c d h w -> b c', 'mean'),
            nn.Linear(dim, hidden_dim, bias = False),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim, bias = False),
            nn.Sigmoid(),
            Rearrange('b c -> b c 1 1 1')
        )

    def forward(self, x):
        return x * self.gate(x)


class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        inner_dim = int(dim * mult)
        self.net = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        dropout = 0.,
        window_size = 7
    ):
        super().__init__()
        assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

        self.heads = dim // dim_head
        self.scale = dim_head ** -0.5

        self.to_q = nn.Linear(dim, dim, bias = False)
        self.to_k = nn.Linear(dim, dim, bias = False)
        self.to_v = nn.Linear(dim, dim, bias = False)

        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias = False),
            nn.Dropout(dropout)
        )

        w1,w2,w3 = window_size
        self.rel_pos_bias = nn.Embedding((2 * w1 - 1) *(2 * w2 - 1)*(2 * w3 - 1), self.heads)
        pos1 = torch.arange(w1)
        pos2 = torch.arange(w2)
        pos3 = torch.arange(w3)

        grid = torch.stack(torch.meshgrid(pos1, pos2, pos3, indexing = 'ij'))
        grid = rearrange(grid, 'c i j k -> (i j k) c')

        rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...') 
        rel_pos[...,0] += w1 - 1
        rel_pos[...,1] += w2 - 1
        rel_pos[...,2] += w3 - 1

        rel_pos_indices = (rel_pos * torch.tensor([(2 *w2 - 1)*(2 *w3 - 1), (2 *w3 - 1), 1])).sum(dim = -1)
        
        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

    def forward(self, x, c=None):
        c = default(c, x)
        batch, height, width, depth, window_height, window_width, window_depth ,_, device, h = *x.shape, x.device, self.heads

        # flatten
        x = rearrange(x, 'b x y z w1 w2 w3 d -> (b x y z) (w1 w2 w3) d')
        c = rearrange(c, 'b x y z w1 w2 w3 d -> (b x y z) (w1 w2 w3) d')

        # project for queries, keys, values
        q = self.to_q(x)
        k = self.to_k(c)
        v = self.to_v(c)

        # split heads
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # scale
        q = q * self.scale

        # sim
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # add positional bias
        bias = self.rel_pos_bias(self.rel_pos_indices)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # attention
        attn = self.attend(sim)

        # aggregate
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # merge heads
        out = rearrange(out, 'b h (w1 w2 w3) d -> b w1 w2 w3 (h d)', w1 = window_height, w2 = window_width, w3 = window_depth)

        # combine heads out
        out = self.to_out(out)
        return rearrange(out, '(b x y z) ... -> b x y z ...', x = height, y = width, z = depth)


class Dropsample(nn.Module):
    def __init__(self, prob = 0):
        super().__init__()
        self.prob = prob
  
    def forward(self, x):
        device = x.device

        if self.prob == 0. or (not self.training):
            return x

        keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1, 1), device = device).uniform_() > self.prob
        return x * keep_mask / (1 - self.prob)


class MBConvResidual(nn.Module):
    def __init__(self, fn, dropout=0.):
        super().__init__()
        self.fn = fn
        self.dropsample = Dropsample(dropout)

    def forward(self, x):
        out = self.fn(x)
        out = self.dropsample(out)
        return out + x


def MBConv(
    dim_in,
    dim_out,
    *,
    downsample,
    expansion_rate=4,
    shrinkage_rate=0.25,
    dropout=0.
):
    hidden_dim = int(expansion_rate * dim_out)
    stride = 2 if downsample else 1

    net = nn.Sequential(
        nn.Conv3d(dim_in, hidden_dim, 1),
        nn.BatchNorm3d(hidden_dim),
        nn.GELU(),
        nn.Conv3d(hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim),
        nn.BatchNorm3d(hidden_dim),
        nn.GELU(),
        SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
        nn.Conv3d(hidden_dim, dim_out, 1),
        nn.BatchNorm3d(dim_out)
    )

    if dim_in == dim_out and not downsample:
        net = MBConvResidual(net, dropout=dropout)

    return net


class MaxAttentionBlock(nn.Module):
    def __init__(self, in_channels, heads=8, dim_head=64, dropout=0., window_size=4):
        super().__init__()
        self.in_channels = in_channels
        self.heads = heads
        self.dim_head = dim_head
        self.dropout = dropout
        self.window_size = window_size

        if isinstance(window_size, int):
            w1 = w2 = w3 = window_size
        else:
            w1, w2, w3 = window_size

        layer_dim = dim_head * heads

        self.rearrange_block_in = Rearrange('b d (x w1) (y w2) (z w3) -> b x y z w1 w2 w3 d', w1=w1, w2=w2, w3=w3)  # block-like attention
        self.attn_block = PreNormResidual(layer_dim, Attention(dim=layer_dim, dim_head=dim_head, dropout=dropout, window_size=(w1, w2, w3)))
        self.ff_block = PreNormResidual(layer_dim, FeedForward(dim=layer_dim, dropout=dropout))
        self.rearrange_block_out = Rearrange('b x y z w1 w2 w3 d -> b d (x w1) (y w2) (z w3)')

        self.rearrange_grid_in = Rearrange('b d (w1 x) (w2 y) (w3 z) -> b x y z w1 w2 w3 d', w1=w1, w2=w2, w3=w3)  # grid-like attention
        self.attn_grid = PreNormResidual(layer_dim, Attention(dim=layer_dim, dim_head=dim_head, dropout=dropout, window_size=(w1, w2, w3)))
        self.ff_grid = PreNormResidual(layer_dim, FeedForward(dim=layer_dim, dropout=dropout))
        self.rearrange_grid_out = Rearrange('b x y z w1 w2 w3 d -> b d (w1 x) (w2 y) (w3 z)')


    def forward(self, x):

        # block attention
        x = self.rearrange_block_in(x)
        x = self.attn_block(x)
        x = self.ff_block(x)
        x = self.rearrange_block_out(x)

        # grid attention
        x = self.rearrange_grid_in(x)
        x = self.attn_grid(x)
        x = self.ff_grid(x)
        x = self.rearrange_grid_out(x)
        
        ## output stage
        return x


class MultiDirectionAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        dropout = 0.,
        window_size = 7
    ):
        super().__init__()
        assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

        self.heads = dim // dim_head
        self.scale = dim_head ** -0.5

        self.to_q = nn.Linear(dim, dim, bias = False)
        self.to_k = nn.Linear(dim, dim, bias = False)
        self.to_v = nn.Linear(dim, dim, bias = False)

        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias = False),
            nn.Dropout(dropout)
        )

        w1,w2,w3 = window_size
        self.rel_pos_bias = nn.Embedding((2 * w1 - 1) *(2 * w2 - 1)*(2 * w3 - 1), self.heads)
        pos1 = torch.arange(w1)
        pos2 = torch.arange(w2)
        pos3 = torch.arange(w3)

        grid = torch.stack(torch.meshgrid(pos1, pos2, pos3, indexing = 'ij'))
        grid = rearrange(grid, 'c i j k -> (i j k) c')

        rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...') 
        rel_pos[...,0] += w1 - 1
        rel_pos[...,1] += w2 - 1
        rel_pos[...,2] += w3 - 1

        rel_pos_indices = (rel_pos * torch.tensor([(2 *w2 - 1)*(2 *w3 - 1), (2 *w3 - 1), 1])).sum(dim = -1)
        
        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

        # learnable direction embeddings to diagonalize the similarity matrix
        nd = 10
        direction_embeddings = torch.ones(self.heads, nd)
        direction_embeddings = torch.diag_embed(direction_embeddings)
        self.direction_embeddings = nn.Parameter(direction_embeddings)

    def forward(self, x, c=None):
        c = default(c, x)
        batch, height, width, depth, nd, window_height, window_width, window_depth ,_, device, h = *x.shape, x.device, self.heads

        # flatten
        x = rearrange(x, 'b x y z md w1 w2 w3 d -> (b x y z) (md w1 w2 w3) d')
        c = rearrange(c, 'b x y z md w1 w2 w3 d -> (b x y z) (md w1 w2 w3) d')

        # project for queries, keys, values
        q = self.to_q(x)
        k = self.to_k(c)
        v = self.to_v(c)

        # split heads
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # scale
        q = q * self.scale

        # compute similarity matrix
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # get spatial bias (shape: [L, L] where L = w1*w2*w3)
        spatial_bias = self.rel_pos_bias(self.rel_pos_indices)  # [w1*w2*w3, w1*w2*w3, heads]
        spatial_bias = rearrange(spatial_bias, 'i j h -> h i j')
        L = window_height * window_width * window_depth

        # because our flattened tokens include the md dimension, we expand the spatial bias across directions:
        spatial_bias = repeat(spatial_bias, 'h i j -> h (nd1 i) (nd2 j)', nd1=nd, nd2=nd)

        dir_bias = self.direction_embeddings  # shape: [total_tokens, total_tokens, heads]
        dir_bias = dir_bias.repeat_interleave(L, dim=1).repeat_interleave(L, dim=2)

        sim = sim + dir_bias + spatial_bias
        
        # attention
        attn = self.attend(sim)

        # aggregate
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # merge heads
        out = rearrange(out, 'b h (md w1 w2 w3) d -> b md w1 w2 w3 (h d)', w1 = window_height, w2 = window_width, w3 = window_depth, md = nd)

        # combine heads out
        out = self.to_out(out)
        return rearrange(out, '(b x y z) ... -> b x y z ...', x = height, y = width, z = depth)

class MultiDirectionMaxAttentionBlock(nn.Module):
    def __init__(self, in_channels, heads=8, dim_head=64, dropout=0., window_size=4):
        super().__init__()
        self.in_channels = in_channels
        self.heads = heads
        self.dim_head = dim_head
        self.dropout = dropout
        self.window_size = window_size

        if isinstance(window_size, int):
            w1 = w2 = w3 = window_size
        else:
            w1, w2, w3 = window_size

        layer_dim = dim_head * heads
        nd = 10

        self.rearrange_block_in = Rearrange('(b md) d (x w1) (y w2) (z w3) -> b x y z md w1 w2 w3 d', md=nd, w1=w1, w2=w2, w3=w3)  # block-like attention
        self.attn_block = PreNormResidual(layer_dim, MultiDirectionAttention(dim=layer_dim, dim_head=dim_head, dropout=dropout, window_size=(w1, w2, w3)))
        self.ff_block = PreNormResidual(layer_dim, FeedForward(dim=layer_dim, dropout=dropout))
        self.rearrange_block_out = Rearrange('b x y z md w1 w2 w3 d -> (b md) d (x w1) (y w2) (z w3)')

        self.rearrange_grid_in = Rearrange('(b md) d (w1 x) (w2 y) (w3 z) -> b x y z md w1 w2 w3 d', md=nd, w1=w1, w2=w2, w3=w3)  # grid-like attention
        self.attn_grid = PreNormResidual(layer_dim, MultiDirectionAttention(dim=layer_dim, dim_head=dim_head, dropout=dropout, window_size=(w1, w2, w3)))
        self.ff_grid = PreNormResidual(layer_dim, FeedForward(dim=layer_dim, dropout=dropout))
        self.rearrange_grid_out = Rearrange('b x y z md w1 w2 w3 d -> (b md) d (w1 x) (w2 y) (w3 z)')


    def forward(self, x):
        # block attention
        x = self.rearrange_block_in(x)
        x = self.attn_block(x)
        x = self.ff_block(x)
        x = self.rearrange_block_out(x)

        # grid attention
        x = self.rearrange_grid_in(x)
        x = self.attn_grid(x)
        x = self.ff_grid(x)
        x = self.rearrange_grid_out(x)

        ## output stage
        return x

class SpatialCrossAttentionWithMax(nn.Module):
    def __init__(self, in_channels, heads=8, dim_head=64, ctx_dim=None, dropout=0., window_size=4):
        super().__init__()
        if isinstance(window_size, int):
            w1 = w2 = w3 = window_size
        else:
            w1, w2, w3 = window_size

        layer_dim = dim_head * heads
        if ctx_dim is None:
            self.proj_in = MBConv(layer_dim*4, layer_dim, downsample=False)
        else:
            self.proj_in = MBConv(ctx_dim, layer_dim, downsample=False)

        self.rearrange_block_in = Rearrange('b d (x w1) (y w2) (z w3) -> b x y z w1 w2 w3 d', w1=w1, w2=w2, w3=w3)  # block-like attention
        self.attn_block = PreNormResidual(layer_dim, Attention(dim=layer_dim, dim_head=dim_head, dropout=dropout, window_size=(w1, w2, w3)))
        self.ff_block = PreNormResidual(layer_dim, FeedForward(dim=layer_dim, dropout=dropout))
        self.rearrange_block_out = Rearrange('b x y z w1 w2 w3 d -> b d (x w1) (y w2) (z w3)')

        self.rearrange_grid_in = Rearrange('b d (w1 x) (w2 y) (w3 z) -> b x y z w1 w2 w3 d', w1=w1, w2=w2, w3=w3)  # grid-like attention
        self.attn_grid = PreNormResidual(layer_dim, Attention(dim=layer_dim, dim_head=dim_head, dropout=dropout, window_size=(w1, w2, w3)))
        self.ff_grid = PreNormResidual(layer_dim, FeedForward(dim=layer_dim, dropout=dropout))
        self.rearrange_grid_out = Rearrange('b x y z w1 w2 w3 d -> b d (w1 x) (w2 y) (w3 z)')

        self.out_conv = nn.Sequential(
            SqueezeExcitation(dim=layer_dim*2),
            nn.Conv3d(layer_dim*2, layer_dim, kernel_size=3, padding=1)
        )

    def forward(self, x, context=None):
        context = default(context, x)

        # MBConv
        c = self.proj_in(context)

        # block attention
        x = self.rearrange_block_in(x)
        c = self.rearrange_block_in(c)
        x = self.attn_block(x, c)
        x = self.ff_block(x)
        x = self.rearrange_block_out(x)
        c = self.rearrange_block_out(c)

        # grid attention
        x = self.rearrange_grid_in(x)
        c = self.rearrange_grid_in(c)
        x = self.attn_grid(x, c)
        x = self.ff_grid(x)
        x = self.rearrange_grid_out(x)

        return x


class SpatialTransformerWithMax(nn.Module):
    def __init__(self, in_channels, n_heads, d_head, dropout=0., context_dim=None, w=2):
        super().__init__()
        self.in_channels = in_channels
        self.context_dim = context_dim
        inner_dim = n_heads * d_head

        self.proj_in = MBConv(context_dim, inner_dim, downsample=False)

        if isinstance(w, int):
            w1 = w2 = w3 = w
        else:
            w1, w2, w3 = w

        self.rearrange_block_in = Rearrange('b d (x w1) (y w2) (z w3) -> b x y z w1 w2 w3 d', w1=w1, w2=w2, w3=w3)  # block-like attention
        self.attn_block = PreNormResidual(inner_dim, Attention(dim=inner_dim, dim_head=d_head, dropout=dropout, window_size=(w1, w2, w3)))
        self.ff_block = PreNormResidual(inner_dim, FeedForward(dim=inner_dim, dropout=dropout))
        self.rearrange_block_out = Rearrange('b x y z w1 w2 w3 d -> b d (x w1) (y w2) (z w3)')

        self.rearrange_grid_in = Rearrange('b d (w1 x) (w2 y) (w3 z) -> b x y z w1 w2 w3 d', w1=w1, w2=w2, w3=w3)  # grid-like attention
        self.attn_grid = PreNormResidual(inner_dim, Attention(dim=inner_dim, dim_head=d_head, dropout=dropout, window_size=(w1, w2, w3)))
        self.ff_grid = PreNormResidual(inner_dim, FeedForward(dim=inner_dim, dropout=dropout))
        self.rearrange_grid_out = Rearrange('b x y z w1 w2 w3 d -> b d (w1 x) (w2 y) (w3 z)')

    def forward(self, x, context=None):
        context = default(context, x)

        # down sample context if necessary
        if context.shape[2] != x.shape[2]:
            stride = context.shape[2] // x.shape[2]
            context = torch.nn.functional.avg_pool3d(context, kernel_size=stride, stride=stride)

        # MBConv
        c = self.proj_in(context)

        # block attention
        x = self.rearrange_block_in(x)
        c = self.rearrange_block_in(c)
        x = self.attn_block(x, c)
        x = self.ff_block(x)
        x = self.rearrange_block_out(x)
        c = self.rearrange_block_out(c)

        # grid attention
        x = self.rearrange_grid_in(x)
        c = self.rearrange_grid_in(c)
        x = self.attn_grid(x, c)
        x = self.ff_grid(x)
        x = self.rearrange_grid_out(x)

        return x


class MultiDirectionCNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels=None, kernel_size=3, num_directions=10):
        super().__init__()
        self.num_directions = num_directions
        out_channels = out_channels or in_channels
        
        # Separate channel dimension to process directions independently
        self.local_branch = nn.Sequential(
            nn.Conv3d(in_channels, in_channels, kernel_size=kernel_size, padding=kernel_size//2, groups=num_directions),
            nn.BatchNorm3d(in_channels),
            nn.GELU(),
            nn.Conv3d(in_channels, in_channels, kernel_size=1),
            nn.BatchNorm3d(in_channels),
            nn.GELU(),
        )
        
        # Global branch - larger receptive field
        self.global_branch = nn.Sequential(
            nn.Conv3d(in_channels, in_channels, kernel_size=kernel_size+4, padding=(kernel_size+4)//2, groups=num_directions),
            nn.BatchNorm3d(in_channels),
            nn.GELU(),
            nn.Conv3d(in_channels, in_channels, kernel_size=1),
            nn.BatchNorm3d(in_channels),
            nn.GELU(),
        )
        
        # Direction mixing layer
        self.direction_mixer = nn.Sequential(
            Rearrange('(b d) c h w z -> b (d c) h w z', d=num_directions),
            nn.Conv3d(in_channels*num_directions, in_channels*num_directions, kernel_size=1, groups=num_directions),
            nn.BatchNorm3d(in_channels*num_directions),
            nn.GELU(),
            Rearrange('b (d c) h w z -> (b d) c h w z', d=num_directions)
        )
        
        # Final projection
        self.output_projection = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm3d(out_channels)
        )
        
        # Skip connection
        self.skip = nn.Identity() if in_channels == out_channels else nn.Conv3d(in_channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        identity = x
        
        # Process through local and global branches
        local_features = self.local_branch(x)
        global_features = self.global_branch(x)
        
        # Combine features
        combined = local_features + global_features
        
        # Mix directions
        mixed = self.direction_mixer(combined)
        
        # Final output with skip connection
        out = self.output_projection(mixed) + self.skip(identity)
        
        return out