from copy import deepcopy
from operator import attrgetter
import re

from diffusers.models import AutoencoderKL
from einops import rearrange
import timm
import torch as pt
import torch.nn as nn
import torch.nn.functional as ptnf


####


class ModelWrap(nn.Module):
    """Wrap a model so that its inputs and outputs are in expected dict struct."""

    def __init__(self, model, imap, omap):
        """
        - imap: dict or list.
            If keys in batch mismatches with keys in model.forward, use dict, ie, {key_in_batch: key_in_forward};
            If not, use list.
        - omap: list
        """
        super().__init__()
        assert isinstance(imap, (dict, list, tuple))
        assert isinstance(omap, (list, tuple))
        self.model = model
        self.imap = imap if isinstance(imap, dict) else {_: _ for _ in imap}
        self.omap = omap

    @pt.compile
    def forward(self, input: dict) -> dict:
        input2 = {k: input[v] for k, v in self.imap.items()}
        output = self.model(**input2)
        if not isinstance(output, (list, tuple)):
            output = [output]
        assert len(self.omap) == len(output)
        output2 = dict(zip(self.omap, output))
        return output2


class ModelWrap2(nn.Module):

    def __init__(self, m: nn.Module, imap, omap):
        """
        - imap: dict or list.
            If keys in batch mismatches with keys in model.forward, use dict, ie, {key_in_batch: key_in_forward};
            If not, use list.
        - omap: list
        """
        super().__init__()
        assert isinstance(imap, (dict, list, tuple))
        assert isinstance(omap, (list, tuple))
        self.m = m
        self.imap = imap if isinstance(imap, dict) else {_: _ for _ in imap}
        self.omap = omap

    # @pt.compile
    def forward(self, input: dict) -> dict:
        input2 = {k: input[v] for k, v in self.imap.items()}
        output = self.m(**input2)
        if not isinstance(output, (list, tuple)):
            output = [output]
        assert len(self.omap) == len(output)
        output2 = dict(zip(self.omap, output))
        return output2

    def load(
        self, ckpt_file: str, ckpt_map: list, compile_prefix="_orig_mod.", verbose=True
    ):
        state_dict = pt.load(ckpt_file, map_location="cuda")
        if ckpt_map is None:
            if verbose:
                print("fully")
            self.load_state_dict(state_dict)
        elif isinstance(ckpt_map, (list, tuple)):
            for dst, src in ckpt_map:
                dkeys = [_ for _ in self.state_dict() if _.startswith(dst)]
                skeys = [
                    _
                    for _ in state_dict
                    if _.startswith(src) or _.startswith(compile_prefix + src)
                ]
                assert len(dkeys) == len(skeys) > 0
                for dk, sk in zip(dkeys, skeys):
                    if verbose:
                        print(dk, sk)
                    self.state_dict()[dk].data[...] = state_dict[sk]
        else:
            raise "ValueError"
        if verbose:
            print(f"checkpoint ``{ckpt_file}`` loaded")

    def save(self, save_file, weights_only=True):
        if weights_only:
            save_obj = self.state_dict()
        else:
            save_obj = self
        pt.save(save_obj, save_file)

    def freez(self, freez: list, verbose=True):
        for fk in freez:
            for param in attrgetter(fk)(self).parameters():
                param.requires_grad = False
        if verbose:
            [print(k, v.requires_grad) for k, v in self.named_parameters()]

    def group_params(self, keys, compile_prefix="_orig_mod."):
        param_group_idxs = keys
        named_parameters = list(self.named_parameters())
        param_groups = []
        for pgi in param_group_idxs:
            param_group = dict(params=[], lr=pgi["lr"])
            for key, param in named_parameters:
                if key.startswith(compile_prefix):
                    key = key[len(compile_prefix) :]
                if re.match(pgi["key"], key):
                    param_group["params"].append(param)
            param_groups.append(param_group)
        assert len(named_parameters) == sum(len(_["params"]) for _ in param_groups)
        return param_groups


class Sequential(nn.Sequential):
    """"""

    def __init__(self, modules: list):
        super().__init__(*modules)

    def forward(self, *input):
        for module in self:
            if isinstance(input, (list, tuple)):
                input = module(*input)
            else:
                input = module(input)
        return input


ModuleList = nn.ModuleList


####


Embedding = nn.Embedding


Conv2d = nn.Conv2d


ConvTranspose2d = nn.ConvTranspose2d


AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d


Identity = nn.Identity


