import math

from einops import rearrange
import torch as pt
import torch.nn as nn

from .basic import TransformDecode, TransformDecodeBlock


class SLATE(nn.Module):

    def __init__(
        self,
        mediat,
        h1w1,
        encode_backbone,
        h2w2,
        encode_posit_embed,
        encode_project,
        initializ,
        correct,
        decode_bos,
        decode_posit_embed,
        decode_backbone,
        decode_readout,
    ):
        super().__init__()
        mediat.eval()
        self.mediat = mediat  # type: dVAE
        self.h1w1 = h1w1
        self.encode_backbone = encode_backbone
        self.h2w2 = h2w2
        self.encode_posit_embed = encode_posit_embed
        self.encode_project = encode_project
        self.initializ = initializ
        self.correct = correct
        self.decode_bos = decode_bos
        self.decode_posit_embed = decode_posit_embed
        self.decode_backbone = decode_backbone
        self.register_buffer(
            "mask", pt.triu(pt.ones([math.prod(h1w1)] * 2, dtype=pt.bool), 1)
        )
        self.decode_readout = decode_readout
        self.reset_parameters()

    def reset_parameters(self):
        # ``zero init conv/linear/gru bias`` converges faster
        for n, m in self.named_modules():
            if n.startswith("mediat.") or n.startswith("encode_backbone"):
                continue
            if isinstance(m, nn.Conv2d):
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.GRUCell):
                if m.bias:
                    nn.init.zeros_(m.bias_ih)
                    nn.init.zeros_(m.bias_hh)

    def forward(self, input, condit=None):
        """
        image: in shape (b,c,h,w)
        condit: in shape (b,n,c)
        """
        b, c, h, w = input.shape

        encode = self.encode_backbone(input)  # (b,c,h,w)
        encode_ = encode.permute(0, 2, 3, 1)  # (b,h,w,c)
        encode_ = self.encode_posit_embed(encode_)
        encode_ = encode_.flatten(1, 2)  # (b,h*w,c)
        encode_ = self.encode_project(encode_)

        hidden = self.initializ(b if condit is None else condit)  # (b,n,c)
        correct, attent_ = self.correct(encode_, hidden)
        attent = rearrange(attent_, "b n (h w) -> b n h w", h=self.h2w2[0])

        zsoft, zidx, quant, decode1 = self.mediat(input)  # (b,c,h,w)
        quant = quant.detach()

        quant_ = rearrange(quant, "b c h w -> b (h w) c")
        token = pt.cat([self.decode_bos.expand(b, -1, -1), quant_], 1)[:, :-1, :]
        token = self.decode_posit_embed(token)
        token = self.decode_backbone(token, correct, self.mask)
        prob_ = self.decode_readout(token)  # (b,h*w,c)
        prob = rearrange(prob_, "b (h w) c -> b c h w", h=self.h1w1[0])

        segment = attent.argmax(1)  # (b,h,w)
        return zidx, prob, segment, correct, attent


class STEVE(nn.Module):

    def __init__(
        self,
        mediat,
        h1w1,
        encode_backbone,
        h2w2,
        encode_posit_embed,
        encode_project,
        initializ,
        correct,
        predict,
        decode_bos,
        decode_posit_embed,
        decode_backbone,
        decode_readout,
    ):
        super().__init__()
        self.mediat = mediat  # type: dVAE
        self.h1w1 = h1w1
        self.encode_backbone = encode_backbone
        self.h2w2 = h2w2
        self.encode_posit_embed = encode_posit_embed
        self.encode_project = encode_project
        self.initializ = initializ
        self.correct = correct
        self.predict = predict
        self.decode_bos = decode_bos
        self.decode_posit_embed = decode_posit_embed
        self.decode_backbone = decode_backbone
        self.register_buffer(
            "mask", pt.triu(pt.ones([math.prod(h1w1)] * 2, dtype=pt.bool), 1)
        )
        self.decode_readout = decode_readout
        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():  # ``zero init conv/linear/gru bias`` converges faster
            if isinstance(m, nn.Conv2d):
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.GRUCell):
                if m.bias:
                    nn.init.zeros_(m.bias_ih)
                    nn.init.zeros_(m.bias_hh)

    def forward(self, input, condit=None):
        """
        video: in shape (b,t,c,h,w)
        condit: in shape (b,t,n,c)
        """
        b, t, c, h, w = input.shape
        input = input.flatten(0, 1)  # (b*t,c,h,w)

        encode = self.encode_backbone(input)  # (b*t,c,h,w)
        encode_ = encode.permute(0, 2, 3, 1)  # (b*t,h,w,c)
        encode_ = self.encode_posit_embed(encode_)
        encode_ = encode_.flatten(1, 2)  # (b*t,h*w,c)
        encode_ = self.encode_project(encode_)

        encode_ = rearrange(encode_, "(b t) hw c -> b t hw c", b=b)

        hidden = self.initializ(b if condit is None else condit[:, 0, :, :])  # (b,n,c)
        correct = []
        attent_ = []
        for i in range(t):
            correct_t, attent_t = self.correct(encode_[:, i, :, :], hidden)
            hidden = self.predict(correct_t)
            correct.append(correct_t)  # [(b,n,c),..]
            attent_.append(attent_t)  # [(b,n,h*w),..]
        correct = pt.stack(correct, 1)  # (b,t,n,c)
        attent_ = pt.stack(attent_, 1)  # (b,t,n,h*w)
        attent = rearrange(attent_, "b t n (h w) -> b t n h w", h=self.h2w2[0])

        zsoft, zidx, quant, decode1 = self.mediat(input)
        zidx = zidx.unflatten(0, [b, t])
        quant = quant.detach()

        quant_ = rearrange(quant, "bt c h w -> bt (h w) c")
        token = pt.cat([self.decode_bos.expand(b * t, -1, -1), quant_], 1)[:, :-1, :]
        token = self.decode_posit_embed(token)
        token = self.decode_backbone(token, correct.flatten(0, 1), self.mask)
        prob_ = self.decode_readout(token)  # (b*t,h*w,c)
        prob = rearrange(prob_, "(b t) (h w) c -> b t c h w", b=b, h=self.h1w1[0])

        segment = attent.argmax(2)  # (b,t,h,w)
        return zidx, prob, segment, correct, attent


class TransformDecodeOCL(nn.Module):

    def __init__(self, embed_dim, num_head, ffn_dim, dropout, num_layer):
        super().__init__()
        self.norm0 = nn.LayerNorm(embed_dim)  # very beneficial
        self.module = TransformDecode(
            TransformDecodeBlock(embed_dim, num_head, ffn_dim, dropout, pre_norm=True),
            num_layer,
            norm9=nn.LayerNorm(embed_dim),
        )
        self.module.layers[0].norm1 = nn.Identity()  # very beneficial

    def forward(self, input, memory, attn_mask=None):
        output = self.norm0(input)
        output = self.module(output, memory, attn_mask)
        return output


class Parameter(nn.Parameter):

    def __new__(cls, func="randn", requires_grad=True, **kwds):
        data = pt.__dict__[func](**kwds)
        return super().__new__(cls, data, requires_grad)
