from torch import Tensor
from .vittm import ttm_model_factory
from timm.models._builder import build_model_with_cfg
from timm.models.vision_transformer import (
    VisionTransformer,
    checkpoint_filter_fn,
    resample_abs_pos_embed,
)
import torch
import torch.nn as nn
import timm
from typing import Optional
from timm.layers import to_2tuple, DropPath
from timm.models.vision_transformer import Attention, Mlp, PatchEmbed, Block, LayerScale
from ..pos_embed import get_2d_sincos_pos_embed


BASE_CLS = ttm_model_factory(VisionTransformer)


class Block(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        qk_norm: bool = False,
        proj_drop: float = 0.1,
        attn_drop: float = 0.1,
        init_values: Optional[float] = None,
        drop_path: float = 0.1,
        act_layer: nn.Module = nn.GELU,
        norm_layer: nn.Module = nn.LayerNorm,
        mlp_layer: nn.Module = Mlp,  # Ensure Mlp class is defined or imported appropriately
    ) -> None:
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm1_a1 = norm_layer(dim)
        self.norm1_a2 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,  # Assuming Attention class is updated to use this
        )
        self.ls1 = (
            LayerScale(dim, init_values=init_values)
            if init_values is not None
            else nn.Identity()
        )
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.norm2_a1 = norm_layer(dim)
        self.norm2_a2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = (
            LayerScale(dim, init_values=init_values)
            if init_values is not None
            else nn.Identity()
        )
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x: torch.Tensor, modality: Optional[str] = None) -> torch.Tensor:
        if modality is None:
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        elif modality == "a1":
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1_a1(x))))
            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2_a1(x))))
        elif modality == "a2":
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1_a2(x))))
            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2_a2(x))))
        return x


class laddersym(BASE_CLS):

    def embed(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        memory = self.memory_embedder(x1)
        process = self.process_embedder(x2)
        return memory, process

    def forward_features(
        self,
        x1: Tensor,
        x2: Tensor,
        memory_patch_size: int | None = None,
        process_patch_size: int | None = None,
        visualize_ca: bool | None = None,
    ) -> Tensor:
        # Memory Embedding
        x1 = x1.unsqueeze(1)
        x1 = x1.transpose(2, 3)
        memory = self.memory_embedder(x1)

        # Process Embedding
        x2 = x2.unsqueeze(1)
        x2 = x2.transpose(2, 3)
        process = self.process_embedder(x2)

        # Add Positional Embeddings
        memory = self._pos_embed(memory, self.memory_pos_embed)
        process = self._pos_embed(process, self.process_pos_embed)

        # Drop Out, Normalization
        process = self.patch_drop(process)
        process = self.norm_pre(process)

        # Iterate over blocks
        # TODO: ask purvish about this indexing
        for i in range(0, len(self.blocks), 2):
            # Access the normalization layer corresponding to the current block index
            norm_index = i // 2

            # Apply read normalization
            memory = self.read_norm[norm_index](memory)
            process = self.read_norm[norm_index](process)

            # Read: Memory -> Process
            rprocess = self.read_head[i](memory, process)

            process = self.read_fusion(rw=rprocess, target=process)

            # Block
            process = self.blocks[i](process)

            # Apply write normalization
            memory = self.write_norm[norm_index](memory)
            process = self.write_norm[norm_index](process)
            # Write: Process -> Memory
            wmemory = self.write_head[i](process, memory)
            memory = self.write_fusion(rw=wmemory, target=memory)
            # memory = self.mem_blocks[i](memory)
            memory = self.blocks[i + 1](memory)

        # Final Layer Normalization
        memory, process = self.norm(memory), self.process_norm(process)

        # Return memory and process
        return memory, process

    def forward(
        self, a1: torch.Tensor, a2: torch.Tensor, pre_logits: bool = False, **kwargs
    ) -> torch.Tensor:
        memory, process = self.forward_features(a1, a2, visualize_ca=True)
        features = torch.cat([memory, process], dim=1)
        features = self.proj(features)
        return features


def _create_vittm(
    variant: str, pretrained: bool = False, **kwargs
) -> VisionTransformer:
    out_indices = kwargs.pop("out_indices", 3)
    if "flexi" in variant:
        # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
        # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
        _filter_fn = partial(
            checkpoint_filter_fn, interpolation="bilinear", antialias=False
        )
    else:
        _filter_fn = checkpoint_filter_fn

    # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
    strict = False
    if "siglip" in variant and kwargs.get("global_pool", None) != "map":
        strict = False

    return build_model_with_cfg(
        laddersym,
        variant,
        pretrained,
        pretrained_filter_fn=_filter_fn,
        pretrained_strict=strict,
        feature_cfg=dict(out_indices=out_indices, feature_cls="getter"),
        **kwargs,
    )


def vittm_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
    """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
    """
    model_args = dict(
        patch_size=16,
        in_chans=1,
        embed_dim=768,
        depth=12,
        num_heads=12,
        rw_head_type="ca",
        fusion_type="residual",
        process_embedder_type="patch",
        process_ps=16,
    )
    model = _create_vittm(
        "vit_base_patch16_224", pretrained=pretrained, **dict(model_args, **kwargs)
    )
    return model
