import math

from einops import rearrange, repeat
from scipy.optimize import linear_sum_assignment
import torch as pt
import torch.nn as nn
import torch.nn.functional as ptnf

from .basic import MLP
from .utils import gumbel_softmax


class SlotAttention(nn.Module):
    """TODO XXX modularization/cgv: correct the wrong implementation!"""

    def __init__(
        self, num_iter, embed_dim, ffn_dim, dropout=0, kv_dim=None, trunc_bp=None
    ):
        super().__init__()
        kv_dim = kv_dim or embed_dim
        assert trunc_bp in ["bi-level", None]
        self.num_iter = num_iter
        self.trunc_bp = trunc_bp
        self.norm1q = nn.LayerNorm(embed_dim)
        self.proj_q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.norm1kv = nn.LayerNorm(kv_dim)
        self.proj_k = nn.Linear(kv_dim, embed_dim, bias=False)
        self.proj_v = nn.Linear(kv_dim, embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout) if dropout else lambda _: _
        self.rnn = nn.GRUCell(embed_dim, embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ffn = MLP(embed_dim, [ffn_dim, embed_dim], dropout)

    def forward(self, input, query, num_iter=None):
        """
        input: in shape (b,h*w,c)
        query: in shape (b,n,c)
        """
        b, n, c = query.shape
        self_num_iter = num_iter or self.num_iter
        kv = self.norm1kv(input)
        k = self.proj_k(kv)
        v = self.proj_v(kv)
        q = query
        for _ in range(self_num_iter):
            if _ + 1 == self_num_iter:
                if self.trunc_bp == "bi-level":
                    q = q.detach() + query - query.detach()
            x = q
            q = self.norm1q(q)
            q = self.proj_q(q)
            u, a = __class__.inverted_scaled_dot_product_attention(
                q, k, v, self.dropout
            )
            y = self.rnn(u.flatten(0, 1), x.flatten(0, 1)).view(b, n, -1)
            z = self.norm2(y)
            q = y + self.ffn(z)  # droppath on ffn seems harmful
        return q, a

    @staticmethod
    def inverted_scaled_dot_product_attention(q, k, v, self_dropout, eps=1e-5):
        scale = q.size(2) ** -0.5  # temperature
        logit = pt.einsum("bqc,bkc->bqk", q * scale, k)
        a0 = logit.softmax(1)  # inverted: softmax over query  # , logit.dtype
        a = a0 / (a0.sum(2, keepdim=True) + eps)  # re-normalize over key
        a = self_dropout(a)
        o = pt.einsum("bqv,bvc->bqc", a, v)
        return o, a0


class CartesianPositionalEmbedding2d(nn.Module):
    """"""

    def __init__(self, resolut, embed_dim):
        super().__init__()
        self.pe = nn.Parameter(
            __class__.meshgrid(resolut)[None, ...], requires_grad=False
        )
        self.project = nn.Linear(4, embed_dim)

    @staticmethod
    def meshgrid(resolut, low=-1, high=1):
        assert len(resolut) == 2
        yx = [pt.linspace(low, high, _ + 1) for _ in resolut]
        yx = [(_[:-1] + _[1:]) / 2 for _ in yx]
        grid_y, grid_x = pt.meshgrid(*yx)
        return pt.stack([grid_y, grid_x, 1 - grid_y, 1 - grid_x], 2)

    def forward(self, input):
        """
        input: in shape (b,h,w,c)
        output: in shape (b,h,w,c)
        """
        max_h, max_w = input.shape[1:3]
        output = input + self.project(self.pe[:, :max_h, :max_w, :])
        return output


class LearntPositionalEmbedding1d(nn.Module):
    """PositionalEncoding: https://pytorch.org/tutorials/beginner/transformer_tutorial.html"""

    def __init__(self, length, embed_dim, dropout=0.1):
        super().__init__()
        self.pe = nn.Parameter(pt.zeros(1, length, embed_dim), requires_grad=True)
        nn.init.trunc_normal_(self.pe)
        self.dropout = nn.Dropout(dropout) if dropout else lambda _: _

    def forward(self, input):
        """
        input: in shape (b,n,c)
        output: in shape (b,n,c)
        """
        max_len = input.size(1)
        output = input + self.pe[:, :max_len, :]
        output = self.dropout(output)
        return output


class NormalInitializ(nn.Module):
    """shared gaussian"""

    def __init__(self, num_smpl, num_dim):
        super().__init__()
        self.num_smpl = num_smpl
        self.num_dim = num_dim
        self.mu = nn.Parameter(pt.empty([1, 1, num_dim]))
        self.log_sigma = nn.Parameter(pt.empty([1, 1, num_dim]))
        nn.init.xavier_uniform_(self.mu[0, :, :])
        nn.init.xavier_uniform_(self.log_sigma[0, :, :])

    def forward(self, b):
        self_mu = self.mu.expand(b, self.num_smpl, -1)
        self_sigma = self.log_sigma.exp().expand(b, self.num_smpl, -1)
        randn = pt.randn_like(self_mu)
        smpls = self_mu + randn * self_sigma
        return smpls


class LearntInitializ(nn.Module):
    """"""

    def __init__(self, num_slot, slot_dim):
        super().__init__()
        self.mu = nn.Parameter(pt.empty(1, num_slot, slot_dim))
        self.sigma = 0.0  # nn.Parameter(pt.tensor(0.0), requires_grad=False)
        nn.init.xavier_uniform_(self.mu[0, :, :])  # very important

    def forward(self, b):
        smpl = self.mu.expand(b, -1, -1)
        if self.training:
            randn = pt.randn_like(smpl)  # ``*smpl.detach()``: very beneficial ???
            smpl = smpl + randn * self.sigma * smpl.detach()
        return smpl


class dVAE(nn.Module):
    """
    Singh et al. Illiterate DALL-E Learns to Compose. ICLR 2022.

    Not discrete VAE, not dynamic VAE, just a variant for OCL.
    """

    def __init__(self, encode, decode, codebook=None):
        super().__init__()
        self.encode = encode
        self.decode = decode
        self.tau = nn.Parameter(pt.tensor(0.1), requires_grad=False)
        self.codebook = codebook

    def forward(self, input):
        """
        input: image in shape (b,c,h,w)
        z_hard: token in shape (b,c',h',w')
        decode: in shape (b,c,h,w)
        """
        encode = self.encode(input)
        zsoft, zidx = self.pick(encode)
        quant = None
        if self.codebook:
            quant = self.codebook(zidx).permute(0, 3, 1, 2)
        decode = None
        if self.decode:
            decode = self.decode(zsoft)
        return zsoft, zidx, quant, decode

    def pick(self, encode):
        """Why???
        - ``log_softmax`` is very beneficial
            -- GeneralZ: this smooths logit thus might encourage (zhard) exploration
        - custom ``gumbel_softmax`` >> ``ptnf.gumbel_softmax``

        - GeneralZ: hard sampling on zsoft instead of logit to encourage exploration:
            - often good for larger g=8, always bad for g=4,2,1
            - our grouping needs more exploration for zhard
        """
        logit = ptnf.log_softmax(encode, 1)
        if self.training and not self.codebook:
            zsoft = gumbel_softmax(logit, self.tau, False, 1)  # soft gumbel
        else:
            zsoft = logit.softmax(1)  # (b,c,h,w)
        # zhard = gumbel_softmax(logit, tau, True, dim=1)  # hard gumbel: <=reuse
        zidx = zsoft.argmax(1)  # (b,h,w)
        return zsoft, zidx


class dVAEGrouped(dVAE):
    """
    Zhao et al. Grouped Discrete Representation Guides Object-Centric Learning. arXiv:2407.01726.
    """

    def __init__(self, encode, decode, groups, codebook=None, learn=True):
        super().__init__(encode, decode, codebook)
        self.groups = groups
        self.learn = learn

    def forward(self, input):
        encode = self.encode(input)
        zsoft_, zsoft, zidx = self.pick(encode)
        quant = None
        if self.codebook:
            quant = self.codebook(zidx).permute(0, 3, 1, 2)
        decode = None
        if self.decode:
            decode = self.decode(zsoft)
        if self.learn:
            return zsoft_, zsoft, zidx, quant, decode
        return zsoft, zidx, quant, decode

    def pick(self, encode):
        d = encode.size(1) // sum(self.groups)
        zsoft_ = []
        zsoft = []
        zidx = []
        start = 0
        for g in self.groups:
            end = start + g * d
            encode_g = encode[:, start:end, :, :]
            start = end
            logit_g = ptnf.log_softmax(encode_g, 1)
            zsoft_g_ = logit_g.softmax(1)
            logit_g = rearrange(logit_g, "b (g d) h w -> b g d h w", g=g).mean(2)
            if self.training and any(_.requires_grad for _ in self.encode.parameters()):
                zsoft_g = gumbel_softmax(logit_g, self.tau, False, 1)
            else:
                zsoft_g = logit_g.softmax(1)
            zidx_g = zsoft_g.argmax(1)
            zsoft_.append(zsoft_g_ * (g / sum(self.groups)))
            zsoft.append(zsoft_g * (g / sum(self.groups)))
            zidx.append(zidx_g)
        assert end == encode.size(1)
        zsoft_ = pt.cat(zsoft_, 1)
        zsoft = pt.cat(zsoft, 1)  # (b,c,h,w)
        zidx = pt.stack(zidx, 1)  # (b,g,h,w)
        return zsoft_, zsoft, zidx


class VQVAE(nn.Module):
    """
    Oord et al. Neural Discrete Representation Learning. NeurIPS 2017.

    reconstruction loss and codebook alignment/commitment (quantization) loss
    """

    def __init__(self, encode, decode, codebook):
        super().__init__()
        self.encode = encode  # should be encoder + quantconv
        self.decode = decode  # should be decoder + postquantconv
        self.codebook = codebook

    def forward(self, input):
        """
        input: image; shape=(b,c,h,w)
        """
        encode = self.encode(input)
        zsoft, zidx = self.codebook.match(encode, False)
        quant = self.codebook(zidx).permute(0, 3, 1, 2)  # (b,h,w,c) -> (b,c,h,w)
        quant = quant.to(encode.dtype)  # TODO added after gdr, ogdr; add fp32 to diffuz
        quant2 = __class__.grad_approx(encode, quant)
        decode = None
        if self.decode:
            decode = self.decode(quant2)
        return encode, zidx, quant, decode

    @staticmethod
    def grad_approx(z, q, nu=0):  # nu=1 always harmful; maybe smaller nu???
        """
        straight-through gradient approximation

        synchronized:
        Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks
        """
        assert nu >= 0
        q = z + (q - z).detach()  # delayed: default
        if nu > 0:
            q += nu * (q - q.detach())  # synchronized
        return q


class VQVAEMultiScale(VQVAE):

    def __init__(
        self,
        encode,
        decode,
        codebook,  # type: nn.ModuleList  # shared*1 + specified*num_scale
        project,  # LinearPinv2d > fc*2 >> conv3*2
        num_scale=3,
        sample=True,  # for stage2 only
        encode_as_quant=False,  # for stage2 only; true for coco
        normaliz=False,  # true for slotdiffuz
        learn=True,
    ):
        super().__init__(encode, decode, codebook)
        self.project = project
        self.num_scale = num_scale
        self.sample = sample  # false or tau=0.1
        self.encode_as_quant = encode_as_quant
        self.normaliz = normaliz
        self.learn = learn
        self.register_buffer("tau", pt.as_tensor(0.1))  # fixed 0.1 > 0.01 > 1
        self.register_buffer("alpha", pt.as_tensor(0.0))  # 0.5 -> 0

    def forward(self, input):
        """
        input: image; shape=(b,c,h,w)
        """
        ### resize
        inputs = [input] + [  # s*(b,c,h,w)
            ptnf.interpolate(input, scale_factor=2**-_, mode="bilinear")
            for _ in range(1, self.num_scale)
        ]
        ### encode
        encodes = [self.encode(_) for _ in inputs]
        ### project-up
        encodes_ = [self.project(_, pinv=True) for _ in encodes]
        ### quant: match and select
        zidxs, encodes_, quants_ = __class__.ms_fuz(  # s*(b,h,w) s*(b,c,h,w)x2
            self.learn,
            self.training,
            # codebook: 1+s > 1+1 > non-grouped; num_code: 256gce > 64ce > 1024gce or 128gce
            self.codebook,
            encodes_,
            self.training and (self.sample or self.learn),  # gumbel is beneficial
            self.tau,
        )  # TODO XXX add fp32 to diffuz TODO XXX TODO XXX TODO XXX TODO XXX
        ### residual
        if self.alpha > 0:  # definitely helpful; 0.5 > 1, 0.2 > 0
            quants_ = [
                e.detach() * self.alpha + q * (1 - self.alpha)
                for e, q in zip(encodes_, quants_)
            ]
        ### approx grad
        quants__ = [  # nu=1 bad; try other value ???
            __class__.grad_approx(e, q) for e, q in zip(encodes_, quants_)
        ]
        ### project-down
        quants = [self.project(_) for _ in quants__]
        ### normaliz
        if self.normaliz:
            quants = [ptnf.group_norm(_, 1) for _ in quants]
        ### decode
        decodes = [None] * self.num_scale
        if self.decode:
            decodes = [self.decode(_) for _ in quants]
        if self.learn:
            return *encodes, *encodes_, *zidxs, *quants_, *quants, *decodes
        ### else:  # self.learn is false
        quant = quants[0]
        if self.encode_as_quant:  # for coco-like only
            # quant = self.project(encodes_[0])
            # quant = (quant + self.project(encodes_[0])) / 2
            ch = quants_[0].size(1) // 2
            quant = self.project(  # only apply to the scale-variant half
                pt.cat([quants_[0][:, :ch, :, :], encodes_[0][:, ch:, :, :]], 1)
            )
            if self.normaliz:  # for slotdiffuz only
                quant = ptnf.group_norm(quant, 1)
        return encodes[0], zidxs[0], quant, decodes[0]

    @staticmethod
    def ms_fuz(
        self_learn: bool,
        self_training: bool,
        self_codebook: nn.ModuleList,
        encodes: list,
        sample: bool,
        tau=0.1,
    ):
        s = len(encodes)
        assert len(self_codebook) == 1 + s  # shared*1 + specified*s
        b, c, h, w = encodes[0].shape
        ch = c // 2
        assert c % 2 == 0
        dtype = encodes[0].dtype

        zidx1s = []
        encode1s = []
        quant1s = []

        for i1 in range(s):  # fuz234 > nofuz234_fuzto2
            encode1_ = []
            for j1, encode in enumerate(encodes):  # s*(b,c,h,w)
                e1_ = encode[:, :ch, :, :]
                if j1 == i1:
                    pass  # (b,c/2,h,w)
                elif j1 < i1:  # avg vs max: max seems to favor fg
                    e1_ = ptnf.avg_pool2d(e1_, 2 ** (i1 - j1))
                else:
                    e1_ = ptnf.upsample(
                        e1_, scale_factor=2 ** (j1 - i1), mode="nearest"
                    )
                encode1_.append(e1_)

            encode1_ = pt.stack(encode1_)  # (s,b,c/2,h,w)
            # cluster to initialize codebook: harmful
            """if self_learn and self_training:  
                self_codebook[0].cluster(encode1_.flatten(0, 1))"""
            zsoft1_, zidx1_ = self_codebook[0].match(  # (s*b,m,h,w) (s*b,h,w)
                encode1_.flatten(0, 1), sample, tau / 10
            )
            zsoft1_ = zsoft1_.unflatten(0, [s, b])
            # zidx1_ = zidx1_.unflatten(0, [s, b])

            zsoft11, zidx11 = zsoft1_.max(2)  # (s,b,m,h,w) -> (s,b,h,w)
            """with pt.no_grad():  # nsigma=0.001>0: always harmful ???
                zsoft11[i1] += (nsigma * zsoft11.std(0).mean([1, 2], keepdim=True))"""
            if sample:
                zsoft11 = gumbel_softmax(zsoft11, tau=tau, dim=0)
            zidx12 = zsoft11.argmax(0)  # (s,b,h,w) -> (b,h,w)
            zidx1 = zidx11.gather(0, zidx12[None, :, :, :])[0]
            encode1 = encode1_.gather(  # (s,b,c/2,h,w) -> (b,c/2,h,w)
                0, zidx12[None, :, None, :, :].expand(-1, -1, ch, -1, -1)
            )[0]
            quant1 = self_codebook[0](zidx1).permute(0, 3, 1, 2)
            # replace to activate dead codes: helpful in stage1 but always harmful in stage2, why ???
            if self_learn and self_training:
                self_codebook[0].replace(encode1, zidx1)

            zidx1s.append(zidx1)  # s*(b,h,w)
            encode1s.append(encode1)  # s*(b,c/2,h,w)
            quant1s.append(quant1.to(dtype))  # s*(b,c/2,h,w)

            """zidx10 = zidx1_[i1]  # dual-superv: bad
            encode10 = encode1_[i1]
            quant10 = self_codebook[0](zidx10).permute(0, 3, 1, 2)
            zidx1s.append(zidx10)
            encode1s.append(encode10)
            quant1s.append(quant10.to(dtype))"""

        zidx2s = []
        encode2s = []
        quant2s = []

        for i2 in range(s):
            encode2 = encodes[i2][:, ch:, :, :]
            # cluster to initialize codebook: harmful
            """if self_learn and self_training:  
                self_codebook[1 + i2].cluster(encode2)"""
            zsoft2, zidx2 = self_codebook[1 + i2].match(encode2, sample, tau)
            quant2 = self_codebook[1 + i2](zidx2).permute(0, 3, 1, 2)
            # replace to activate dead codes: helpful in stage1 but always harmful in stage2, why ???
            if self_learn and self_training:
                self_codebook[1 + i2].replace(encode2, zidx2)
            zidx2s.append(zidx2)
            encode2s.append(encode2)
            quant2s.append(quant2.to(dtype))

            """zidx2s.append(zidx2)  # dual-superv: bad
            encode2s.append(encode2)
            quant2s.append(quant2.to(dtype))"""

        zidxs = [pt.stack([u, v], 1) for u, v in zip(zidx1s, zidx2s)]  # 2s*(b,g,h,w)
        encodes = [  # 2s*(b,c,h,w)
            pt.cat([u, v], 1) for u, v in zip(encode1s, encode2s)
        ]
        quants = [pt.cat([u, v], 1) for u, v in zip(quant1s, quant2s)]  # 2s*(b,c,h,w)
        return zidxs, encodes, quants


class VQVAEGroupedMultiScale(VQVAEMultiScale):

    @staticmethod
    def ms_fuz(
        self_learn: bool,
        self_training: bool,
        self_codebook: nn.ModuleList,
        encodes: list,
        sample: bool,
        tau=0.1,
    ):
        s = len(encodes)
        assert len(self_codebook) == 1 + s  # shared*1 + specified*s
        b, c, h, w = encodes[0].shape
        ch = c // 2
        assert c % 2 == 0
        dtype = encodes[0].dtype

        zidx1s = []
        encode1s = []
        quant1s = []

        for i1 in range(s):  # fuz234 > nofuz234_fuzto2
            encode1_ = []
            for j1, encode in enumerate(encodes):  # s*(b,c,h,w)
                e1_ = encode[:, :ch, :, :]
                if j1 == i1:
                    pass  # (b,c/2,h,w)
                elif j1 < i1:  # avg vs max: max seems to favor fg
                    e1_ = ptnf.avg_pool2d(e1_, 2 ** (i1 - j1))
                else:
                    e1_ = ptnf.upsample(
                        e1_, scale_factor=2 ** (j1 - i1), mode="nearest"
                    )
                encode1_.append(e1_)

            encode1_ = pt.stack(encode1_)  # (s,b,c/2,h,w)
            zsoft1_, zidx1_ = __class__.g_match(  # (s*b,m,h,w) (s*b,h,w)
                self_codebook[0], encode1_.flatten(0, 1), sample, tau / 10
            )
            zsoft1_ = zsoft1_.unflatten(0, [s, b])
            # zidx1_ = zidx1_.unflatten(0, [s, b])

            zsoft11, zidx11 = zsoft1_.max(2)  # (s,b,m,h,w) -> (s,b,h,w)
            if sample:
                zsoft11 = gumbel_softmax(zsoft11, tau=tau, dim=0)
            zidx12 = zsoft11.argmax(0)  # (s,b,h,w) -> (b,h,w)
            zidx1 = zidx11.gather(0, zidx12[None, :, :, :])[0]
            encode1 = encode1_.gather(  # (s,b,c/2,h,w) -> (b,c/2,h,w)
                0, zidx12[None, :, None, :, :].expand(-1, -1, ch, -1, -1)
            )[0]
            quant1 = __class__.g_select(self_codebook[0], zidx1).permute(0, 3, 1, 2)
            # replace to activate dead codes: helpful in stage1 but always harmful in stage2, why ???
            if self_learn and self_training:
                __class__.g_replace(self_codebook[0], encode1, zidx1)

            zidx1s.append(zidx1)  # s*(b,h,w)
            encode1s.append(encode1)  # s*(b,c/2,h,w)
            quant1s.append(quant1.to(dtype))  # s*(b,c/2,h,w)

        zidx2s = []
        encode2s = []
        quant2s = []

        for i2 in range(s):
            encode2 = encodes[i2][:, ch:, :, :]
            zsoft2, zidx2 = __class__.g_match(
                self_codebook[1 + i2], encode2, sample, tau
            )
            quant2 = __class__.g_select(self_codebook[1 + i2], zidx2).permute(
                0, 3, 1, 2
            )
            # replace to activate dead codes: helpful in stage1 but always harmful in stage2, why ???
            if self_learn and self_training:
                __class__.g_replace(self_codebook[1 + i2], encode2, zidx2)
            zidx2s.append(zidx2)
            encode2s.append(encode2)
            quant2s.append(quant2.to(dtype))

        zidxs = [pt.stack([u, v], 1) for u, v in zip(zidx1s, zidx2s)]  # 2s*(b,g,h,w)
        encodes = [  # 2s*(b,c,h,w)
            pt.cat([u, v], 1) for u, v in zip(encode1s, encode2s)
        ]
        quants = [pt.cat([u, v], 1) for u, v in zip(quant1s, quant2s)]  # 2s*(b,c,h,w)
        return zidxs, encodes, quants

    @staticmethod
    def g_match(g_codebook: nn.ModuleList, encode, sample: bool, tau=1):
        """
        g_codebook: [Codebook(),..]
        encode: shape=(b,c,h,w)
        zsoft: shape=(b,m,h,w)
        zidx: shape=(b,h,w)
        """
        zsoft = []
        zidx = []
        start = 0
        for i, codebook in enumerate(g_codebook):
            end = start + codebook.embed_dim
            encode_g = encode[:, start:end, :, :]
            templat_g = codebook.weight
            start = end
            zsoft_g, zidx_g = codebook.match(encode_g, sample, tau)
            zsoft.append(zsoft_g)
            zidx.append(zidx_g)
        assert end == encode.size(1)
        zsoft = pt.concat(zsoft, 1)
        zidx = pt.stack(zidx, 1)
        return zsoft, zidx

    @staticmethod
    def g_select(g_codebook: nn.ModuleList, zidx):
        """
        g_codebook: [Codebook(),..]
        zidx: indexes, shape=(b,g,h,w)
        output: shape=(b,h,w,c)
        """
        output = []
        for i, codebook in enumerate(g_codebook):
            idx_g = zidx[:, i, ...]
            output_g = codebook(idx_g)
            output.append(output_g)
        assert len(g_codebook) == input.size(1)
        output = pt.cat(output, -1)
        return output

    @staticmethod
    def g_replace(g_codebook: nn.ModuleList, encode, zidx):
        """
        g_codebook: [Codebook(),..]
        encode: shape=(b,c,h,w)
        zidx: indexes, shape=(b,g,h,w)
        """
        start = 0
        for i, codebook in enumerate(g_codebook):
            end = start + codebook.embed_dim
            encode_g = encode[:, start:end, :, :]
            zidx_g = zidx[:, start:end, :, :]
            codebook.replace(encode_g, zidx_g)
            start = end
        return


class VQVAEGrouped(VQVAE):

    def __init__(
        self,
        encode,
        decode,
        codebook,
        project,  # LinearPinv2d
        sample=True,  # for stage2 only
        encode_as_quant=False,  # for stage2 only; true for coco
        normaliz=False,  # true for slotdiffuz
        learn=True,
    ):
        super().__init__(encode, decode, codebook)
        self.project = project
        self.sample = sample
        self.encode_as_quant = encode_as_quant
        self.normaliz = normaliz
        self.learn = learn
        self.register_buffer("tau", pt.as_tensor(0.1))  # 1.0 -> 0.1
        self.register_buffer("alpha", pt.as_tensor(0.0))  # 0.5 -> 0

    def forward(self, input):
        """
        input: image in shape (b,c,h,w)
        """
        encode = self.encode(input)
        ### GroupNorm(affine=True or False): bad
        """if self.normaliz:
            encode = self.normaliz(encode)"""
        ### randn_first: worse
        """if self.training and self.full:
            encode = encode + pt.randn_like(encode) * (
                self.tau / encode.detach().abs().mean()
            )"""
        ### proj1: worse
        """encode_ = self.project(encode)"""
        ### lstsq1 > lstsq1_detach > proj1; lstsq1 > lstsq2
        with pt.autocast("cuda", pt.float32):
            encode_ = self.project(encode, pinv=True)
        ### match_gumbel > match, randn_first
        zsoft, zidx = self.codebook.match(
            encode_,
            sample=self.training and (self.sample or self.learn),  # stage2 gumbel > no
            tau=self.tau,
        )
        quant_ = self.codebook(zidx).permute(0, 3, 1, 2)
        ### residual_in > residual_out; before attach_grad > after
        if self.alpha > 0:
            quant_ = encode_.detach() * self.alpha + quant_ * (1 - self.alpha)
        ### attach_grad_in > attach_grad_out
        quant__ = __class__.grad_approx(encode_, quant_)
        ### lstsq2: worse
        """with pt.autocast("cuda", pt.float32):
            quant = pt.linalg.lstsq(
                self.project.weight[None, :, :, 0, 0].detach(),  # (b,c,d) (b,c,h*w)
                (quant__ - self.project.bias[None, :, None, None].detach()).flatten(2, 3),
            ).solution.unflatten(2, encode.shape[2:4])"""
        ### proj2 > lstsq2
        quant = self.project(quant__)
        # gn() good for slotdiffuz; gn(affine) a little bit bad for slate/steve, gn bad
        if self.normaliz:
            quant = ptnf.group_norm(quant, 1)
        decode = None
        if self.decode:
            ### attach_grad_out: worse
            """quant2 = encode + (quant - encode).detach()"""
            ### residual_out: worse
            """if self.alpha > 0:
                quant2 = encode.detach() * self.alpha + quant * (1 - self.alpha)
            else:
                quant2 = quant"""
            decode = self.decode(quant)
        if self.learn:  # TODO remove zsoft; introduce cluster, replace and synchroniz
            return encode, encode_, zsoft, zidx, quant_, quant, decode
        if self.encode_as_quant:  # for coco-like
            quant = encode
            if self.normaliz:  # for slotdiffuz
                quant = ptnf.group_norm(quant, 1)
        return encode, zidx, quant, decode


class LinearPinv2d(nn.Module):

    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        # self.wn = wn  # weight-norm always bad (for slate/steve) on all datasets
        conv = nn.Conv2d(in_channel, out_channel, 1)
        nn.init.zeros_(conv.bias)
        self.weight = conv.weight
        self.bias = conv.bias

    def forward(self, input, pinv=False):
        if pinv:
            return (
                pt.linalg.lstsq(  # same as A.pinv() @ B
                    self.weight[None, :, :, 0, 0],  # (b,c,d) (b,c,h*w)
                    (input - self.bias[None, :, None, None]).flatten(2, 3),
                )
                .solution.unflatten(2, input.shape[2:4])
                .contiguous()
            )
        return ptnf.conv2d(input, self.weight, self.bias)

    def extra_repr(self):
        return f"{self.in_channel}, {self.out_channel}"


class Codebook(nn.Module):
    """
    clust: always negative
    replac: always positive
    sync: always negative
    """

    def __init__(self, num_embed, embed_dim):
        super().__init__()
        self.num_embed = num_embed
        self.embed_dim = embed_dim
        self.templat = nn.Embedding(num_embed, embed_dim)
        n = self.templat.weight.size(0)  # good to vqvae pretrain but bad to dvae
        self.templat.weight.data.uniform_(-1 / n, 1 / n)

    def forward(self, input):
        """
        input: indexes in shape (b,..)
        output: in shape (b,..,c)
        """
        output = self.templat(input)
        return output

    def match(self, encode, sample: bool, tau=1, detach="encode"):
        return __class__.match_encode_with_templat(
            encode, self.templat.weight, sample, tau, detach
        )

    @pt.no_grad()
    def cluster(self, latent, max_iter=100):  # always harmful
        """Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks

        latent: shape=(b,c,h,w)
        """
        assert self.training  # if not self.training: return
        if not hasattr(self, "cluster_flag"):  # only once
            self.cluster_flag = pt.zeros([], dtype=pt.bool, device=latent.device)
        if self.cluster_flag:
            return
        self.cluster_flag.data[...] = True
        latent = latent.permute(0, 2, 3, 1).flatten(0, -2)  # (b,h,w,c)
        n, c = latent.shape
        if n < self.num_embed:
            raise f"warmup samples should >= codebook size: {n} vs {self.num_embed}"
        print("clustering...")
        assign, centroid = __class__.kmeans_pt(
            latent, self.num_embed, max_iter=max_iter
        )
        self.templat.weight.data[...] = centroid

    @pt.no_grad()
    def replace(self, latent, zidx, rate=1, rho=1e-2, timeout=4096, cluster=0.5):
        """Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks

        latent: shape=(b,c,h,w)
        zidx: shape=(b,..)
        timeout: in #vector; will be converted to #iter
        cluster: with is too slow !!!

        Alchemy
        ---
        for stage2 (maynot stand for stage1):
        - replace rate: 1>0.5;
        - noise rho: 1e-2>0;
        - replace timeout: 4096>1024,16384;
        - enabled in half training steps > full;
        - cluster r0.5 > r0.1?
        """
        assert self.training  # if not self.training: return
        if not hasattr(self, "replace_cnt"):  # only once
            self.replace_cnt = pt.ones(
                self.num_embed, dtype=pt.int, device=latent.device
            )
            self.replace_rate = pt.as_tensor(
                rate, dtype=latent.dtype, device=latent.device
            )
        assert 0 <= self.replace_rate <= 1
        if self.replace_rate == 0:
            return

        latent = latent.permute(0, 2, 3, 1).flatten(0, -2)
        m = latent.size(0)
        timeout = math.ceil(timeout * self.num_embed / m)  # from #vector to #iter

        assert 0 <= cluster <= 1
        if self.replace_rate > 0 and cluster > 0:  # cluster update rate
            assert m >= self.num_embed
            if not hasattr(self, "replace_centroid"):
                self.replace_centroid = __class__.kmeans_pt(
                    latent,
                    self.num_embed,
                    self.templat.weight.data.to(latent.dtype),
                    max_iter=100,
                )[1]
            else:
                centroid = __class__.kmeans_pt(
                    latent, self.num_embed, self.replace_centroid, max_iter=1
                )[1]
                self.replace_centroid = (  # equal to ema ???
                    self.replace_centroid * (1 - cluster) + centroid * cluster
                )

        assert self.replace_cnt.min() >= 0
        self.replace_cnt -= 1

        # reset cnt of recently used codes
        active_idx = pt.unique(zidx)
        self.replace_cnt.index_fill_(0, active_idx, timeout)

        # reset value of unused codes
        dead_idx = (self.replace_cnt == 0).argwhere()[:, 0]  # (n,)->(n,1)->(n,)
        num_dead = dead_idx.size(0)
        if num_dead > 0:
            print("#", timeout, self.num_embed, m, dead_idx)
            mult = num_dead // m + 1

            ### policy: random from input
            """latent = latent[pt.randperm(m)]
            if mult > 1:  # no need to repeat and shuffle all as mult always == 1
                latent = latent.tile([mult, 1])
            replac = latent[:num_dead]"""
            ### policy: random least similar to others from input
            """dist = __class__.euclidean_distance(latent, self.templat(active_idx))
            ridx = dist.mean(1).topk(min(num_dead, m), sorted=False)[1]
            if mult > 1:
                ridx = ridx.tile(mult)[:num_dead]
            replac = latent[ridx]"""
            ### policy: most similar centriod to self from input -- VQ-NeRV: A Vector Quantized Neural Representation for Videos
            dist = __class__.euclidean_distance(
                self.templat.weight.data[dead_idx], self.replace_centroid
            )
            row_idx, col_idx = linear_sum_assignment(dist.detach().cpu())
            replac = self.replace_centroid[pt.from_numpy(col_idx).to(latent.device)]

            # add noise
            if rho > 0:  # helpful
                norm = replac.norm(p=2, dim=-1, keepdim=True)
                noise = pt.randn_like(replac)
                replac = replac + rho * norm * noise

            self.templat.weight.data = self.templat.weight.data.clone()
            self.templat.weight.data[dead_idx] = (
                self.templat.weight.data[dead_idx] * (1 - self.replace_rate)
                + replac * self.replace_rate
            )
            self.replace_cnt[dead_idx] += timeout

    @staticmethod
    def kmeans_pt(X, num_cluster: int, center=None, tol=1e-4, max_iter=100):
        """euclidean kmeans in pytorch
        https://github.com/subhadarship/kmeans_pytorch/blob/master/kmeans_pytorch/__init__.py

        X: shape=(m,c)
        center: (initial) centers for clustering; shape=(n,c)
        assign: clustering assignment to vectors in X; shape=(m,)
        """
        m, c = X.shape
        if center is None:
            idx0 = pt.randint(0, m, [num_cluster])
            center = X[idx0]

        cnt = 0
        while True:
            dist = __class__.euclidean_distance(X, center)  # (m,c) (n,c) -> (m,n)
            assign = dist.argmin(1)  # (m,)
            center_old = center.clone()

            for cid in range(num_cluster):
                idx = (assign == cid).nonzero().squeeze()  # (m2,)
                if idx.numel() == 0:
                    idx = pt.randint(0, m, [1])
                cluster = X[idx]  # (m2,c)  # index_select
                center[cid] = cluster.mean(0)

            shift = ptnf.pairwise_distance(center, center_old).sum()
            if shift**2 < tol:
                break

            cnt = cnt + 1
            if cnt > max_iter:
                break

        return assign, center

    @staticmethod
    def match_encode_with_templat(encode, templat, sample, tau=1, detach="encode"):
        """
        encode: in shape (b,c,h,w)
        templat: in shape (m,c)
        zsoft: in shape (b,m,h,w)
        zidx: in shape (b,h,w)
        """
        if detach == "encode":
            encode = encode.detach()
        elif detach == "templat":
            templat = templat.detach()
        # b, c, h, w = encode.shape
        # dist = __class__.euclidean_distance(  # (b*h*w,c) (m,c) -> (b*h*w,m)
        #    encode.permute(0, 2, 3, 1).flatten(0, -2), templat
        # )
        # dist = dist.view(b, h, w, -1).permute(0, 3, 1, 2)  # (b,m,h,w)
        # simi = -dist.square()  # better than without  # TODO XXX learnable scale ???
        dist = (  # always better than cdist.square, why ???
            encode.square().sum(1, keepdim=True)  # (b,1,h,w)
            + templat.square().sum(1)[None, :, None, None]
            - 2 * pt.einsum("bchw,mc->bmhw", encode, templat)
        )  # 1 > 0.5 > 2, 4
        simi = -dist
        if sample and tau > 0:
            zsoft = gumbel_softmax(simi, tau, False, 1)
        else:
            zsoft = simi.softmax(1)
        zidx = zsoft.argmax(1)  # (b,m,h,w) -> (b,h,w)
        return zsoft, zidx

    @staticmethod
    def euclidean_distance(source, target, split_size=4096):
        """chunked cdist

        source: shape=(b,m,c) or (m,c)
        target: shape=(b,n,c) or (n,c)
        split_size: in case of oom; can be bigger than m
        dist: shape=(b,m,n) or (m,n)
        """
        assert source.ndim == target.ndim and source.ndim in [2, 3]
        source = source.split(split_size)  # type: list
        dist = []
        for s in source:
            d = pt.cdist(s, target, p=2)  # (m2,n);
            dist.append(d)
        dist = pt.concat(dist)  # (m,n)
        return dist


class CodebookGrouped(nn.Module):

    def __init__(self, groups, cat_dim, embed_dim=None):
        super().__init__()
        gdims_i = __class__.calc_gdims(groups, cat_dim)
        gdims_o = __class__.calc_gdims(groups, embed_dim)  # for vqvae
        self.templat = nn.ModuleList(
            [nn.Embedding(g, d) for g, d in zip(groups, gdims_i)]
        )  # templat > templat+layernorm
        self.reset_parameters()
        self.gdims_i = gdims_i  # for vqvae
        self.gdims_o = gdims_o  # for vqvae

    @staticmethod
    def calc_gdims(groups, dim):
        gdim0 = dim / sum(groups)
        gdims = [int(math.ceil(_ * gdim0)) for _ in groups]
        gdims[-1] = dim - sum(gdims[:-1])
        return gdims

    def reset_parameters(self):
        for mg in self.templat:
            ng = mg.weight.size(0)
            mg.weight.data.uniform_(-1 / ng, 1 / ng)

    def forward(self, input):
        """
        input: indexes in shape (b,g,..)
        output: in shape (b,..,c)
        """
        output = []
        for i in range(input.size(1)):
            idx_g = input[:, i, ...]
            output_g = self.templat[i](idx_g)
            output.append(output_g)
        assert len(self.templat) == input.size(1)
        output = pt.cat(output, -1)
        return output

    def match(self, encode, sample: bool, tau=1, detach="encode"):
        zsoft = []
        zidx = []
        start = 0
        for i, g in enumerate(self.gdims_i):
            end = start + g
            encode_g = encode[:, start:end, :, :]
            templat_g = self.templat[i].weight
            start = end
            zsoft_g, zidx_g = Codebook.match_encode_with_templat(
                encode_g, templat_g, sample, tau, detach
            )
            zsoft.append(zsoft_g)
            zidx.append(zidx_g)
        assert end == encode.size(1)
        zsoft = pt.concat(zsoft, 1)
        zidx = pt.stack(zidx, 1)
        return zsoft, zidx
