from torch import nn

from src.models.kappa_overrides.a_dpa import AnchoredDotProductAttention
from kappamodules.init.functional import init_norms_as_noaffine
from src.models.kappa_overrides.xa import CrossAttention
from kappamodules.layers import DropPath, LayerScale
from src.models.kappa_overrides.mlp import Mlp


class PrenormWithCrossAttentionBlock(nn.Module):
    def __init__(
            self,
            dim,
            num_heads,
            n_anchors: int,
            mlp_hidden_dim=None,
            qkv_bias=True,
            proj_bias=True,
            mlp_bias=True,
            drop_path=0.,
            act_ctor=nn.GELU,
            norm_ctor=nn.LayerNorm,
            attn_ctor=AnchoredDotProductAttention,
            xattn_ctor=CrossAttention,
            layerscale=None,
            eps=1e-6,
            init_weights="xavier_uniform",
            init_norms="nonaffine",
            init_last_proj_zero=False,
    ):
        super().__init__()
        self.init_norms = init_norms
        mlp_hidden_dim = mlp_hidden_dim or dim * 4
        self.norm1 = norm_ctor(dim, eps=eps)
        self.attn = attn_ctor(
            dim=dim,
            num_heads=num_heads,
            n_anchors=n_anchors,
            qkv_bias=qkv_bias,
            proj_bias=proj_bias,
            init_weights=init_weights,
            init_last_proj_zero=init_last_proj_zero,
        )
        self.ls1 = nn.Identity() if layerscale is None else LayerScale(dim, init_scale=layerscale)
        self.drop_path1 = DropPath(drop_prob=drop_path)
        
        self.norm2 = norm_ctor(dim, eps=eps)
        self.mlp = Mlp(
            in_dim=dim,
            hidden_dim=mlp_hidden_dim,
            act_ctor=act_ctor,
            init_weights=init_weights,
            init_last_proj_zero=init_last_proj_zero,
            bias=mlp_bias,
        )
        self.ls2 = nn.Identity() if layerscale is None else LayerScale(dim, init_scale=layerscale)
        self.drop_path2 = DropPath(drop_prob=drop_path)
        
        # xattn block
        
        self.norm3 = norm_ctor(dim, eps=eps)
        self.xattn = xattn_ctor(
            dim=dim,
            num_heads=num_heads,
            init_weights=init_weights,
            init_last_proj_zero=init_last_proj_zero,
        )
        self.ls3 = nn.Identity() if layerscale is None else LayerScale(dim, init_scale=layerscale)
        self.drop_path3 = DropPath(drop_prob=drop_path)
        
        self.reset_parameters()

    def reset_parameters(self):
        if self.init_norms == "torch":
            pass
        elif self.init_norms == "nonaffine":
            init_norms_as_noaffine(self.norm1)
            init_norms_as_noaffine(self.norm2)
            init_norms_as_noaffine(self.norm3)
        else:
            raise NotImplementedError

    # def _attn_residual_path(self, x, attn_mask):
    #     return self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask))
    def _attn_residual_path(self, x, **kwargs):
        return self.ls1(self.attn(self.norm1(x), **kwargs))

    # def _xattn_residual_path(self, x, kv, attn_mask):
    #     return self.ls1(self.xattn(self.norm1(x), kv=kv, attn_mask=attn_mask))

    def _xattn_residual_path(self, x, kv, **kwargs):
        return self.ls1(self.xattn(self.norm1(x), kv=kv, **kwargs))

    def _mlp_residual_path(self, x, **kwargs):
        return self.ls2(self.mlp(self.norm2(x)))

    def forward(self, x, kv, **kwargs):
        x = self.drop_path1(
            x,
            residual_path=self._attn_residual_path,
            residual_path_kwargs=kwargs,
        )
        x = self.drop_path3(
            x,
            residual_path=self._xattn_residual_path,
            residual_path_kwargs=dict(kv=kv, **kwargs),
        )
        x = self.drop_path2(x, self._mlp_residual_path)
        return x
