# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from torch import Tensor

from ldm.util import instantiate_from_config
from ldm.modules.attention import LinearAttention


# wrapper classes to prevent orthogonal regularization 
class noOrthoRegularization:
    pass
class noOrthoRegularizationConv1d(nn.Conv1d, noOrthoRegularization):
    pass
class noOrthoRegularizationConv2d(nn.Conv2d, noOrthoRegularization):
    pass
class noOrthoRegularizationLinear(nn.Linear, noOrthoRegularization):
    pass


def get_timestep_embedding(timesteps, embedding_dim):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models:
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0,1,0,0))
    return emb


def nonlinearity(x):
    # swish
    return x*torch.sigmoid(x)


def Normalize(in_channels, num_groups=32):
    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)


class MuxConvBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, num_groups=32, nonlinearity=True):
        super().__init__()
        self.num_groups = num_groups
        self.in_channels = in_channels
        self.nonlin = nonlinearity
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        if self.num_groups > 0 :
            self.norm_out = Normalize(in_channels, num_groups=num_groups)
        self.conv = noOrthoRegularizationConv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)

    def forward(self, x):
        if self.num_groups > 0:
            x = self.norm_out(x)
        x = self.conv(x)
        if self.nonlin:
            x = nonlinearity(x)
        return x

class MuxLinearBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, num_groups=32):
        super().__init__()
        self.num_groups = num_groups
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        if self.num_groups > 0 :
            self.norm_out = Normalize(in_channels, num_groups=num_groups)
        self.lin = noOrthoRegularizationLinear(in_channels,
                                     out_channels)

    def forward(self, x):
        if self.num_groups > 0:
            x = self.norm_out(x)
        x = self.lin(x)
        x = nonlinearity(x)
        return x

class ResMuxBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512, num_groups=32):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = Normalize(in_channels, num_groups=num_groups)
        self.conv1 = noOrthoRegularizationConv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels,
                                             out_channels)
        self.norm2 = Normalize(out_channels, num_groups=num_groups)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = noOrthoRegularizationConv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = noOrthoRegularizationConv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = noOrthoRegularizationConv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, x, temb=None):
        h = x
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)

        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]

        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h


class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        if self.with_conv:
            x = self.conv(x)
        return x


class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)

    def forward(self, x):
        if self.with_conv:
            pad = (0,1,0,1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x


class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512, num_groups=32):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = Normalize(in_channels, num_groups=num_groups)
        self.conv1 = torch.nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels,
                                             out_channels)
        self.norm2 = Normalize(out_channels, num_groups=num_groups)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, x, temb=None):
        h = x
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)

        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]

        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h


class LinAttnBlock(LinearAttention):
    """to match AttnBlock usage"""
    def __init__(self, in_channels):
        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)


class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w)
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_


def make_attn(in_channels, attn_type="vanilla"):
    assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        return AttnBlock(in_channels)
    elif attn_type == "none":
        return nn.Identity(in_channels)
    else:
        return LinAttnBlock(in_channels)


class Model(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = self.ch*4
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        self.use_timestep = use_timestep
        if self.use_timestep:
            # timestep embedding
            self.temb = nn.Module()
            self.temb.dense = nn.ModuleList([
                torch.nn.Linear(self.ch,
                                self.temb_ch),
                torch.nn.Linear(self.temb_ch,
                                self.temb_ch),
            ])

        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            skip_in = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                if i_block == self.num_res_blocks:
                    skip_in = ch*in_ch_mult[i_level]
                block.append(ResnetBlock(in_channels=block_in+skip_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x, t=None, context=None):
        #assert x.shape[2] == x.shape[3] == self.resolution
        if context is not None:
            # assume aligned context, cat along channel axis
            x = torch.cat((x, context), dim=1)
        if self.use_timestep:
            # timestep embedding
            assert t is not None
            temb = get_timestep_embedding(t, self.ch)
            temb = self.temb.dense[0](temb)
            temb = nonlinearity(temb)
            temb = self.temb.dense[1](temb)
        else:
            temb = None

        # downsampling
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](
                    torch.cat([h, hs.pop()], dim=1), temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h

    def get_last_layer(self):
        return self.conv_out.weight


class Encoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
                 **ignore_kwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = nn.ModuleList()
        # print(f"attention type: {attn_type}")
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        2*z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        # timestep embedding
        temb = None

        # downsampling
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h


class Decoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
                 attn_type="vanilla", **ignorekwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end
        self.tanh_out = tanh_out

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
        self.conv_in = torch.nn.Conv2d(z_channels,
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, z):
        #assert z.shape[1:] == self.z_shape[1:]
        self.last_z_shape = z.shape

        # timestep embedding
        temb = None

        # z to block_in
        h = self.conv_in(z)

        # middle
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        if self.tanh_out:
            h = torch.tanh(h)
        return h


class SimpleDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__()
        self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
                                     ResnetBlock(in_channels=in_channels,
                                                 out_channels=2 * in_channels,
                                                 temb_channels=0, dropout=0.0),
                                     ResnetBlock(in_channels=2 * in_channels,
                                                out_channels=4 * in_channels,
                                                temb_channels=0, dropout=0.0),
                                     ResnetBlock(in_channels=4 * in_channels,
                                                out_channels=2 * in_channels,
                                                temb_channels=0, dropout=0.0),
                                     nn.Conv2d(2*in_channels, in_channels, 1),
                                     Upsample(in_channels, with_conv=True)])
        # end
        self.norm_out = Normalize(in_channels)
        self.conv_out = torch.nn.Conv2d(in_channels,
                                        out_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        for i, layer in enumerate(self.model):
            if i in [1,2,3]:
                x = layer(x, None)
            else:
                x = layer(x)

        h = self.norm_out(x)
        h = nonlinearity(h)
        x = self.conv_out(h)
        return x


class UpsampleDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
                 ch_mult=(2,2), dropout=0.0, norm_num_groups=32):
        super().__init__()
        # upsampling
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        block_in = in_channels
        curr_res = resolution // 2 ** (self.num_resolutions - 1)
        self.res_blocks = nn.ModuleList()
        self.upsample_blocks = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            res_block = []
            block_out = int(ch * ch_mult[i_level])
            for i_block in range(self.num_res_blocks + 1):
                res_block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout,
                                         num_groups=norm_num_groups))
                block_in = block_out
            self.res_blocks.append(nn.ModuleList(res_block))
            if i_level != self.num_resolutions - 1:
                self.upsample_blocks.append(Upsample(block_in, True))
                curr_res = curr_res * 2

        # end
        self.norm_out = Normalize(block_in, num_groups=norm_num_groups)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        # upsampling
        h = x
        for k, i_level in enumerate(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = self.res_blocks[i_level][i_block](h, None)
            if i_level != self.num_resolutions - 1:
                h = self.upsample_blocks[k](h)
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h


class LatentRescaler(nn.Module):
    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
        super().__init__()
        # residual block, interpolate, residual block
        self.factor = factor
        self.conv_in = nn.Conv2d(in_channels,
                                 mid_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1)
        self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
                                                     out_channels=mid_channels,
                                                     temb_channels=0,
                                                     dropout=0.0) for _ in range(depth)])
        self.attn = AttnBlock(mid_channels)
        self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
                                                     out_channels=mid_channels,
                                                     temb_channels=0,
                                                     dropout=0.0) for _ in range(depth)])

        self.conv_out = nn.Conv2d(mid_channels,
                                  out_channels,
                                  kernel_size=1,
                                  )

    def forward(self, x):
        x = self.conv_in(x)
        for block in self.res_block1:
            x = block(x, None)
        x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
        x = self.attn(x)
        for block in self.res_block2:
            x = block(x, None)
        x = self.conv_out(x)
        return x


