#!/usr/bin/env python3

from functools import partial
from typing import Callable, Literal, Type

import torch
import torch.nn as nn
from timm.layers import (
    get_norm_layer,
    get_act_layer,
    DropPath,
    PatchEmbed,
    Mlp,
    LayerType,
)
from timm.models import VisionTransformer

from modern_hopfield_attention.layers import ModernHopfieldAttention


class LayerScale(nn.Module):
    def __init__(
        self,
        dim: int,
        init_values: float = 1e-5,
        inplace: bool = False,
    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class MHABlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        attn_alpha: float = 0.5,
        skip_alpha: float = 0.5,
        mlp_ratio: float = 4,
        qkv_bias: bool = False,
        qk_norm: bool = False,
        proj_drop: float = 0,
        attn_drop: float = 0,
        init_values: float | None = None,
        drop_path: float = 0,
        act_layer: Type[nn.Module] = nn.GELU,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        mlp_layer: Type[nn.Module] = Mlp,
        *args,
        **kwargs
    ) -> None:
        super().__init__()
        # layers
        self.norm1 = norm_layer(dim)
        self.attn = ModernHopfieldAttention(
            dim=dim,
            num_heads=num_heads,
            attn_alpha=attn_alpha,
            skip_alpha=skip_alpha,
            causal=False,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
        )
        self.ls1 = (
            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        )
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = (
            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        )
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(
        self, x: torch.Tensor, h: torch.Tensor | None = None
    ) -> tuple[torch.Tensor, torch.Tensor]:

        residual = x.clone()

        x = self.norm1(x)
        x, h = self.attn(x, h)
        x = self.drop_path1(x)

        x = x + residual

        # mlp layer
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x, h


