# pytorch_diffusion + derived encoder decoder (3D version)
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from functools import partial
import torch.nn.init as init

from ldm.modules.attention import LinearAttention, SpatialCrossAttentionWithPosEmb
from ldm.modules.maxvit import SpatialCrossAttentionWithMax, MaxAttentionBlock


def get_timestep_embedding(timesteps, embedding_dim):
    assert len(timesteps.shape) == 1
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb


def nonlinearity(x):
    return x * torch.sigmoid(x)


def Normalize(in_channels, num_groups=32):
    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)


class IdentityWrapper(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer = nn.Identity()

    def forward(self, x, context=None):
        return self.layer(x)


class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            self.conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="trilinear", align_corners=False)
        if self.with_conv:
            x = self.conv(x)
        return x


class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            self.conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x):
        if self.with_conv:
            pad = (0, 1, 0, 1, 0, 1)  # Padding for 3D
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool3d(x, kernel_size=2, stride=2)
        return x


class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = Normalize(in_channels)
        self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
        self.norm2 = Normalize(out_channels)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x, temb):
        h = x
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)

        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]

        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x + h
    

class ResnetBlock1D(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = Normalize(in_channels)
        self.conv1 = torch.nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
        self.norm2 = Normalize(out_channels)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x, temb):
        h = x
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)

        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:, :, None]

        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x + h


class LinAttnBlock(LinearAttention):
    def __init__(self, in_channels):
        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)


class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.k = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.v = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.proj_out = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        b, c, d, h, w = q.shape
        q = q.reshape(b, c, -1).permute(0, 2, 1)  # b, d*h*w, c
        k = k.reshape(b, c, -1)  # b, c, d*h*w
        w_ = torch.bmm(q, k)  # b, d*h*w, d*h*w
        w_ = w_ * (int(c) ** -0.5)
        w_ = torch.nn.functional.softmax(w_, dim=2)

        v = v.reshape(b, c, -1)
        h_ = torch.bmm(v, w_.permute(0, 2, 1))  # b, c, d*h*w
        h_ = h_.reshape(b, c, d, h, w)

        h_ = self.proj_out(h_)
        return x + h_


class AttnBlock1D(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.k = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.v = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.proj_out = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        
        nn.init.zeros_(self.proj_out.weight)
        nn.init.zeros_(self.proj_out.bias)

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        b, c, d = q.shape
        q = q.reshape(b, c, -1).permute(0, 2, 1)  # b, d, c
        k = k.reshape(b, c, -1)  # b, c, d
        w_ = torch.bmm(q, k)  # b, d, d
        w_ = w_ * (int(c) ** -0.5)
        w_ = torch.nn.functional.softmax(w_, dim=2)

        v = v.reshape(b, c, -1)
        h_ = torch.bmm(v, w_.permute(0, 2, 1))  # b, c, d
        h_ = h_.reshape(b, c, d)

        h_ = self.proj_out(h_)
        return x + h_

class AttnBlock1D_Relative(nn.Module):
    def __init__(self, in_channels, max_rel_dist=64):
        super().__init__()
        self.in_channels = in_channels
        self.max_rel_dist = max_rel_dist
        self.rel_pos_bias = nn.Parameter(torch.zeros(2 * max_rel_dist - 1))  # [−(d−1), ..., +(d−1)]

        self.norm = Normalize(in_channels)
        self.q = nn.Conv1d(in_channels, in_channels, kernel_size=1)
        self.k = nn.Conv1d(in_channels, in_channels, kernel_size=1)
        self.v = nn.Conv1d(in_channels, in_channels, kernel_size=1)
        self.proj_out = nn.Conv1d(in_channels, in_channels, kernel_size=1)

        nn.init.zeros_(self.proj_out.weight)
        nn.init.zeros_(self.proj_out.bias)

    def _relative_position_index(self, size):
        """Generate relative position index matrix of shape [size, size]."""
        coords = torch.arange(size, device=self.rel_pos_bias.device)
        rel_pos = coords[None, :] - coords[:, None]  # shape: (size, size)
        rel_pos += self.max_rel_dist - 1  # shift to index range [0, 2*max_rel_dist - 2]
        return rel_pos.clamp(0, 2 * self.max_rel_dist - 2)

    def forward(self, x):
        h_ = self.norm(x)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        b, c, d = q.shape  # batch, channel, length

        # reshape
        q = q.permute(0, 2, 1)        # b, d, c
        k = k                         # b, c, d
        attn_scores = torch.bmm(q, k)  # b, d, d
        attn_scores = attn_scores / (c ** 0.5)

        # add relative position bias
        rel_pos_index = self._relative_position_index(d)  # (d, d)
        rel_bias = self.rel_pos_bias[rel_pos_index]       # (d, d)
        attn_scores = attn_scores + rel_bias.unsqueeze(0)  # broadcast to (b, d, d)

        attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)

        v = v.reshape(b, c, -1)  # b, c, d
        h_ = torch.bmm(v, attn_weights.permute(0, 2, 1))  # b, c, d
        h_ = self.proj_out(h_)

        return x + h_

