import torch
import torch.nn as nn

# Public builders & weight enums
from torchvision.models.video import mvit_v2_s, mvit_v1_b
from torchvision.models.video import MViT_V2_S_Weights, MViT_V1_B_Weights

# The actual base class + helper used in the stock forward
from torchvision.models.video.mvit import MViT as _MViT
from torchvision.models.video.mvit import _unsqueeze  # handles (B,C,H,W) -> (B,C,1,H,W)

def interpolate_spatial_mvit_embed(embedding, new_size):
    return nn.functional.interpolate(
        embedding.permute(1, 0).unsqueeze(0),
        size=new_size**2,
        mode="linear",
    ).squeeze(0).permute(1, 0)

def interpolate_temporal_mvit_embed(embedding, T):
    return nn.functional.interpolate(
        embedding.permute(1, 0).unsqueeze(0),
        size=T//2,
        mode="linear",
    ).squeeze(0).permute(1, 0)


class FlexPositionalEncoding(nn.Module):
    def __init__(self, embed_size: int, spatial_size: tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None:
        super().__init__()
        self.spatial_size = spatial_size
        self.temporal_size = temporal_size

        self.class_token = nn.Parameter(torch.zeros(embed_size))
        self.spatial_pos: Optional[nn.Parameter] = None
        self.temporal_pos: Optional[nn.Parameter] = None
        self.class_pos: Optional[nn.Parameter] = None
        if not rel_pos_embed:
            self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size))
            self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size))
            self.class_pos = nn.Parameter(torch.zeros(embed_size))

    def forward(self, x: torch.Tensor, new_size, T) -> torch.Tensor:
        self.spatial_size = (new_size, new_size)
        self.temporal_size = T//2
        class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1)
        x = torch.cat((class_token, x), dim=1)
        if new_size == 3136 and T == 8: # default mvit behavior
            if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None:
                hw_size, embed_size = self.spatial_pos.shape
                pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0)
                pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size))
                pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0)
                x.add_(pos_embedding)
        else:
            spatial_pos = interpolate_spatial_mvit_embed(self.spatial_pos, new_size)
            temporal_pos = interpolate_temporal_mvit_embed(self.temporal_pos, T)
            # print('old spatial vs new:', self.spatial_pos.shape, spatial_pos.shape)
            # print('old temporal vs new:', self.temporal_pos.shape, temporal_pos.shape)

            hw_size, embed_size = spatial_pos.shape
            pos_embedding = torch.repeat_interleave(temporal_pos, hw_size, dim=0)
            pos_embedding.add_(
                spatial_pos.unsqueeze(0).expand(T//2, -1, -1).reshape(-1, embed_size))
            pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0)
            x.add_(pos_embedding)

        return x


class FlexMViT(_MViT):
    """
    Subclass of TorchVision's MultiScale Vision Transformer (MViT) that
    loads default pretrained weights and exposes a customizable forward().
    """

    @classmethod
    def from_pretrained(cls, variant: str = "v2_s", weights: str = "DEFAULT", **kwargs) -> "CustomMViT":
        """
        Build an MViT variant with pretrained weights, then return it as CustomMViT.

        Args:
            variant: "v2_s" (default) or "v1_b".
            weights: a valid weights enum name (e.g., "DEFAULT", "KINETICS400_V1").
            **kwargs: forwarded to the builder (e.g., num_classes, dropout, etc.).
        """
        if variant.lower() == "v2_s":
            # You can pass strings like "DEFAULT" or the enum directly
            wt = getattr(MViT_V2_S_Weights, weights) if isinstance(weights, str) else weights
            base = mvit_v2_s(weights=wt, **kwargs)
        elif variant.lower() == "v1_b":
            wt = getattr(MViT_V1_B_Weights, weights) if isinstance(weights, str) else weights
            base = mvit_v1_b(weights=wt, **kwargs)
        else:
            raise ValueError(f"Unknown variant '{variant}'. Use 'v2_s' or 'v1_b'.")

        # Safely morph the instance into our subclass so our custom forward is used
        old_pe = base.pos_encoding
        embed_size = base.conv_proj.out_channels
        temporal_size = getattr(old_pe, "temporal_size", None)
        spatial_size = getattr(old_pe, "spatial_size", None)
        had_abs = all(
            getattr(old_pe, name, None) is not None
            for name in ("spatial_pos", "temporal_pos", "class_pos")
        )
        rel_pos_embed = not had_abs

        base.pos_encoding = FlexPositionalEncoding(
            embed_size=embed_size,
            spatial_size=spatial_size,
            temporal_size=temporal_size,
            rel_pos_embed=rel_pos_embed,
        )
        base.__class__ = cls
        return base

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Default forward pass from TorchVision's MViT (kept verbatim in spirit so you can tweak).
        Expected input shape: (B, C, T, H, W) or (B, C, H, W) where T will default to 1.

        Steps:
        1) (Optionally) add a temporal dim
        2) Tubelet/patch embedding conv
        3) Flatten + add positional encoding (incl. class token)
        4) Transformer blocks w/ multiscale attention
        5) Norm, take cls token, head -> logits
        """
        # (B,C,H,W) -> (B,C,1,H,W) if needed
        # print('input: ', x.shape)
        T = x.shape[2]
        reso = x.shape[3]
        x = _unsqueeze(x, target_dim=5, expand_dim=2)[0]
        # print('unsqueeze: ', x.shape)

        # Patchify / tubelet embedding: (B, C, T, H, W) -> (B, embed, T', H', W') -> (B, THW', embed)
        x = self.conv_proj(x)
        # print('patchify: ', x.shape)
        new_size = x.shape[-1]  # default setting is 224x224 -> 56 x 56 = 3136. Seems like reso is downsized by a factor of 4
        x = x.flatten(2).transpose(1, 2)



        # print(reso, new_size)
        # print(self.pos_encoding.spatial_pos.shape)
        # print(self.pos_encoding.temporal_pos.shape)

        x = self.pos_encoding(x, new_size, T)
        # print('post encode: ', x.shape)

        # Run transformer blocks (track temporal-height-width sizes for multiscale ops)
        thw = (self.pos_encoding.temporal_size,) + self.pos_encoding.spatial_size
        # print('thw input: ', thw)

        for blk in self.blocks:
            x, thw = blk(x, thw)

        # Final norm, CLS token, and classification head
        x = self.norm(x)
        if self.mode == 'test':
            return x
        # print('final feature: ', x.shape)

        cls_token = x[:, 0]
        logits = self.head(cls_token)
        return logits