ReLU = nn.ReLU


GELU = nn.GELU


SiLU = nn.SiLU


Mish = nn.Mish


class Mean(nn.Module):

    def __init__(self, dims):
        super().__init__()
        self.dims = dims

    def forward(self, input):
        return input.mean(self.dims)


class Conv2dPixelShuffle(nn.Sequential):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        upscale=2,
    ):
        conv = nn.Conv2d(
            in_channels,
            out_channels * upscale**2,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )
        shuff = nn.PixelShuffle(upscale)
        super().__init__(conv, shuff)


Linear = nn.Linear


GroupNorm = nn.GroupNorm


LayerNorm = nn.LayerNorm


####


class QueryKeyAttention(nn.Module):
    """TODO XXX correct the wrong impl in modularization/cgv!"""

    def __init__(self, q_dim, k_dim, mid_dim, num_head=1):
        super().__init__()
        self.proj_q = nn.Linear(q_dim, mid_dim)
        self.proj_k = nn.Linear(k_dim, mid_dim)
        self.num_head = num_head

    def forward(self, query, key):
        q = self.proj_q(query)
        k = self.proj_k(key)
        into = "b n (h d) -> b h n d"
        q = rearrange(q, into, h=self.num_head)
        k = rearrange(k, into, h=self.num_head)
        scale = q.size(3) ** -0.5
        a = pt.einsum("bhqd,bhkd->bhqk", q * scale, k)
        a = a.mean(1)
        return a


class MultiHeadAttention(nn.Module):
    """nn.MultiheadAttention"""  # TODO XXX modularization/cgv: correct the wrong implementation!

    def __init__(
        self,
        embed_dim,
        num_head,
        dropout=0,
        qkv_bias=False,
        q_dim=None,
        kv_dim=None,
        o_bias=False,
        o_dim=None,
    ):
        assert embed_dim % num_head == 0
        super().__init__()
        self.num_head = num_head
        q_dim = q_dim or embed_dim
        kv_dim = kv_dim or q_dim
        o_dim = o_dim or q_dim
        self.proj_q = nn.Linear(q_dim, embed_dim, bias=qkv_bias)
        self.proj_k = nn.Linear(kv_dim, embed_dim, bias=qkv_bias)
        self.proj_v = nn.Linear(kv_dim, embed_dim, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) if dropout else lambda _: _
        self.proj_o = nn.Linear(o_dim, embed_dim, bias=o_bias)

    def forward(self, query, key, value, attn_mask=None):
        q = self.proj_q(query)
        k = self.proj_k(key)
        v = self.proj_v(value)

        into = "b n (h d) -> b h n d"  # bnhd seems not faster
        q = rearrange(q, into, h=self.num_head)
        k = rearrange(k, into, h=self.num_head)
        v = rearrange(v, into, h=self.num_head)

        scale = q.size(3) ** -0.5
        a = pt.einsum("bhqd,bhkd->bhqk", q * scale, k)
        if attn_mask is not None:  # backward dim match
            a = a.masked_fill(attn_mask, -pt.inf)
        a = a.softmax(3)  # , a.dtype
        a = self.dropout(a)
        o = pt.einsum("bhqv,bhvd->bhqd", a, v)

        attent = a.mean(1)

        back = "b h n d -> b n (h d)"
        o = rearrange(o, back)

        output = self.proj_o(o)
        return output, attent