def make_attn(in_channels, attn_type="vanilla"):
    assert attn_type in ["vanilla", "linear", "none", 'max'], f'attn_type {attn_type} unknown'
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        return AttnBlock(in_channels)
    elif attn_type == "none":
        return nn.Identity(in_channels)
    elif attn_type == 'max':
        return MaxAttentionBlock(in_channels, heads=1, dim_head=in_channels)
    else:
        return LinAttnBlock(in_channels)


class DiEncoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
                 num_classes=None, **ignore_kwargs):
        super().__init__()
        if use_linear_attn:
            attn_type = "linear"
        self.ch = ch  # 128
        self.temb_ch = ch * 4 if num_classes is not None else 0
        self.num_resolutions = len(ch_mult)  # 3
        self.num_res_blocks = num_res_blocks  # 2
        self.resolution = resolution  # 256
        self.in_channels = in_channels  # 3
        self.num_classes = num_classes  # 4
        
        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, self.temb_ch)

        # downsampling
        self.conv_in = torch.nn.Conv3d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution  # 256
        in_ch_mult = (1,) + tuple(ch_mult)  # (1,1,2,4)
        self.in_ch_mult = in_ch_mult  # (1,1,2,4)
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = int(ch * in_ch_mult[i_level])
            block_out = int(ch * ch_mult[i_level])
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                           out_channels=block_out,
                                           temb_channels=self.temb_ch,
                                           dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                         out_channels=block_in,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                         out_channels=block_in,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout)

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv3d(block_in,
                                        2 * z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x, y=None, ret_feature=False):
        temb = None
        if y is not None:
            temb = self.label_emb(y)

        # downsampling
        hs = [self.conv_in(x)]
        phi_list = []
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            phi_list.append(hs[-1])
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        phi_list.append(h)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)

        if ret_feature:
            return h, phi_list
        return h


class DiDecoderWithResidual(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
                 attn_type="vanilla", num_head_channels=32, num_heads=1, cond_type=None,
                 num_classes=None, **ignorekwargs):
        super().__init__()
        
        self.ch = ch
        self.temb_ch = ch * 4 if num_classes is not None else 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end
        self.tanh_out = tanh_out
        self.num_classes = num_classes  # 4
        
        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, self.temb_ch)
        
        in_ch_mult = (1,) + tuple(ch_mult)
        block_in = int(ch * ch_mult[self.num_resolutions - 1])
        curr_res = resolution // 2 ** (self.num_resolutions - 1)
        self.z_shape = (1, z_channels, curr_res, curr_res, curr_res)

        # z to block_in
        self.conv_in = torch.nn.Conv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1)

        # middle block
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        # CrossAttention
        if num_head_channels == -1:
            dim_head = block_in // num_heads
        else:
            num_heads = block_in // num_head_channels
            dim_head = num_head_channels # 32
        if cond_type == 'cross_attn':
            self.mid.cross_attn = SpatialCrossAttentionWithPosEmb(in_channels=block_in, 
                                                            ctx_channels=block_in * 4,
                                                            heads=num_heads,
                                                            dim_head=dim_head)
        elif cond_type == 'max_cross_attn':
            self.mid.cross_attn = SpatialCrossAttentionWithMax(in_channels=block_in,
                                                        heads=num_heads,
                                                        dim_head=dim_head,
                                                        )
        elif cond_type == 'max_cross_attn_frame':
            self.mid.cross_attn = SpatialCrossAttentionWithMax(in_channels=block_in,
                                                        heads=num_heads,
                                                        dim_head=dim_head,
                                                        ctx_dim=6,
                                                        )
        else:
            self.mid.cross_attn = IdentityWrapper()
        
        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = int(ch * ch_mult[i_level])
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))

            # CrossAttention
            if num_head_channels == -1:
                dim_head = block_in // num_heads
            else:
                num_heads = block_in // num_head_channels
                dim_head = num_head_channels # 32
            if cond_type == 'cross_attn':
                cross_attn = SpatialCrossAttentionWithPosEmb(in_channels=block_in, 
                                                             ctx_channels=block_in * 4,
                                                             heads=num_heads,
                                                             dim_head=dim_head)
            elif cond_type == 'max_cross_attn':
                cross_attn = SpatialCrossAttentionWithMax(in_channels=block_in,
                                                          heads=num_heads,
                                                          dim_head=dim_head,
                                                          )
            elif cond_type == 'max_cross_attn_frame':
                cross_attn = SpatialCrossAttentionWithMax(in_channels=block_in,
                                                          heads=num_heads,
                                                          dim_head=dim_head,
                                                          ctx_dim=6,
                                                          )
            else:
                cross_attn = IdentityWrapper()
                
            up = nn.Module()
            up.block = block
            up.attn = attn
            up.cross_attn = cross_attn

            # Upsample
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up)

        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv3d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

    def forward(self, z, y=None, cond_dict=None):
        phi_b0_list = cond_dict['phi_b0_list']
        phi_b1000x_list = cond_dict['phi_b1000x_list']
        phi_b1000y_list = cond_dict['phi_b1000y_list']
        phi_b1000z_list = cond_dict['phi_b1000z_list']

        # timestep embedding
        temb = None
        if y is not None:
            temb = self.label_emb(y)

        # z to block_in
        h = self.conv_in(z)

        # middle block
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)
        ctx = None
        if phi_b0_list[-1] is not None:
            ctx = torch.cat([phi_b0_list[-1], phi_b1000x_list[-1], phi_b1000y_list[-1], phi_b1000z_list[-1]], dim=1)
        h = self.mid.cross_attn(h, ctx)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            ctx = None
            if phi_b0_list[i_level] is not None:
                ctx = torch.cat([phi_b0_list[i_level], phi_b1000x_list[i_level], phi_b1000y_list[i_level], phi_b1000z_list[i_level]], dim=1)
            h = self.up[i_level].cross_attn(h, ctx)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        out = self.conv_out(h)
        return out