class MHAVisionTransformer(VisionTransformer):
    def __init__(
        self,
        attn_alpha: float,
        skip_alpha: float,
        img_size: int | tuple[int, int] = 224,
        patch_size: int | tuple[int, int] = 16,
        in_chans: int = 3,
        num_classes: int = 1000,
        global_pool: Literal["", "avg", "avgmax", "max", "token", "map"] = "token",
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        init_values: float | None = None,
        class_token: bool = True,
        pos_embed: str = "learn",
        no_embed_class: bool = False,
        reg_tokens: int = 0,
        pre_norm: bool = False,
        fc_norm: bool | None = None,
        dynamic_img_size: bool = False,
        dynamic_img_pad: bool = False,
        drop_rate: float = 0.0,
        pos_drop_rate: float = 0.0,
        patch_drop_rate: float = 0.0,
        proj_drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
        fix_init: bool = False,
        embed_layer: Callable = PatchEmbed,
        norm_layer: LayerType | None = None,
        act_layer: LayerType | None = None,
        block_fn: Type[nn.Module] = MHABlock,
        mlp_layer: Type[nn.Module] = Mlp,
    ) -> None:
        super().__init__(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            num_classes=num_classes,
            global_pool=global_pool,
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            init_values=init_values,
            class_token=class_token,
            pos_embed=pos_embed,
            no_embed_class=no_embed_class,
            reg_tokens=reg_tokens,
            pre_norm=pre_norm,
            fc_norm=fc_norm,
            dynamic_img_size=dynamic_img_size,
            dynamic_img_pad=dynamic_img_pad,
            drop_rate=drop_rate,
            pos_drop_rate=pos_drop_rate,
            patch_drop_rate=patch_drop_rate,
            proj_drop_rate=proj_drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            weight_init=weight_init,
            fix_init=fix_init,
            embed_layer=embed_layer,
            norm_layer=norm_layer,
            act_layer=act_layer,
            block_fn=nn.Identity,
            mlp_layer=mlp_layer,
        )

        norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
        act_layer = get_act_layer(act_layer) or nn.GELU

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList(
            [
                block_fn(
                    dim=embed_dim,
                    num_heads=num_heads,
                    attn_alpha=attn_alpha,
                    skip_alpha=skip_alpha,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_norm=qk_norm,
                    proj_drop=proj_drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    act_layer=act_layer,
                    mlp_layer=mlp_layer,
                )
                for i in range(depth)
            ]
        )

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)

        h = None
        for module in self.blocks:
            x, h = module(x=x, h=h)

        x = self.norm(x)

        return x

    def forward(
        self,
        x: torch.Tensor,
    ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x

    def register_hooks(self) -> None:
        self.hook_input = list()

        def hook_fn(module, input, output) -> None:
            if isinstance(module, ModernHopfieldAttention):
                self.hook_input.append(input[0].detach().cpu())

        for block in self.blocks:
            block.attn.register_forward_hook(hook_fn)

    def clear_hooks(self) -> None:
        self.hook_input = list()


class UniversalMHAViT(VisionTransformer):
    def __init__(
        self,
        attn_alpha: float,
        skip_alpha: float,
        img_size: int | tuple[int, int] = 224,
        patch_size: int | tuple[int, int] = 16,
        in_chans: int = 3,
        num_classes: int = 1000,
        global_pool: Literal["", "avg", "avgmax", "max", "token", "map"] = "token",
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        init_values: float | None = None,
        class_token: bool = True,
        pos_embed: str = "learn",
        no_embed_class: bool = False,
        reg_tokens: int = 0,
        pre_norm: bool = False,
        fc_norm: bool | None = None,
        dynamic_img_size: bool = False,
        dynamic_img_pad: bool = False,
        drop_rate: float = 0.0,
        pos_drop_rate: float = 0.0,
        patch_drop_rate: float = 0.0,
        proj_drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
        fix_init: bool = False,
        embed_layer: Callable = PatchEmbed,
        norm_layer: LayerType | None = None,
        act_layer: LayerType | None = None,
        block_fn: Type[nn.Module] = MHABlock,
        mlp_layer: Type[nn.Module] = Mlp,
    ) -> None:
        super().__init__(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            num_classes=num_classes,
            global_pool=global_pool,
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            init_values=init_values,
            class_token=class_token,
            pos_embed=pos_embed,
            no_embed_class=no_embed_class,
            reg_tokens=reg_tokens,
            pre_norm=pre_norm,
            fc_norm=fc_norm,
            dynamic_img_size=dynamic_img_size,
            dynamic_img_pad=dynamic_img_pad,
            drop_rate=drop_rate,
            pos_drop_rate=pos_drop_rate,
            patch_drop_rate=patch_drop_rate,
            proj_drop_rate=proj_drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            weight_init=weight_init,
            fix_init=fix_init,
            embed_layer=embed_layer,
            norm_layer=norm_layer,
            act_layer=act_layer,
            block_fn=nn.Identity,
            mlp_layer=mlp_layer,
        )

        norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
        act_layer = get_act_layer(act_layer) or nn.GELU

        self.layer = block_fn(
            dim=embed_dim,
            num_heads=num_heads,
            attn_alpha=attn_alpha,
            skip_alpha=skip_alpha,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            proj_drop=proj_drop_rate,
            attn_drop=attn_drop_rate,
            drop_path=drop_path_rate,
            norm_layer=norm_layer,
            act_layer=act_layer,
            mlp_layer=mlp_layer,
        )
        self.blocks = nn.ModuleList([self.layer for _ in range(depth)])

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)

        h = None
        for module in self.blocks:
            x, h = module(x=x, h=h)

        x = self.norm(x)

        return x

    def forward(
        self,
        x: torch.Tensor,
    ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x

    def register_hooks(self) -> None:
        self.hook_input = list()
        self.hook_count = 0

        def hook_fn(module, input, output) -> None:
            if isinstance(module, ModernHopfieldAttention) and self.hook_count == 0:
                self.hook_input.append(input[0].detach().cpu())
                self.hook_count += 1

        self.layer.attn.register_forward_hook(hook_fn)

    def clear_hooks(self) -> None:
        self.hook_input = list()
        self.hook_count = 0


__all__ = ["MHAVisionTransformer", "UniversalMHAViT"]

if __name__ == "__main__":

    from modern_hopfield_attention.data import VIT_MODEL_SIZE

    model_size = "large"
    model = MHAVisionTransformer(
        attn_alpha=0.1,
        skip_alpha=0.1,
        num_classes=1,
        depth=VIT_MODEL_SIZE[model_size]["depth"],
        num_heads=VIT_MODEL_SIZE[model_size]["num_heads"],
        embed_dim=VIT_MODEL_SIZE[model_size]["embed_dim"],
    )
    import torch

    dummy = torch.randn(1, 3, 224, 224)
    from torchinfo import summary

    summary(model)

    # model.register_hooks()
    # model(dummy)
    # print(len(model.hook_input))