class MergedRescaleEncoder(nn.Module):
    def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True,
                 ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
        super().__init__()
        intermediate_chn = ch * ch_mult[-1]
        self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
                               z_channels=intermediate_chn, double_z=False, resolution=resolution,
                               attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
                               out_ch=None)
        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
                                       mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)

    def forward(self, x):
        x = self.encoder(x)
        x = self.rescaler(x)
        return x


class MergedRescaleDecoder(nn.Module):
    def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
                 dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
        super().__init__()
        tmp_chn = z_channels*ch_mult[-1]
        self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
                               resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
                               ch_mult=ch_mult, resolution=resolution, ch=ch)
        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
                                       out_channels=tmp_chn, depth=rescale_module_depth)

    def forward(self, x):
        x = self.rescaler(x)
        x = self.decoder(x)
        return x


class Upsampler(nn.Module):
    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
        super().__init__()
        assert out_size >= in_size
        num_blocks = int(np.log2(out_size//in_size))+1
        factor_up = 1.+ (out_size % in_size)
        print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
        self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
                                       out_channels=in_channels)
        self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
                               attn_resolutions=[], in_channels=None, ch=in_channels,
                               ch_mult=[ch_mult for _ in range(num_blocks)])

    def forward(self, x):
        x = self.rescaler(x)
        x = self.decoder(x)
        return x


class Resize(nn.Module):
    def __init__(self, in_channels=None, learned=False, mode="bilinear"):
        super().__init__()
        self.with_conv = learned
        self.mode = mode
        if self.with_conv:
            print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
            raise NotImplementedError()
            assert in_channels is not None
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=4,
                                        stride=2,
                                        padding=1)

    def forward(self, x, scale_factor=1.0):
        if scale_factor==1.0:
            return x
        else:
            x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
        return x

class FirstStagePostProcessor(nn.Module):

    def __init__(self, ch_mult:list, in_channels,
                 pretrained_model:nn.Module=None,
                 reshape=False,
                 n_channels=None,
                 dropout=0.,
                 pretrained_config=None):
        super().__init__()
        if pretrained_config is None:
            assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
            self.pretrained_model = pretrained_model
        else:
            assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
            self.instantiate_pretrained(pretrained_config)

        self.do_reshape = reshape

        if n_channels is None:
            n_channels = self.pretrained_model.encoder.ch

        self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
        self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
                            stride=1,padding=1)

        blocks = []
        downs = []
        ch_in = n_channels
        for m in ch_mult:
            blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
            ch_in = m * n_channels
            downs.append(Downsample(ch_in, with_conv=False))

        self.model = nn.ModuleList(blocks)
        self.downsampler = nn.ModuleList(downs)


    def instantiate_pretrained(self, config):
        model = instantiate_from_config(config)
        self.pretrained_model = model.eval()
        # self.pretrained_model.train = False
        for param in self.pretrained_model.parameters():
            param.requires_grad = False


    @torch.no_grad()
    def encode_with_pretrained(self,x):
        c = self.pretrained_model.encode(x)
        if isinstance(c, DiagonalGaussianDistribution):
            c = c.mode()
        return  c

    def forward(self,x):
        z_fs = self.encode_with_pretrained(x)
        z = self.proj_norm(z_fs)
        z = self.proj(z)
        z = nonlinearity(z)

        for submodel, downmodel in zip(self.model,self.downsampler):
            z = submodel(z,temb=None)
            z = downmodel(z)

        if self.do_reshape:
            z = rearrange(z,'b c h w -> b (h w) c')
        return z

class MuxEncoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
                 K, expand, mod,
                 **ignore_kwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        self.K = K
        self.expand = expand
        self.mod = mod

        if self.mod == 'nonlinear-expand':
            # for /data/minkyu/logs/2025-02-19T06-45-07_autoencoder_ffhq_kl_64x64x3_noatt_resdemux_upsample
            # self.rop = torch.nn.ModuleList([torch.nn.Conv2d(in_channels, in_channels*self.expand//2, kernel_size=7, stride=1, padding=3) for _ in range(self.K)])
            # self.rop2 = torch.nn.ModuleList([torch.nn.Conv2d(in_channels*self.expand//2, in_channels*self.expand, kernel_size=7, stride=1, padding=3) for _ in range(self.K)])
            # self.relu = torch.nn.ReLU()
            self.rop = torch.nn.ModuleList([
                torch.nn.Sequential(
                    noOrthoRegularizationConv2d(in_channels, in_channels*self.expand//2, kernel_size=7, stride=1, padding=3),
                    torch.nn.ReLU(),
                    noOrthoRegularizationConv2d(in_channels*self.expand//2, in_channels*self.expand, kernel_size=7, stride=1, padding=3),
                    torch.nn.ReLU()
                ) for _ in range(self.K)
            ])

            self.mux_proj = self.proj
        elif self.mod =='nonlinear-expand-one':
            # self.rop = torch.nn.ModuleList([torch.nn.Conv2d(in_channels, in_channels*self.expand//2, kernel_size=1, stride=1) for _ in range(self.K)])
            # self.rop2 = torch.nn.ModuleList([torch.nn.Conv2d(in_channels*self.expand//2, in_channels*self.expand, kernel_size=1, stride=1) for _ in range(self.K)])
            # self.relu = torch.nn.ReLU()
            self.rop = torch.nn.ModuleList([torch.nn.Sequential(noOrthoRegularizationConv2d(in_channels, in_channels*self.expand//2, kernel_size=1, stride=1),
                                            torch.nn.ReLU(),
                                            noOrthoRegularizationConv2d(in_channels*self.expand//2, in_channels*self.expand, kernel_size=1, stride=1),
                                            torch.nn.ReLU()) for _ in range(self.K)])
            self.mux_proj = self.proj
        elif self.mod == 'HRR':
            self.rop = noOrthoRegularizationConv1d(in_channels = self.K, out_channels = self.K, kernel_size = self.ch, padding = self.ch//2, padding_mode = 'circular', groups = self.K, bias = False)
            self.expand = 1
            self.mux_proj = self.pointwiseConvThroughChannels
        elif self.mod == 'MBAT':
            self.rop = noOrthoRegularizationConv2d(in_channels = self.ch*self.K, out_channels = self.ch*self.K, kernel_size = 1, groups = self.K, bias = False)
            self.expand = 1
            self.mux_proj = self.conv1x1Stack
        elif self.mod == 'one-res':
            self.rop = nn.ModuleList([ResnetBlock(in_channels=self.ch, out_channels=self.ch, temb_channels=self.temb_ch, dropout=dropout) for _ in range(self.K)])
            self.expand = 1
            self.mux_proj = self.resstack
        elif self.mod == 'two-res':
            self.rop = nn.ModuleList([torch.nn.Sequential(ResnetBlock(in_channels=self.ch, out_channels=self.ch, temb_channels=self.temb_ch, dropout=dropout),
                                    ResnetBlock(in_channels=self.ch, out_channels=self.ch, temb_channels=self.temb_ch, dropout=dropout)) for _ in range(self.K)])
            self.expand = 1
            self.mux_proj = self.resstack
        elif self.mod == 'one-res-in':
            self.rop = nn.ModuleList([ResnetBlock(in_channels=in_channels, out_channels = in_channels*self.expand, temb_channels=self.temb_ch, dropout=dropout, num_groups=1) for _ in range(self.K)])
            self.mux_proj = self.proj
        elif self.mod == 'nonlinear-expand-norm':
            self.rop = torch.nn.ModuleList([
                torch.nn.Sequential(
                    MuxConvBlock(in_channels = in_channels, out_channels = in_channels*self.expand//2, num_groups = 1),
                    MuxConvBlock(in_channels = in_channels*self.expand//2, out_channels = in_channels*self.expand, num_groups = 1),
                    Normalize(in_channels*self.expand, num_groups=1),
                ) for _ in range(self.K)
            ])
            self.mux_proj = self.proj
        elif self.mod == 'datamux':
            self.rop = torch.nn.ModuleList([
                torch.nn.Sequential(
                    MuxConvBlock(in_channels = in_channels, out_channels = 16, num_groups = 0, nonlinearity=False),
                    torch.nn.Tanh(),
                    MuxConvBlock(in_channels = 16, out_channels = 8, num_groups = 0, nonlinearity=False),
                    torch.nn.Tanh(),
                ) for _ in range(self.K)
            ])
            self.mux_proj = self.proj
        elif self.mod == 'my_mbat':
            self.rop = nn.ModuleList([
                torch.nn.Sequential(
                    MuxLinearBlock(in_channels = in_channels, out_channels = in_channels, num_groups = 1),
                    Normalize(in_channels, num_groups = 1),
                ) for _ in range(self.K)
            ])
            self.mux_proj = self.proj
            self.expand = 1
        elif self.mod == 'last-norm':
            self.rop = torch.nn.ModuleList([
                torch.nn.Sequential(
                    MuxConvBlock(in_channels = 2*z_channels if double_z else z_channels, out_channels = 4*z_channels if double_z else 2*z_channels, num_groups = 1),
                    MuxConvBlock(in_channels = 4*z_channels if double_z else 2*z_channels, out_channels = 2*z_channels if double_z else z_channels, num_groups = 1, nonlinearity = False),
                ) for _ in range(self.K)
            ])
            self.expand = 1
            self.mux_proj = self.proj
        elif self.mod == 'symmetry':
            self.rop = torch.nn.ModuleList([
                torch.nn.Sequential(
                    MuxConvBlock(in_channels = 32, out_channels = 64, num_groups = 32),
                    MuxConvBlock(in_channels = 64, out_channels = 128, num_groups = 32),
                    Normalize(128, num_groups=32),
                ) for _ in range(self.K)
            ])
            self.mux_proj = self.proj
            self.expand = 1
        elif self.mod == 'symmetry-res':
            self.rop = torch.nn.ModuleList([
                torch.nn.Sequential(
                    ResMuxBlock(in_channels=32, out_channels=64, temb_channels=self.temb_ch, dropout=dropout),
                    ResMuxBlock(in_channels=64, out_channels=128, temb_channels=self.temb_ch, dropout=dropout),
                ) for _ in range(self.K)
            ])
            self.mux_proj = self.proj
            self.expand = 1
        elif self.mod == 'symmetry-res-expand':
            self.rop = torch.nn.ModuleList([
                torch.nn.Sequential(
                    ResMuxBlock(in_channels=16, out_channels=32, temb_channels=self.temb_ch, dropout=dropout, num_groups=16),
                    ResMuxBlock(in_channels=32, out_channels=64, temb_channels=self.temb_ch, dropout=dropout),
                ) for _ in range(self.K)
            ])
            self.conv_expand = torch.nn.Conv2d(64, self.ch, kernel_size=3, stride=1, padding=1)
            self.mux_proj = self.proj
        elif self.mod =='compress-upsample' and self.num_resolutions == 3:
            self.rop = torch.nn.ModuleList([
                    UpsampleDecoder(in_channels=512, out_channels=512, ch=512, num_res_blocks=0, resolution=64, ch_mult=(2,2), dropout=0.0, norm_num_groups=32)
                    for _ in range(self.K)
            ])
            self.expand = 1
            self.mux_proj = self.proj
        elif self.mod =='compress-end' and self.num_resolutions == 3:
            self.rop = torch.nn.ModuleList([
                    UpsampleDecoder(in_channels=6, out_channels=6, ch=6, num_res_blocks=0, resolution=64, ch_mult=(2,2), dropout=0.0, norm_num_groups=1)
                    for _ in range(self.K)
            ])
            self.expand = 1
            self.mux_proj = self.proj
        elif self.mod == 'spatial-upsample' and self.num_resolutions == 3:
            self.rop = torch.nn.Sequential(
                    noOrthoRegularizationConv2d(512, 1024, kernel_size=7, stride=1, padding=3),
                    torch.nn.ReLU(),
                    noOrthoRegularizationConv2d(1024, 512, kernel_size=7, stride=1, padding=3),
                    torch.nn.ReLU()
                )
            self.expand = 1
            self.mux_proj = self.spatial
        elif self.mod == 'spatial-end' and self.num_resolutions == 3:
            self.rop = torch.nn.Sequential(
                    noOrthoRegularizationConv2d(8, 16, kernel_size=7, stride=1, padding=3),
                    torch.nn.ReLU(),
                    noOrthoRegularizationConv2d(16, 8, kernel_size=7, stride=1, padding=3),
                    torch.nn.ReLU()
                )
            self.expand = 1
            self.mux_proj = self.spatial
        elif self.mod == 'chcompress-upsample' and self.num_resolutions == 3:
            self.rop = UpsampleDecoder(in_channels=512*self.K, out_channels=512, ch=512, num_res_blocks=0, resolution=64, ch_mult=(4,2), dropout=0.0, norm_num_groups=32)
            self.expand = 1
            self.mux_proj = self.ch_compress
        elif self.mod == 'chcompress-end' and self.num_resolutions == 3:
            self.rop = UpsampleDecoder(in_channels=6*self.K, out_channels=6, ch=6, num_res_blocks=0, resolution=64, ch_mult=(4,2), dropout=0.0, norm_num_groups=1)
            self.expand = 1
            self.mux_proj = self.ch_compress


        else:
            assert False, "Unknown mod"


        # downsampling
        if self.mod =='datamux':
            self.conv_in = torch.nn.Conv2d(8,
                                        self.ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
        elif self.mod == 'symmetry' or self.mod == 'symmetry-res':
            self.conv_in = torch.nn.Conv2d(3,
                                        32,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
        elif self.mod == 'symmetry-res-expand':
            self.conv_in = torch.nn.Conv2d(in_channels,
                                        16,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
        else:
            self.conv_in = torch.nn.Conv2d(in_channels*self.expand,
                                        self.ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = nn.ModuleList()
        if (self.mod in ["compress-upsample", "compress-end", "spatial-upsample", "spatial-end", "chcompress-upsample", "chcompress-end"]) and self.num_resolutions == 3:
            self.num_resolutions = self.num_resolutions+1
            ch_mult = ch_mult + (4,)
            in_ch_mult = in_ch_mult + (4,)
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        2*z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def binding_regularization(self):
        r"""computes regularization term with the goal of orthonormalizing binding keys"""
        if self.mod == "HRR":
            weights = self.rop.weight.reshape(self.K, -1)
            normed_weights = nn.functional.normalize(weights, p=2, dim=1)
            # returns average squared cosine of angle between any two pairs of vectors + average squared error of norm to length 1
            avg_abs_inner_product = (torch.norm(torch.triu(torch.matmul(normed_weights, torch.transpose(normed_weights, 0, 1)), diagonal=1)))**2 / (self.K*(self.K-1)/2) if self.K > 1 else 0
            avg_norm_delta = (torch.norm(torch.norm(weights, dim=1) - 1))**2 / self.K
            return avg_abs_inner_product , avg_norm_delta
        # elif self.mod == "MBAT":
        #     # MBAT uses group conv2d with kernel_size=1. 
        #     # shape of channelLinear.weight -> [out_channels, in_channels, kH, kW]
        #     # here out_channels = K*channels, in_channels = K*channels, kH=kW=1
        #     # Because 'groups=K', each group is [channels, channels, 1, 1].
            
        #     weights = self.rop.weight  # shape = [K*channels, K*channels, 1, 1]

        #     # Collect row vectors from each group's (channels x channels) block.
        #     row_list = []
        #     for g in range(self.K):
        #         # Extract the (channels x channels) block for group g
        #         block = weights[
        #             g*self.ch : (g+1)*self.ch,
        #             g*self.ch : (g+1)*self.ch,
        #             0,  # kH = 1
        #             0   # kW = 1
        #         ]  # shape = (channels, channels)
                
        #         # Each row in this block is one "vector" (length = channels)
        #         # We'll stack them up
        #         row_list.append(block)
            
        #     # shape = (K*channels, channels)
        #     all_rows = torch.cat(row_list, dim=0)
            
        #     # Now regularize these row vectors the same way we do for HRR:
        #     normed_rows = nn.functional.normalize(all_rows, p=2, dim=1)
            
        #     if all_rows.shape[0] > 1:
        #         # average squared cosine among distinct pairs of row vectors
        #         avg_abs_inner_product = (
        #             torch.norm(
        #                 torch.triu(torch.matmul(normed_rows, normed_rows.t()), diagonal=1)
        #             ) ** 2
        #             / (all_rows.shape[0] * (all_rows.shape[0] - 1) / 2)
        #         )
        #     else:
        #         avg_abs_inner_product = 0
            
        #     # measure how much each row's norm deviates from 1
        #     avg_norm_delta = (
        #         torch.norm(torch.norm(all_rows, dim=1) - 1) ** 2 / all_rows.shape[0]
        #     )
            
        #     return avg_abs_inner_product , avg_norm_delta
        elif self.mod == "nonlinear-expand" or self.mod == "nonlinear-expand-one":
            #
            # Suppose we interpret each module in self.custom as one "key."
            # We gather all its convolutional weights into a single vector,
            # then repeat that for all modules.
            #

            # Each entry in self.custom is a Sequential containing two (or more) conv layers.
            # We'll flatten out all conv-layer parameters in each module into one vector.
            module_vectors = []
            for mod in self.rop:
                # Gather the conv weights in this module
                flattened_weights = []
                for layer in mod:
                    if isinstance(layer, nn.Conv2d):
                        # Flatten [outC, inC, kH, kW] -> a single 1D vector
                        flattened_weights.append(layer.weight.view(-1))
                # Concatenate all conv weights in this module into one vector
                module_vector = torch.cat(flattened_weights, dim=0)
                module_vectors.append(module_vector)

            # Now we have one vector per module; stack them into a matrix [K, vector_length]
            # where K = number of modules = self.K, typically.
            all_vectors = torch.stack(module_vectors, dim=0)  # shape: [K, length]

            # Normalize each row (module) to length 1
            normed_rows = nn.functional.normalize(all_vectors, p=2, dim=1)

            # --- 1) Encourage orthogonality (average squared inner product off-diagonal) ---
            if self.K > 1:
                # triu(...) picks upper triangular part above the main diagonal
                inner_products = torch.matmul(normed_rows, normed_rows.t())  # [K, K]
                off_diag = torch.triu(inner_products, diagonal=1)
                avg_abs_inner_product = (torch.norm(off_diag) ** 2) / (
                    self.K * (self.K - 1) / 2
                )
            else:
                avg_abs_inner_product = 0.0

            # --- 2) Encourage unit norm (average squared deviation of norms from 1) ---
            # L2-norm along each row, subtract 1, then square
            norms = torch.norm(all_vectors, dim=1)  # shape: [K]
            avg_norm_delta = (
                torch.norm(norms - 1.0) ** 2
            ) / self.K
            return avg_abs_inner_product , avg_norm_delta
        elif self.mod == "one-res" or self.mod == "two-res" or self.mod == "one-res-in":
            module_vectors = []
            for mod in self.rop:
                flattened_weights = []
                if isinstance(mod, nn.Sequential):
                    for layer in mod:
                        if isinstance(layer, ResnetBlock):  # Extract weights from each ResnetBlock
                            for sub_layer in [layer.conv1, layer.conv2]:  # conv1 and conv2 exist inside ResnetBlock
                                flattened_weights.append(sub_layer.weight.view(-1))
                elif isinstance(mod, ResnetBlock):
                    for sub_layer in [mod.conv1, mod.conv2]:
                        flattened_weights.append(sub_layer.weight.view(-1))

                if len(flattened_weights) > 0:
                    module_vector = torch.cat(flattened_weights, dim=0)
                    module_vectors.append(module_vector)

            if len(module_vectors) == 0:
                return 0  # No valid weights found, return zero regularization

            # Stack vectors into a matrix [K, vector_length]
            all_vectors = torch.stack(module_vectors, dim=0)  # shape: [K, length]

            # Normalize each row (module) to length 1
            normed_rows = nn.functional.normalize(all_vectors, p=2, dim=1)

            # --- 1) Encourage orthogonality (average squared inner product off-diagonal) ---
            if self.K > 1:
                inner_products = torch.matmul(normed_rows, normed_rows.t())  # [K, K]
                off_diag = torch.triu(inner_products, diagonal=1)
                avg_abs_inner_product = (torch.norm(off_diag) ** 2) / (
                    self.K * (self.K - 1) / 2
                )
            else:
                avg_abs_inner_product = 0.0

            # --- 2) Encourage unit norm (average squared deviation of norms from 1) ---
            norms = torch.norm(all_vectors, dim=1)  # shape: [K]
            avg_norm_delta = (
                torch.norm(norms - 1.0) ** 2
            ) / self.K

            return avg_abs_inner_product , avg_norm_delta
        elif self.mod == "symmetry-res" or self.mod == "symmetry-res-expand":
            module_vectors = []
            for mod in self.rop:
                flattened_weights = []
                if isinstance(mod, nn.Sequential):
                    for layer in mod:
                        if isinstance(layer, ResMuxBlock):  # Extract weights from each ResnetBlock
                            for sub_layer in [layer.conv1, layer.conv2]:  # conv1 and conv2 exist inside ResnetBlock
                                flattened_weights.append(sub_layer.weight.view(-1))
                elif isinstance(mod, ResnetBlock):
                    for sub_layer in [mod.conv1, mod.conv2]:
                        flattened_weights.append(sub_layer.weight.view(-1))

                if len(flattened_weights) > 0:
                    module_vector = torch.cat(flattened_weights, dim=0)
                    module_vectors.append(module_vector)

            if len(module_vectors) == 0:
                return 0  # No valid weights found, return zero regularization

            # Stack vectors into a matrix [K, vector_length]
            all_vectors = torch.stack(module_vectors, dim=0)  # shape: [K, length]

            # Normalize each row (module) to length 1
            normed_rows = nn.functional.normalize(all_vectors, p=2, dim=1)

            # --- 1) Encourage orthogonality (average squared inner product off-diagonal) ---
            if self.K > 1:
                inner_products = torch.matmul(normed_rows, normed_rows.t())  # [K, K]
                off_diag = torch.triu(inner_products, diagonal=1)
                avg_abs_inner_product = (torch.norm(off_diag) ** 2) / (
                    self.K * (self.K - 1) / 2
                )
            else:
                avg_abs_inner_product = 0.0

            # --- 2) Encourage unit norm (average squared deviation of norms from 1) ---
            norms = torch.norm(all_vectors, dim=1)  # shape: [K]
            avg_norm_delta = (
                torch.norm(norms - 1.0) ** 2
            ) / self.K

            return avg_abs_inner_product , avg_norm_delta
        elif self.mod == "my_mbat":
            #
            # Each item in self.rop is a Sequential:
            #  [
            #    MuxLinearBlock(in_channels=..., out_channels=..., num_groups=1),
            #    Normalize(in_channels, num_groups=1)
            #  ]
            #
            # We'll gather .lin.weight from each MuxLinearBlock and flatten it.
            #
            module_vectors = []
            for mod in self.rop:
                flattened_weights = []
                for layer in mod:
                    if isinstance(layer, MuxLinearBlock):
                        # Flatten the linear layer weights [out_features, in_features]
                        flattened_weights.append(layer.lin.weight.view(-1))
                if len(flattened_weights) > 0:
                    module_vector = torch.cat(flattened_weights, dim=0)
                    module_vectors.append(module_vector)

            if len(module_vectors) == 0:
                return 0  # No valid weights, no penalty

            # Stack vectors into a matrix [K, vector_length]
            all_vectors = torch.stack(module_vectors, dim=0)  # shape: [K, length]

            # Normalize each row (module) to length 1
            normed_rows = nn.functional.normalize(all_vectors, p=2, dim=1)

            # 1) Encourage orthogonality (average squared inner product off-diagonal)
            if self.K > 1:
                inner_products = torch.matmul(normed_rows, normed_rows.t())  # [K, K]
                off_diag = torch.triu(inner_products, diagonal=1)
                avg_abs_inner_product = (torch.norm(off_diag) ** 2) / (
                    self.K * (self.K - 1) / 2
                )
            else:
                avg_abs_inner_product = 0.0

            # 2) Encourage unit norm (average squared deviation of norms from 1)
            norms = torch.norm(all_vectors, dim=1)  # shape: [K]
            avg_norm_delta = (torch.norm(norms - 1.0) ** 2) / self.K

            return avg_abs_inner_product , avg_norm_delta

        elif self.mod == "datamux" or self.mod == "nonlinear-expand-norm" or self.mod == "last-norm" or self.mod == "symmetry":
            #
            # Each item in self.rop is a Sequential of 2 MuxConvBlock:
            #   [
            #     MuxConvBlock(in_channels, out_channels=16, num_groups=0),
            #     MuxConvBlock(in_channels=16, out_channels=8, num_groups=0)
            #   ]
            #
            # We'll gather each .conv.weight from the MuxConvBlock layers and flatten it.
            #
            module_vectors = []
            for mod in self.rop:
                flattened_weights = []
                for layer in mod:
                    if isinstance(layer, MuxConvBlock):
                        # Flatten [out_channels, in_channels, kH, kW]
                        flattened_weights.append(layer.conv.weight.view(-1))
                if len(flattened_weights) > 0:
                    module_vector = torch.cat(flattened_weights, dim=0)
                    module_vectors.append(module_vector)

            if len(module_vectors) == 0:
                return 0  # No valid weights, no penalty

            # Stack vectors into a matrix [K, vector_length]
            all_vectors = torch.stack(module_vectors, dim=0)  # shape: [K, length]

            # Normalize each row (module) to length 1
            normed_rows = nn.functional.normalize(all_vectors, p=2, dim=1)

            # 1) Encourage orthogonality (average squared inner product off-diagonal)
            if self.K > 1:
                inner_products = torch.matmul(normed_rows, normed_rows.t())  # [K, K]
                off_diag = torch.triu(inner_products, diagonal=1)
                avg_abs_inner_product = (torch.norm(off_diag) ** 2) / (
                    self.K * (self.K - 1) / 2
                )
            else:
                avg_abs_inner_product = 0.0

            # 2) Encourage unit norm (average squared deviation of norms from 1)
            norms = torch.norm(all_vectors, dim=1)  # shape: [K]
            avg_norm_delta = (torch.norm(norms - 1.0) ** 2) / self.K

            return avg_abs_inner_product , avg_norm_delta

        elif self.mod in ["compress-upsample", "compress-end"]:
            # For compress-upsample and compress-end, self.rop is a ModuleList of UpsampleDecoder.
            # We'll extract the conv_out weights from each UpsampleDecoder, flatten them,
            # and then compute two regularization metrics:
            # 1) avg_abs_inner_product: the average squared inner product (off-diagonal) among normalized binding vectors.
            # 2) avg_norm_delta: the average squared deviation of each binding vector's norm from 1.
            
            weight_list = []
            for decoder in self.rop:
                # decoder.conv_out.weight has shape [out_channels, in_channels, kH, kW]
                # We flatten it into a 1D vector.
                w = decoder.conv_out.weight.view(-1)
                weight_list.append(w)
            
            if len(weight_list) == 0:
                return 0, 0  # No weights to regularize
            
            # Stack the weight vectors into a matrix of shape [K, vector_length]
            all_weights = torch.stack(weight_list, dim=0)
            
            # Normalize each vector to have unit norm
            normed_weights = nn.functional.normalize(all_weights, p=2, dim=1)
            
            # Compute the average squared inner product between distinct pairs (off-diagonal elements)
            if all_weights.shape[0] > 1:
                inner_products = torch.matmul(normed_weights, normed_weights.t())  # [K, K]
                off_diag = torch.triu(inner_products, diagonal=1)
                num_pairs = all_weights.shape[0] * (all_weights.shape[0] - 1) / 2
                avg_abs_inner_product = (torch.norm(off_diag) ** 2) / num_pairs
            else:
                avg_abs_inner_product = 0.0
            
            # Compute the average squared deviation of the norms from 1
            norms = torch.norm(all_weights, dim=1)
            avg_norm_delta = (torch.norm(norms - 1.0) ** 2) / all_weights.shape[0]
            
            return avg_abs_inner_product, avg_norm_delta


        else:
            return 0, 0



    def proj(self, imgs):
        device = imgs.device
        comp_imgs = torch.zeros(imgs.shape[1], imgs.shape[2]*self.expand, 
                                imgs.shape[3], imgs.shape[4]).to(device) # [K, batch/K, C, H, W]
        if self.mod == 'nonlinear-expand' or self.mod == 'nonlinear-expand-norm' or self.mod == 'last-norm':
            for i in range(self.K):
                self.rop[i].to(device)
                comp_imgs += self.rop[i](imgs[i])
                # tmp = self.rop[i](imgs[i])
                # tmp = self.relu(tmp)
                # tmp = self.rop2[i](tmp)
                # comp_imgs += self.relu(tmp)
        elif self.mod == 'nonlinear-expand-one' or self.mod == 'one-res-in':
            for i in range(self.K):
                self.rop[i].to(device)
                comp_imgs += self.rop[i](imgs[i])
        elif self.mod == 'datamux':
            comp_imgs = torch.zeros(imgs.shape[1], 8, imgs.shape[3], imgs.shape[4]).to(device)
            for i in range(self.K):
                comp_imgs += self.rop[i](imgs[i])
        elif self.mod == 'my_mbat':
            k, b_div_k, c, h, w = imgs.shape
            img_reshaped = imgs.permute(0,1,3,4,2).reshape(k, -1, c).contiguous()
            for i in range(self.K):
                comp_imgs += self.rop[i](img_reshaped[i]).reshape(b_div_k, h, w, c*self.expand).permute(0,3,1,2).contiguous()
        elif self.mod == 'symmetry' or self.mod == 'symmetry-res':
            comp_imgs = torch.zeros(imgs.shape[1], 128, imgs.shape[3], imgs.shape[4]).to(device)
            for i in range(self.K):
                comp_imgs += self.rop[i](imgs[i])
        elif self.mod == 'symmetry-res-expand':
            temp_imgs = torch.zeros(imgs.shape[1], 64, imgs.shape[3], imgs.shape[4]).to(device)
            for i in range(self.K):
                temp_imgs += self.rop[i](imgs[i])
            comp_imgs = self.conv_expand(temp_imgs/self.K)
        elif self.mod == 'compress-upsample' or self.mod == 'compress-end':
            comp_imgs = torch.zeros(imgs.shape[1], imgs.shape[2], imgs.shape[3]*2, imgs.shape[4]*2).to(device)
            for i in range(self.K):
                comp_imgs += self.rop[i](imgs[i])
        else:
            assert False, 'Not implemented'
        return comp_imgs

    def pointwiseConvThroughChannels(self, x: Tensor) -> Tensor:
        r"""implements convolution through channels, i.e. pixel-wise HRR binding
        Args:
            x: image tensor of size (K, N/K, C, H, W)
        Returns:
            image tensor of size (K, N/K, C, H, W)
        """
        x = x.permute(1,3,4,0,2) # yields (N/K, H, W, K, C)
        old_shape = x.shape
        x = x.reshape(x.shape[0]*x.shape[1]*x.shape[2], x.shape[3], x.shape[4]) # yields (N/K*H*W, K, C)
        x = self.rop(x)
        x = x[:, :, 1:] # circular padding is added on both sides, hence at the end (due to vectors being even), one channel was duplicated which has to be removed
        x = x.reshape(old_shape) # yields (N/K, H, W, K, C)
        x = x.permute(3,0,4,1,2) # yields (K, N/K, C, H, W)
        return x

    def conv1x1Stack(self, x:Tensor) -> Tensor:
        r"""implements pixel-wise MBAT binding
        Args:
            x: image tensor of size (K, N/K, C, H, W)
        Returns:
            image tensor of size (K, N/K, C, H, W)
        """
        x = x.permute(1,0,2,3,4) # (N/K, K, C, H, W)
        old_shape = x.shape
        x = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3], x.shape[4]) # (N/K, K * C, H, W), channels are stacked on top
        x = self.rop(x)
        x = x.reshape(old_shape) # (N/K, K, C, H, W)
        x = x.permute(1,0,2,3,4) # (K, N/K, C, H, W)
        return x

    def resstack(self, x:Tensor) -> Tensor:
        r"""implements pixel-wise Resnet binding
        Args:
            x: image tensor of size (K, N/K, C, H, W)
        Returns:
            image tensor of size (K, N/K, C, H, W)
        """
        new_x = torch.zeros_like(x).to(x.device)
        for i in range(self.K):
            # x[:,i] = self.rop[i](x[:,i])
            new_x[i] = self.rop[i](x[i])
            # x[i] = self.rop[i](x[i])
        return new_x

    def spatial(self, x: Tensor) -> Tensor:
        """
        Rearranges 4 images into a 2x2 interleaved grid.

        Args:
            x: Tensor of shape (4, BS//4, C, H, W), where BS must be divisible by 4 and K = 4

        Returns:
            Tensor of shape (BS//4, C, H*2, W*2)
        """
        K, B_div_k, C, H, W = x.shape
        if K != 4:
            raise ValueError(f"Expected K = 4, but got K = {K}")
        x00 = x[0]  # top-left
        x01 = x[1]  # top-right
        x10 = x[2]  # bottom-left
        x11 = x[3]  # bottom-right

        # Initialize final output
        out = torch.zeros(B_div_k, C, H * 2, W * 2, device=x.device, dtype=x.dtype)

        # Assign to each spatial position
        out[:, :, 0::2, 0::2] = x00  # even rows, even cols
        out[:, :, 0::2, 1::2] = x01  # even rows, odd cols
        out[:, :, 1::2, 0::2] = x10  # odd rows, even cols
        out[:, :, 1::2, 1::2] = x11  # odd rows, odd cols

        out = self.rop(out)

        return out
    
    def ch_compress(self, x: Tensor) -> Tensor:
        """
        Compresses the channels of 8 images into 1 image.

        Args:
            x: Tensor of shape (8, BS//8, C, H, W), where BS must be divisible by 8 and K = 8

        Returns:
            Tensor of shape (BS//8, C, H, W)
        """
        K, B_div_k, C, H, W = x.shape
        # if K != 8:
        #     raise ValueError(f"Expected K = 8, but got K = {K}")
        
        # concatenate along channel dimension
        x_reshape = x.permute(1, 0, 2, 3, 4).reshape(B_div_k, -1, H, W)
        out = self.rop(x_reshape)

        return out



    def forward(self, x):
        # timestep embedding
        temb = None

        # print(f'x.shape: {x.shape}')
        if self.mod in ['nonlinear-expand', 'nonlinear-expand-one', 'nonlinear-expand-norm', 'one-res-in', 'datamux', 'my_mbat']:
            BS, C, H, W = x.shape
            try:
                x = x.reshape(self.K, -1, C, H, W).contiguous()
            except:
                assert False, f'x.shape: {x.shape} is not compatible with K: {self.K}'
            x = self.mux_proj(x)/self.K
            hs = [self.conv_in(x)]
        elif self.mod == 'HRR' or self.mod == 'MBAT':
            x = self.conv_in(x)
            BS, C, H, W = x.shape
            x = x.reshape(self.K, -1, C, H, W).contiguous()
            x = self.mux_proj(x)
            x = torch.mean(x, dim=0)
            hs = [x]
        elif self.mod == 'one-res' or self.mod == 'two-res':
            x = self.conv_in(x)
            BS, C, H, W = x.shape
            x = x.reshape( self.K, -1, C, H, W).contiguous()
            x = self.mux_proj(x)
            x = torch.mean(x, dim=1)
            hs = [x]
        elif self.mod in ['last-norm', 'compress-upsample', 'compress-end', 'spatial-upsample', 'spatial-end', 'chcompress-upsample', 'chcompress-end']:
            x = self.conv_in(x)
            hs = [x]
        elif self.mod == 'symmetry' or self.mod == 'symmetry-res':
            x = self.conv_in(x)
            BS, C, H, W = x.shape
            x = x.reshape(self.K, -1, C, H, W).contiguous()
            x = self.mux_proj(x)/self.K
            hs = [x] #여기서 다 더하고 바로 넘겨주는게 문젠가 싶기는 하네 conv하나정도 통과시켜야하나.. 예를 들어 3->16->32->64->128로 해서 처음이랑 마지막을 공용으로 써야하나..
        elif self.mod == 'symmetry-res-expand':
            x = self.conv_in(x)
            BS, C, H, W = x.shape
            x = x.reshape(self.K, -1, C, H, W).contiguous()
            x = self.mux_proj(x)
            hs = [x]

        # print(f'x.shape: {x.shape}')
        # downsampling
        
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # middle
        h = hs[-1]
        if self.mod == 'compress-upsample':
            BS, C, H, W = h.shape
            h = h.reshape(self.K, -1, C, H, W).contiguous()
            h = self.mux_proj(h)/self.K
        elif self.mod == 'spatial-upsample':
            BS, C, H, W = h.shape
            h = h.reshape(self.K, -1, C, H, W).contiguous()
            h = self.mux_proj(h)
        elif self.mod == 'chcompress-upsample':
            BS, C, H, W = h.shape
            h = h.reshape(self.K, -1, C, H, W).contiguous()
            h = self.mux_proj(h)

        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        if self.mod == 'last-norm':
            _, C, H, W = h.shape
            try:
                h = h.reshape(self.K, -1, C, H, W).contiguous()
            except:
                assert False, f'x.shape: {x.shape} is not compatible with K: {self.K}'
            h = self.mux_proj(h) / self.K
        elif self.mod == 'compress-end':
            _, C, H, W = h.shape
            h = h.reshape(self.K, -1, C, H, W).contiguous()
            h = self.mux_proj(h) / self.K
        elif self.mod == 'spatial-end':
            _, C, H, W = h.shape
            h = h.reshape(self.K, -1, C, H, W).contiguous()
            h = self.mux_proj(h)
        elif self.mod == 'chcompress-end':
            _, C, H, W = h.shape
            h = h.reshape(self.K, -1, C, H, W).contiguous()
            h = self.mux_proj(h)
        return h


class MuxDecoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
                 K, demux, demux_mod, expand, decrease_ch=16,
                 attn_type="vanilla", **ignorekwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end
        self.tanh_out = tanh_out
        self.relu = torch.nn.ReLU()

        self.K = K
        self.demux = demux
        self.demux_mod = demux_mod
        self.decrease_ch = decrease_ch

        if self.demux == 'before':
            self.demux_dim = z_channels
        elif self.demux == 'conv_in' or self.demux == 'mid_1' or self.demux == 'mid_2':
            self.demux_dim = ch*ch_mult[self.num_resolutions-1]
        elif self.demux == 'upsample':
            self.demux_dim = ch*ch_mult[0]
        elif self.demux == 'end':
            self.demux_dim = out_ch

        print(f"demux_dim: {self.demux_dim}")

        if self.demux_mod == 'channel-one':
            self.channel_one = torch.nn.ModuleList([noOrthoRegularizationLinear(self.demux_dim, self.demux_dim) for _ in range(self.K)])
        elif self.demux_mod == 'channel-conv':
            self.channel_conv = torch.nn.ModuleList([torch.nn.Sequential(noOrthoRegularizationConv2d(self.demux_dim, self.demux_dim, kernel_size=1, stride=1),
                                                    torch.nn.ReLU(),) for _ in range(self.K)])
        elif self.demux_mod == 'K-Conv':
            self.kconv = torch.nn.ModuleList([torch.nn.Sequential(
                                            # torch.nn.Conv2d(self.demux_dim, self.demux_dim, kernel_size=1, stride=1, padding=3),
                                            noOrthoRegularizationConv2d(self.demux_dim, self.demux_dim, kernel_size=7, stride=1, padding=3),
                                            torch.nn.ReLU(),
                                            noOrthoRegularizationConv2d(self.demux_dim, self.demux_dim, kernel_size=7, stride=1, padding=3),
                                            torch.nn.ReLU()
                                            ) for _ in range(self.K)])
        elif self.demux_mod == 'one-res':
            self.res = nn.ModuleList([ResnetBlock(in_channels=self.demux_dim, out_channels=self.demux_dim, temb_channels=self.temb_ch, dropout=dropout) for _ in range(self.K)])
        elif self.demux_mod == 'two-res':
            self.res = nn.ModuleList([torch.nn.Sequential(ResnetBlock(in_channels=self.demux_dim, out_channels=self.demux_dim, temb_channels=self.temb_ch, dropout=dropout),
                                    ResnetBlock(in_channels=self.demux_dim, out_channels=self.demux_dim, temb_channels=self.temb_ch, dropout=dropout)) for _ in range(self.K)])
        elif self.demux_mod == 'decrease_ch':
            self.decrease_conv = noOrthoRegularizationConv2d(self.demux_dim, self.decrease_ch, kernel_size=3, stride=1, padding=1)
            self.res = nn.ModuleList([torch.nn.Sequential(noOrthoRegularizationLinear(32, 32),
            torch.nn.ReLU(),) for _ in range(self.K)])
        elif self.demux_mod == 'symmetry':
            self.res = torch.nn.ModuleList([
                torch.nn.Sequential(
                    MuxConvBlock(in_channels = self.demux_dim, out_channels = 64, num_groups = 32),
                    MuxConvBlock(in_channels = 64, out_channels = 32, num_groups = 32),
                ) for _ in range(self.K)
            ])
        elif self.demux_mod == 'symmetry-res':
            self.res = torch.nn.ModuleList([
                torch.nn.Sequential(
                    ResMuxBlock(in_channels=self.demux_dim, out_channels=64, temb_channels=self.temb_ch, dropout=dropout),
                    ResMuxBlock(in_channels=64, out_channels=32, temb_channels=self.temb_ch, dropout=dropout),
                ) for _ in range(self.K)
            ])
        elif self.demux_mod == 'symmetry-res-expand':
            self.conv_expand = torch.nn.Conv2d(self.demux_dim, 64, kernel_size=3, stride=1, padding=1)
            self.res = torch.nn.ModuleList([
                torch.nn.Sequential(
                    ResMuxBlock(in_channels=64, out_channels=32, temb_channels=self.temb_ch, dropout=dropout, num_groups=32),
                    ResMuxBlock(in_channels=32, out_channels=16, temb_channels=self.temb_ch, dropout=dropout, num_groups=16),
                ) for _ in range(self.K)
            ])
        else:
            assert False, 'Not implemented'

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
        self.conv_in = torch.nn.Conv2d(z_channels,
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        
        if self.demux_mod == 'decrease_ch':
            self.norm_out = Normalize(self.decrease_ch, num_groups=self.decrease_ch//4)
            self.conv_out = torch.nn.Conv2d(self.decrease_ch,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
        elif self.demux_mod == 'symmetry' or self.demux_mod == 'symmetry-res':
            self.norm_out = Normalize(32, num_groups=32)
            self.conv_out = torch.nn.Conv2d(32,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
        elif self.demux_mod == 'symmetry-res-expand':
            self.norm_out = Normalize(16, num_groups=16)
            self.conv_out = torch.nn.Conv2d(16,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
        else:
            self.norm_out = Normalize(block_in)
            self.conv_out = torch.nn.Conv2d(block_in,
                                            out_ch,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)

    def deproj(self, x):
        if self.demux_mod == 'channel-one':
            b_div_k, c, h, w = x.shape
            x_list = []
            for i in range(self.K):
                x_reshaped = x.permute(0, 2, 3, 1).reshape(-1, c).contiguous()
                x_reshaped = self.relu(self.channel_one[i](x_reshaped))
                x_reshaped = x_reshaped.reshape(b_div_k, h, w, c).permute(0, 3, 1, 2).contiguous()
                x_list.append(x_reshaped)
            x = torch.cat(x_list, dim=0)
        elif self.demux_mod == 'channel-conv':
            b_div_k, c, h, w = x.shape
            x_list = []
            for i in range(self.K):
                self.channel_conv[i].to(x.device)
                x_reshaped = self.channel_conv[i](x)
                x_list.append(x_reshaped)
            x = torch.cat(x_list, dim=0)
        elif self.demux_mod == 'K-Conv':
            b_div_k, c, h, w = x.shape
            x_list = []
            for i in range(self.K):
                self.kconv[i].to(x.device)
                x_reshaped = self.kconv[i](x)
                x_list.append(x_reshaped)
            x = torch.cat(x_list, dim=0)
        elif self.demux_mod == 'one-res' or self.demux_mod == 'two-res':
            x_list = []
            for i in range(self.K):
                x_reshaped = self.res[i](x)
                x_list.append(x_reshaped)
            x = torch.cat(x_list, dim=0)
        elif self.demux_mod == 'decrease_ch':
            x = self.decrease_conv(x)
            b_div_k, c, h, w = x.shape
            x_list = []
            for i in range(self.K):
                x_reshaped = x.permute(0, 2, 3, 1).reshape(-1, c).contiguous()
                x_reshaped = self.relu(self.res[i](x_reshaped))
                x_reshaped = x_reshaped.reshape(b_div_k, h, w, c).permute(0, 3, 1, 2).contiguous()
                x_list.append(x_reshaped)
            x = torch.cat(x_list, dim=0)
        elif self.demux_mod == 'symmetry' or self.demux_mod == 'symmetry-res':
            b_div_k, c, h, w = x.shape
            x_list = []
            for i in range(self.K):
                self.res[i].to(x.device)
                x_reshaped = self.res[i](x)
                x_list.append(x_reshaped)
            x = torch.cat(x_list, dim=0)
        elif self.demux_mod == 'symmetry-res-expand':
            x = self.conv_expand(x)
            x_list = []
            for i in range(self.K):
                x_reshaped = self.res[i](x)
                x_list.append(x_reshaped)
            x = torch.cat(x_list, dim=0)
        else:
            assert False, "Invalid demux_mod type."
        return x

    def forward(self, z):
        #assert z.shape[1:] == self.z_shape[1:]
        self.last_z_shape = z.shape

        # print(f'z.shape: {z.shape}')
        # timestep embedding
        temb = None

        if self.demux == 'before':
            z = self.deproj(z)
        else:
            pass
        # z to block_in
        h = self.conv_in(z)

        if self.demux == 'conv_in':
            h = self.deproj(h)
        else:
            pass
        # middle
        h = self.mid.block_1(h, temb)
        if self.demux == 'mid_1':
            h = self.deproj(h)
        else:
            pass
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)
        if self.demux == 'mid_2':
            h = self.deproj(h)
        else:
            pass

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)
        if self.demux == 'upsample':
            h = self.deproj(h)
        else:
            pass

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        if self.demux == 'end':
            h = self.deproj(h)
        else:
            pass

        if self.tanh_out:
            h = torch.tanh(h)

        # print(f'h.shape: {h.shape}')
        return h

class ResMuxDecoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
                 K, demux, demux_mod,
                 attn_type="vanilla", **ignorekwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end
        self.tanh_out = tanh_out
        self.relu = torch.nn.ReLU()

        self.K = K
        self.demux = demux
        self.demux_mod = demux_mod

        if self.demux == 'before':
            self.demux_dim = z_channels
        elif self.demux == 'conv_in' or self.demux == 'mid_1' or self.demux == 'mid_2':
            self.demux_dim = ch*ch_mult[self.num_resolutions-1]
        elif self.demux == 'upsample':
            self.demux_dim = ch*ch_mult[0]
        elif self.demux == 'end':
            self.demux_dim = out_ch

        if self.demux_mod == 'channel-one':
            self.channel_one = torch.nn.ModuleList([noOrthoRegularizationLinear(self.demux_dim, self.demux_dim) for _ in range(self.K)])
        elif self.demux_mod == 'channel-conv':
            self.channel_conv = torch.nn.ModuleList([torch.nn.Sequential(noOrthoRegularizationConv2d(self.demux_dim, self.demux_dim, kernel_size=1, stride=1),
                                                    noOrthoRegularizationReLU(),) for _ in range(self.K)])
        elif self.demux_mod == 'K-Conv':
            self.kconv = torch.nn.ModuleList([torch.nn.Sequential(
                                            # torch.nn.Conv2d(self.demux_dim, self.demux_dim, kernel_size=1, stride=1, padding=3),
                                            noOrthoRegularizationConv2d(self.demux_dim, self.demux_dim, kernel_size=7, stride=1, padding=3),
                                            torch.nn.ReLU(),
                                            noOrthoRegularizationConv2d(self.demux_dim, self.demux_dim, kernel_size=7, stride=1, padding=3),
                                            torch.nn.ReLU()
                                            ) for _ in range(self.K)])
        else:
            assert False, 'Not implemented'

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
        self.conv_in = torch.nn.Conv2d(z_channels,
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def deproj(self, x):
        if self.demux_mod == 'channel-one':
            b_div_k, c, h, w = x.shape
            x_list = []
            for i in range(self.K):
                x_reshaped = x.permute(0, 2, 3, 1).reshape(-1, c).contiguous()
                x_reshaped = self.relu(self.channel_one[i](x_reshaped))
                x_reshaped = x_reshaped.reshape(b_div_k, h, w, c).permute(0, 3, 1, 2).contiguous()
                x_list.append(x_reshaped+x)
            x = torch.cat(x_list, dim=0)
        elif self.demux_mod == 'channel-conv':
            b_div_k, c, h, w = x.shape
            x_list = []
            for i in range(self.K):
                self.channel_conv[i].to(x.device)
                x_reshaped = self.channel_conv[i](x)
                x_list.append(x_reshaped+x)
            x = torch.cat(x_list, dim=0)
        elif self.demux_mod == 'K-Conv':
            b_div_k, c, h, w = x.shape
            x_list = []
            for i in range(self.K):
                self.kconv[i].to(x.device)
                x_reshaped = self.kconv[i](x)
                x_list.append(x_reshaped+x)
            x = torch.cat(x_list, dim=0)
        else:
            assert False, "Invalid demux_mod type."
        return x

    def forward(self, z):
        #assert z.shape[1:] == self.z_shape[1:]
        self.last_z_shape = z.shape

        # print(f'z.shape: {z.shape}')
        # timestep embedding
        temb = None

        if self.demux == 'before':
            z = self.deproj(z)
        else:
            pass
        # z to block_in
        h = self.conv_in(z)

        if self.demux == 'conv_in':
            h = self.deproj(h)
        else:
            pass
        # middle
        h = self.mid.block_1(h, temb)
        if self.demux == 'mid_1':
            h = self.deproj(h)
        else:
            pass
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)
        if self.demux == 'mid_2':
            h = self.deproj(h)
        else:
            pass

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)
        if self.demux == 'upsample':
            h = self.deproj(h)
        else:
            pass

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        if self.demux == 'end':
            h = self.deproj(h)
        else:
            pass

        if self.tanh_out:
            h = torch.tanh(h)

        # print(f'h.shape: {h.shape}')
        return h