class MultiClass_DiEncoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
                 num_classes=None, **ignore_kwargs):
        """
        Multi-class version of DiEncoder that creates separate encoders for each class
        
        Parameters:
            num_classes (int) -- number of classes to create separate encoders for
            (other parameters are the same as DiEncoder)
        """
        super().__init__()
        
        self.num_classes = num_classes
        # Create a separate encoder for each class
        self.encoders = nn.ModuleList([
            DiEncoder(
                ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
                attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
                in_channels=in_channels, resolution=resolution, z_channels=z_channels,
                double_z=double_z, use_linear_attn=use_linear_attn, attn_type=attn_type,
                num_classes=None  # Individual encoders don't need class conditioning
            ) for _ in range(num_classes)
        ])

    def forward(self, x, y=None, ret_feature=False):
        """
        Args:
            x: Input tensor (b,c,d,h,w)
            y: Class labels (b,)
            ret_feature: Whether to return intermediate features
        """
        if y is None:
            raise ValueError("Class labels (y) must be provided for MultiClass_DiEncoder")
        
        batch_size = x.size(0)
        outputs = []
        phi_lists = [] if ret_feature else None
        
        # Process each sample with its corresponding class encoder
        for i in range(batch_size):
            class_idx = int(y[i].item())
            if ret_feature:
                output, phi_list = self.encoders[class_idx](x[i:i+1], ret_feature=True)
                outputs.append(output)
                phi_lists.append(phi_list)
            else:
                output = self.encoders[class_idx](x[i:i+1])
                outputs.append(output)
        
        # Combine outputs from all samples back into a batch
        combined_output = torch.cat(outputs, dim=0)
        
        if ret_feature:
            # Combine phi_lists - each element should be a tensor of shape [batch_size, ...]
            combined_phi_lists = []
            for level in range(len(phi_lists[0])):
                level_features = [phi_lists[b][level] for b in range(batch_size)]
                combined_phi_lists.append(torch.cat(level_features, dim=0))
            return combined_output, combined_phi_lists
        
        return combined_output