class TransformEncodeBlock(nn.Module):
    """nn.TransformerEncoderLayer"""  # TODO correct the wrong impl in cgv

    def __init__(
        self, embed_dim, num_head, ffn_dim, dropout=0.1, pre_norm=False, bias=False
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attent = MultiHeadAttention(
            embed_dim, num_head, dropout, qkv_bias=bias, o_bias=bias
        )
        self.dropout1 = nn.Dropout(dropout) if dropout else lambda _: _
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ffn = MLP(embed_dim, [ffn_dim, embed_dim], dropout)
        self.dropout2 = nn.Dropout(dropout) if dropout else lambda _: _

    def forward(self, input, attn_mask=None):
        x = input
        if self.pre_norm:
            z1 = self.norm1(x)
            x = x + self.dropout1(self.attent(z1, z1, z1, attn_mask)[0])
            z2 = self.norm2(x)
            x = x + self.dropout2(self.ffn(z2))
        else:
            z1 = x + self.dropout1(self.attent(x, x, x, attn_mask)[0])
            x = self.norm1(z1)
            z2 = x + self.dropout2(self.ffn(x))
            x = self.norm2(z2)
        return x


class TransformEncode(nn.Module):
    """nn.TransformerEncoder"""

    def __init__(self, layer, num_layer, norm9=None):
        super().__init__()
        self.layers = nn.ModuleList([deepcopy(layer) for _ in range(num_layer)])
        self.norm9 = norm9

    def forward(self, input, attn_mask=None):
        output = input
        for layer in self.layers:
            output = layer(output, attn_mask)
        if self.norm9 is not None:
            output = self.norm9(output)
        return output


class TransformDecodeBlock(nn.Module):
    """nn.TransformerDecoderLayer"""

    def __init__(
        self,
        embed_dim,
        num_head,
        ffn_dim,
        dropout=0.1,
        pre_norm=False,
        bias=False,
        kv_dim=None,
    ):
        super().__init__()
        kv_dim = kv_dim or embed_dim
        self.pre_norm = pre_norm
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attent1 = MultiHeadAttention(
            embed_dim, num_head, dropout, qkv_bias=bias, o_bias=bias
        )
        self.dropout1 = nn.Dropout(dropout) if dropout else lambda _: _
        self.norm2 = nn.LayerNorm(embed_dim)
        self.attent2 = MultiHeadAttention(
            embed_dim, num_head, dropout, qkv_bias=bias, kv_dim=kv_dim, o_bias=bias
        )
        self.dropout2 = nn.Dropout(dropout) if dropout else lambda _: _
        self.norm3 = nn.LayerNorm(embed_dim)
        self.ffn = MLP(embed_dim, [ffn_dim, embed_dim], dropout)
        self.dropout3 = nn.Dropout(dropout) if dropout else lambda _: _

    def forward(self, input, memory, attn_mask=None):
        x = input
        if self.pre_norm:
            z1 = self.norm1(x)
            x = x + self.dropout1(self.attent1(z1, z1, z1, attn_mask)[0])
            z2 = self.norm2(x)
            x = x + self.dropout2(self.attent2(z2, memory, memory)[0])
            z3 = self.norm3(x)
            x = x + self.dropout3(self.ffn(z3))
        else:
            z1 = x + self.dropout1(self.attent1(x, x, x, attn_mask)[0])
            x = self.norm1(z1)
            z2 = x + self.dropout2(self.attent2(x, memory, memory)[0])
            x = self.norm2(z2)
            z3 = x + self.dropout3(self.ffn(x))
            x = self.norm3(z3)
        return x


class TransformDecode(nn.Module):
    """nn.TransformerDecoder"""

    def __init__(self, layer, num_layer, norm9=None):
        super().__init__()
        self.layers = nn.ModuleList([deepcopy(layer) for _ in range(num_layer)])
        self.norm9 = norm9

    def forward(self, input, memory, attn_mask=None):
        output = input
        for layer in self.layers:
            output = layer(output, memory, attn_mask)
        if self.norm9 is not None:
            output = self.norm9(output)
        return output


###


class CNN(nn.Sequential):
    """hyperparam setting of ConvTranspose2d:
    https://blog.csdn.net/pl3329750233/article/details/130283512.
    """

    conv_types = {
        0: nn.Conv2d,
        1: lambda *a, **k: nn.ConvTranspose2d(*a, **k, output_padding=1),
        2: lambda *a, **k: Conv2dPixelShuffle(*a, **k, upscale=2),
    }

    def __init__(
        self, channel0, channels, kernels, strides, ctypes=0, gn=0, act="Mish"
    ):
        if isinstance(ctypes, int):
            ctypes = [ctypes] * len(channels)
        layers = []
        ci = channel0
        for i, (t, c, k, s) in enumerate(zip(ctypes, channels, kernels, strides)):
            p = k // 2 if k % 2 != 0 else 0  # XXX for k=s=4, requires isize%k==0
            if i == 0:
                ls = [__class__.conv_types[t](ci, c, k, stride=s, padding=p)]
            else:
                ls = [
                    nn.GroupNorm(gn, ci) if gn else None,
                    nn.__dict__[act](
                        inplace=True
                    ),  # SiLU/SELU may fail; Mish>ReLU>Hardswish
                    __class__.conv_types[t](ci, c, k, stride=s, padding=p),
                ]
            layers.extend([_ for _ in ls if _])
            ci = c
        super().__init__(*layers)


class BigLittle(nn.Module):
    """Big model low output resolution; Little model high output resolution.
    Upsample the former and fuse with the latter.
    """

    def __init__(
        self,
        big,
        little,
        bpre=None,
        lpre=None,
        bpost=None,
        lpost=None,
        fuse="sum",
        post=None,
    ):
        super().__init__()
        self.big = big
        self.little = little
        self.bpre = bpre
        self.lpre = lpre
        self.bpost = bpost
        self.lpost = lpost
        if fuse == "sum":
            self.fuse = lambda b, l: b + l
        elif fuse == "cat":
            self.fuse = lambda b, l: pt.cat([b, l], 1)
        else:
            raise "NotImplemented"
        self.post = post

    def forward(self, input):
        """
        input: in shape (b,c,h,w)
        output: in shape (b,c,h,w)
        """
        xb = xl = input
        if self.bpre:
            xb = self.bpre(xb)
        xb = self.big(xb)
        if self.bpost:
            xb = self.bpost(xb)
        if self.lpre:
            xl = self.lpre(xl)
        xl = self.little(xl)
        if self.lpost:
            xl = self.lpost(xl)
        scale = xl.size(2) / xb.size(2)
        assert scale == xl.size(3) / xb.size(3)
        xb = ptnf.upsample_nearest(xb, scale_factor=int(scale))
        xf = self.fuse(xb, xl)  # TODO upssample
        output = self.post(xf)
        return output


class MLP(nn.Sequential):
    """"""

    def __init__(self, channel0, channels, dropout=0):
        layers = []
        num = len(channels)
        ci = channel0
        for i, c in enumerate(channels):
            if i == 0:
                layers.append(nn.Linear(ci, c))
            else:
                layers.extend([nn.GELU(), nn.Linear(ci, c)])
            if i + 1 < num and dropout > 0:
                layers.append(nn.Dropout(dropout))
            ci = c
        super().__init__(*layers)


class ResNet(nn.Sequential):
    """https://huggingface.co/timm/resnet18.fb_swsl_ig1b_ft_in1k"""

    def __init__(
        self,
        model_name="resnet18.fb_swsl_ig1b_ft_in1k",
        in_dim=3,
        k0=7,
        strides=[2] * 5,
        dilats=[1] * 5,
        gn=0,
        learn_changed_only=False,
    ):
        assert in_dim == 3
        assert 2 <= len(strides) <= 5
        assert all(_ in [1, 2] for _ in strides[1:]) and strides[0] in [1, 2, 4]
        assert len(strides) <= len(dilats)
        dilats = dilats[: len(strides)]
        assert dilats[:2] == [1, 1]

        model = timm.create_model(model_name, pretrained=True, num_classes=0)
        resnet = deepcopy(model)
        if learn_changed_only:
            for p in resnet.parameters():
                p.requires_grad = False

        if gn > 0:
            for name, module in model.named_modules():
                if isinstance(module, nn.BatchNorm2d):
                    group_norm = nn.GroupNorm(gn, module.num_features)
                    group_norm.weight.data[...] = module.weight
                    group_norm.bias.data[...] = module.bias
                    if "." in name:
                        parent = attrgetter(".".join(name.split(".")[:-1]))(resnet)
                    else:
                        parent = resnet
                    setattr(parent, name.split(".")[-1], group_norm)
                    for p in group_norm.parameters():
                        p.requires_grad = True

        if in_dim != 3 or k0 != 7:  # for dvae, k4d4 > k7d2d2
            assert k0 < 7
            conv1 = nn.Conv2d(in_dim, 64, k0, 2, (k0 - 1) // 2, bias=False)
            conv1.weight.data[...] = ptnf.interpolate(resnet.conv1.weight, [k0, k0])
            resnet.conv1 = conv1
            for p in conv1.parameters():
                p.requires_grad = True

        if strides[0] != 2:
            resnet.conv1.stride = strides[0]
            for p in resnet.conv1.parameters():
                p.requires_grad = True
        layers = [resnet.conv1, resnet.bn1, resnet.act1]

        if strides[1] == 2:
            layers.append(resnet.maxpool)
        layers.append(resnet.layer1)

        for i, (s, d) in enumerate(zip(strides, dilats)):
            if i < 2:  # skip conv1 and maxpool
                continue
            layer = getattr(resnet, f"layer{i}")

            if s == 1:
                layer[0].conv1.stride = 1
                for p in layer[0].conv1.parameters():
                    p.requires_grad = True
                layer[0].downsample[0].stride = 1
                for p in layer[0].downsample[0].parameters():
                    p.requires_grad = True

            if d != 1:
                layer[0].conv1.dilation = (d,) * 2
                layer[0].conv1.padding = (layer[0].conv1.padding[0] + d // 2,) * 2
                for p in layer[0].conv1.parameters():
                    p.requires_grad = True

            layers.append(layer)

        super().__init__(*layers)


class Dinolet(nn.Module):
    ARCH_DICT = {
        "v1s8": ["facebookresearch/dino:main", "dino_vits8"],
        "v1b8": ["facebookresearch/dino:main", "dino_vitb8"],
        "v2s14": ["facebookresearch/dinov2", "dinov2_vits14"],
        "v2b14": ["facebookresearch/dinov2", "dinov2_vitb14"],
    }

    def __init__(self, arch="v1s8", num_block=None, learn=False):
        super().__init__()
        self.dino = pt.hub.load(*Dinolet.ARCH_DICT[arch])
        self.dino.blocks = self.dino.blocks[:num_block]
        self._v = int(arch[1])
        if not learn:
            self.eval()
            for p in self.parameters():
                p.requires_grad = False

    def forward(self, input):
        b, c0, h0, w0 = input.shape
        if self._v == 1:
            x = self.dino.prepare_tokens(input)
            for blk in self.dino.blocks:
                x = blk(x)
            x = self.dino.norm(x)[:, 1:, :]  # (b,h*w,c)
            d = 8
        elif self._v == 2:
            x = self.dino.prepare_tokens_with_masks(input, None)
            for blk in self.dino.blocks:
                x = blk(x)
            x = self.dino.norm(x)[:, self.dino.num_register_tokens + 1 :, :]
            d = 14
        else:
            raise "NotImplemented"
        return rearrange(x, "b (h w) c -> b c h w", h=h0 // d, w=w0 // d)


class ConvNeXt(nn.Sequential):
    """https://huggingface.co/timm/convnext_nano.in12k"""

    def __init__(self, model_name="convnext_atto.d2_in1k", strides=[4, 2, 2, 2]):
        assert strides[0] == 4 and all(_ in [1, 2] for _ in strides[1:])
        model = timm.create_model(model_name, pretrained=True, num_classes=0)

        __class__.freez(model.stem)
        __class__.freez(model.stages[0])

        if strides[1] == 1:
            cd1 = model.stages[1].downsample[1]
            assert isinstance(cd1, nn.Conv2d) and cd1.stride == (2, 2)
            cd1.kernel_size = (1, 1)
            cd1.stride = (1, 1)
            cd1.weight.data = cd1.weight.mean([2, 3], keepdim=True)

        if len(strides) > 2 and strides[2] == 1:
            cd2 = model.stages[2].downsample[1]
            assert isinstance(cd2, nn.Conv2d) and cd2.stride == (2, 2)
            cd2.kernel_size = (1, 1)
            cd2.stride = (1, 1)
            cd2.weight.data = cd2.weight.mean([2, 3], keepdim=True)

        if len(strides) > 3 and strides[3] == 1:
            cd3 = model.stages[3].downsample[1]
            assert isinstance(cd3, nn.Conv2d) and cd3.stride == (2, 2)
            cd3.kernel_size = (1, 1)
            cd3.stride = (1, 1)
            cd3.weight.data = cd3.weight.mean([2, 3], keepdim=True)

        super().__init__(model.stem, model.stages[: len(strides)])
        # print(self)

    @staticmethod
    def freez(module):
        module.eval()
        for p in module.parameters():
            p.requires_grad = False


class ResEncoderAkl(nn.Sequential):
    """https://huggingface.co/stabilityai/sd-vae-ft-mse"""

    def __init__(
        self, model_name="stabilityai/sd-vae-ft-mse", down=4, extra=False, learn=False
    ):
        vae = AutoencoderKL.from_pretrained(model_name)
        conv_in = vae.encoder.conv_in
        num = int(down**0.5)
        if extra:
            num += 1
        down_blocks = vae.encoder.down_blocks[:num]
        if extra:
            assert len(down_blocks[-1].downsamplers) == 1
            down_blocks[-1].downsamplers[0] = nn.Identity()
        # TODO mid_block, conv_norm_out, conv_act, conv_out
        super().__init__(conv_in, *down_blocks)
        if not learn:
            self.eval()
            for p in self.parameters():
                p.requires_grad = False