class MultiClass_DiDecoderWithResidual(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
                 attn_type="vanilla", num_head_channels=32, num_heads=1, cond_type=None,
                 num_classes=None, **ignorekwargs):
        """
        Multi-class version of DiDecoderWithResidual that creates separate decoders for each class
        
        Parameters:
            num_classes (int) -- number of classes to create separate decoders for
            (other parameters are the same as DiDecoderWithResidual)
        """
        super().__init__()
        
        self.num_classes = num_classes
        # Create a separate decoder for each class
        self.decoders = nn.ModuleList([
            DiDecoderWithResidual(
                ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
                attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
                in_channels=in_channels, resolution=resolution, z_channels=z_channels,
                give_pre_end=give_pre_end, tanh_out=tanh_out, use_linear_attn=use_linear_attn,
                attn_type=attn_type, num_head_channels=num_head_channels, num_heads=num_heads,
                cond_type=cond_type, num_classes=None  # Individual decoders don't need class conditioning
            ) for _ in range(num_classes)
        ])

    def forward(self, z, y=None, cond_dict=None):
        """
        Args:
            z: Latent tensor (b,c,d,h,w)
            y: Class labels (b,)
            cond_dict: Dictionary containing conditional features
        """
        if y is None:
            raise ValueError("Class labels (y) must be provided for MultiClass_DiDecoderWithResidual")
        
        batch_size = z.size(0)
        outputs = []
        
        # Process each sample with its corresponding class decoder
        for i in range(batch_size):
            class_idx = int(y[i].item())
            
            # Extract individual sample's conditional features
            sample_cond_dict = {}
            for key, value_list in cond_dict.items():
                sample_cond_dict[key] = []
                for level_features in value_list:
                    if level_features is not None:
                        sample_cond_dict[key].append(level_features[i:i+1])
                    else:
                        sample_cond_dict[key].append(None)
            
            # Process with the appropriate decoder
            output = self.decoders[class_idx](z[i:i+1], cond_dict=sample_cond_dict)
            outputs.append(output)
        
        # Combine outputs from all samples back into a batch
        return torch.cat(outputs, dim=0)



class MultiClass_DiDecoderWithResidual_Joint(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
                 attn_type="vanilla", num_head_channels=32, num_heads=1, cond_type=None,
                 num_classes=None, num_directions=6, **ignorekwargs):
        """
        Multi-class version of DiDecoderWithResidual that creates separate decoders for each class
        
        Parameters:
            num_classes (int) -- number of classes to create separate decoders for
            (other parameters are the same as DiDecoderWithResidual)
        """
        super().__init__()
        
        self.num_classes = num_classes
        self.decoders = nn.ModuleList([
            DiDecoderWithResidual(
                ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
                attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
                in_channels=in_channels, resolution=resolution, z_channels=z_channels,
                give_pre_end=give_pre_end, tanh_out=tanh_out, use_linear_attn=use_linear_attn,
                attn_type=attn_type, num_head_channels=num_head_channels, num_heads=num_heads,
                cond_type=cond_type, num_classes=None  # Individual decoders don't need class conditioning
            ) for _ in range(num_classes)
        ])
        
        self.ch = ch
        self.temb_ch = ch * 4 if num_classes is not None else 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.tanh_out = tanh_out
        self.num_directions = num_directions
        
        in_ch_mult = (1,) + tuple(ch_mult)
        block_in = int(ch * ch_mult[self.num_resolutions - 1])
        curr_res = resolution // 2 ** (self.num_resolutions - 1)
        
        for i_level in reversed(range(self.num_resolutions)):
            block_out = int(ch * ch_mult[i_level])
            block_in = block_out

        self.joint = nn.Sequential(
            nn.Linear(in_features=block_in*self.num_directions, out_features=block_in*self.num_directions//2),
            nn.GELU(),
            nn.Linear(in_features=block_in*self.num_directions//2, out_features=block_in*self.num_directions)
        )
        
        with torch.no_grad():
            init.normal_(self.joint[0].weight, std=1e-3)
            init.zeros_(self.joint[0].bias)
            init.zeros_(self.joint[2].weight)
            init.zeros_(self.joint[2].bias)
            
    def forward_angle(self, h_iso, h_aniso, module):
        d, h, w = [int(s) for s in h_iso.size()[-3:]]
        h_iso = rearrange(h_iso, '(b a) c d h w -> (b d h w) (a c)', a = self.num_directions//2)
        h_aniso = rearrange(h_aniso, '(b a) c d h w -> (b d h w) (a c)', a = self.num_directions//2)
        h_total = torch.cat([h_iso, h_aniso], dim=-1)
        h_total = module(h_total)
        split_size = h_iso.shape[-1]
        h_iso = h_total[:, :split_size] + h_iso
        h_aniso = h_total[:, split_size:] + h_aniso
        h_iso = rearrange(h_iso, '(b d h w) (a c) -> (b a) c d h w', d=d, h=h, w=w, a = self.num_directions//2)
        h_aniso = rearrange(h_aniso, '(b d h w) (a c) -> (b a) c d h w', d=d, h=h, w=w, a = self.num_directions//2)
        return h_iso, h_aniso

    def center_crop_to_match(self, src, target):
        """
        Center crop src tensor to match the spatial size of target tensor.
        Assumes tensors are in (B, C, D, H, W) format.
        """
        if src.shape[-3:] == target.shape[-3:]:
            return src
        _, _, d_t, h_t, w_t = target.shape
        _, _, d_s, h_s, w_s = src.shape
        start_d = (d_s - d_t) // 2
        start_h = (h_s - h_t) // 2
        start_w = (w_s - w_t) // 2
        return src[:, :, start_d:start_d+d_t, start_h:start_h+h_t, start_w:start_w+w_t]

    def forward_joint(self, z_iso, z_aniso, cond_dict=None, temb=None):
        decoder_iso = self.decoders[0]
        decoder_aniso = self.decoders[1]
        
        phi_b0_list = cond_dict['phi_b0_list']
        phi_b1000x_list = cond_dict['phi_b1000x_list']
        phi_b1000y_list = cond_dict['phi_b1000y_list']
        phi_b1000z_list = cond_dict['phi_b1000z_list']
        
        h_iso = decoder_iso.conv_in(z_iso)
        h_aniso = decoder_aniso.conv_in(z_aniso)
        
        h_iso = decoder_iso.mid.block_1(h_iso, temb)
        h_aniso = decoder_aniso.mid.block_1(h_aniso, temb)
        h_iso = decoder_iso.mid.attn_1(h_iso)
        h_aniso = decoder_aniso.mid.attn_1(h_aniso)
        h_iso = decoder_iso.mid.block_2(h_iso, temb)
        h_aniso = decoder_aniso.mid.block_2(h_aniso, temb)
        
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(decoder_iso.num_res_blocks):
                h_iso = decoder_iso.up[i_level].block[i_block](h_iso, temb)
                h_aniso = decoder_aniso.up[i_level].block[i_block](h_aniso, temb)
                if len(decoder_iso.up[i_level].attn) > 0:
                    h_iso = decoder_iso.up[i_level].attn[i_block](h_iso)
                if len(decoder_aniso.up[i_level].attn) > 0:
                    h_aniso = decoder_aniso.up[i_level].attn[i_block](h_aniso)
            ctx = None
            if phi_b0_list[i_level] is not None:
                ctx = torch.cat([phi_b0_list[i_level], phi_b1000x_list[i_level], phi_b1000y_list[i_level], phi_b1000z_list[i_level]], dim=1)
                ctx = self.center_crop_to_match(ctx, h_iso)
            h_iso = decoder_iso.up[i_level].cross_attn(h_iso, ctx)
            h_aniso = decoder_aniso.up[i_level].cross_attn(h_aniso, ctx)

            if i_level != 0:
                h_iso = decoder_iso.up[i_level].upsample(h_iso)
                h_aniso = decoder_aniso.up[i_level].upsample(h_aniso)
                
        h_iso, h_aniso = self.forward_angle(h_iso, h_aniso, self.joint)

        h_iso = decoder_iso.norm_out(h_iso)
        h_iso = nonlinearity(h_iso)
        out_iso = decoder_iso.conv_out(h_iso)
        h_aniso = decoder_aniso.norm_out(h_aniso)
        h_aniso = nonlinearity(h_aniso)
        out_aniso = decoder_aniso.conv_out(h_aniso)
        
        return out_iso, out_aniso
        
    def forward(self, z, y=None, cond_dict=None):
        """
        Args:
            z: Latent tensor (b,a,c,d,h,w)
            y: Class labels (b,)
            cond_dict: Dictionary containing conditional features
        """
        if y is None:
            raise ValueError("Class labels (y) must be provided for MultiClass_DiDecoderWithResidual")
        
        z = rearrange(z, '(b a) c d h w -> b a c d h w', a=self.num_directions)
        z_iso = z[:, :self.num_directions//2]
        z_aniso = z[:, self.num_directions//2:]
        z_iso = rearrange(z_iso, 'b a c d h w -> (b a) c d h w')
        z_aniso = rearrange(z_aniso, 'b a c d h w -> (b a) c d h w')
        
        out_iso, out_aniso = self.forward_joint(z_iso, z_aniso, cond_dict)
        
        out_iso = rearrange(out_iso, '(b a) c d h w -> b a c d h w', a = self.num_directions//2)
        out_aniso = rearrange(out_aniso, '(b a) c d h w -> b a c d h w', a = self.num_directions//2)
        
        return torch.cat([out_iso, out_aniso], dim=1)