# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copyright 2020 Ross Wightman
# Modified Model definition with RoPE (spatial + temporal)

import torch
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
import math
import warnings
import numpy as np

from timesformer.models.vit_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timesformer.models.helpers import load_pretrained
from timesformer.models.vit_utils import DropPath, to_2tuple, trunc_normal_

from .build import MODEL_REGISTRY
from torch import einsum
from einops import rearrange, reduce, repeat


# ============================================
# Config helper
# ============================================

def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
        'crop_pct': .9, 'interpolation': 'bicubic',
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
        'first_conv': 'patch_embed.proj', 'classifier': 'head',
        **kwargs
    }


default_cfgs = {
    'vit_base_patch16_224': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
    ),
}


# ============================================
# RoPE helpers (2D spatial + 1D temporal)
# ============================================

def init_random_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
    """
    dim: head_dim
    return: freqs (2, num_heads, dim/2) for x,y
    """
    freqs_x = []
    freqs_y = []
    mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    for i in range(num_heads):
        angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
        fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi / 2 + angles)], dim=-1)
        fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi / 2 + angles)], dim=-1)
        freqs_x.append(fx)
        freqs_y.append(fy)
    freqs_x = torch.stack(freqs_x, dim=0)  # (heads, dim/2)
    freqs_y = torch.stack(freqs_y, dim=0)
    freqs = torch.stack([freqs_x, freqs_y], dim=0)  # (2, heads, dim/2)
    return freqs


def init_t_xy(end_x: int, end_y: int):
    t = torch.arange(end_x * end_y, dtype=torch.float32)
    t_x = (t % end_x).float()
    t_y = torch.div(t, end_x, rounding_mode='floor').float()
    return t_x, t_y


def compute_mixed_cis(freqs, t_x, t_y, num_heads):
    """
    freqs: (2, depth, heads * (dim/2))
    t_x, t_y: (N,) flattened grid index
    return: (depth, heads, N, dim/2) complex cis
    """
    N = t_x.shape[0]
    depth = freqs.shape[1]
    dim_half = freqs.shape[-1] // num_heads

    # No float16 for this range
    with torch.cuda.amp.autocast(enabled=False):
        # freqs[0]: (depth, heads*dim_half)
        fx = freqs[0].view(depth, num_heads, dim_half)  # (depth, heads, dim/2)
        fy = freqs[1].view(depth, num_heads, dim_half)
        # t_x: (N,) -> (N,1)
        freqs_x = (t_x.unsqueeze(-1) * fx.unsqueeze(2))  # (depth, heads, N, dim/2)
        freqs_y = (t_y.unsqueeze(-1) * fy.unsqueeze(2))  # (depth, heads, N, dim/2)
        angles = freqs_x + freqs_y
        freqs_cis = torch.polar(torch.ones_like(angles), angles)  # complex cis

    return freqs_cis  # (depth, heads, N, dim/2)


def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 100.0):
    """
    2D axial RoPE over (x,y), for spatial attention.
    dim: head_dim
    return: (N, dim/2) complex cis
    """
    freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))

    t_x, t_y = init_t_xy(end_x, end_y)
    freqs_x = torch.outer(t_x, freqs_x)  # (N, dim/4)
    freqs_y = torch.outer(t_y, freqs_y)  # (N, dim/4)
    freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
    freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
    return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)  # (N, dim/2)


def compute_1d_time_cis(dim: int, seq_len: int, theta: float = 10000.0):
    """
    1D RoPE for temporal attention.
    dim: head_dim
    return: (T, dim/2) complex cis
    """
    # standard RoPE-style frequencies
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))  # (dim/2,)
    t = torch.arange(seq_len, dtype=torch.float32)
    freqs = torch.outer(t, freqs)  # (T, dim/2)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis  # (T, dim/2)


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    freqs_cis: (..., N, dim/2) either (N, D) or (H, N, D) etc.
    x: (..., N, dim/2) complex
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
    elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)]
    else:
        # fallback: broadcast as scalar per last dims
        shape = [1] * (ndim - freqs_cis.ndim) + list(freqs_cis.shape)
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    """
    xq, xk: (B, heads, N, head_dim)
    freqs_cis: (N, head_dim/2) or (heads, N, head_dim/2)
    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)


# ============================================
# MLP
# ============================================

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


# ============================================
# Attention (with optional RoPE)
# ============================================

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,
                 attn_drop=0., proj_drop=0., with_qkv=True):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.with_qkv = with_qkv
        if self.with_qkv:
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
            self.proj = nn.Linear(dim, dim)
            self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = nn.Dropout(attn_drop)

        # RoPE
        self.freqs_cis = None
        self.rope_on_cls = False

    def set_rope_freqs(self, freqs_cis: torch.Tensor = None, rope_on_cls: bool = False):
        self.freqs_cis = freqs_cis
        self.rope_on_cls = rope_on_cls

    def forward(self, x):
        B, N, C = x.shape
        if self.with_qkv:
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)\
                             .permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]    # (B, heads, N, dim_head)
        else:
            qkv = x.reshape(B, N, self.num_heads, C // self.num_heads)\
                   .permute(0, 2, 1, 3)
            q, k, v = qkv, qkv, qkv

        if self.freqs_cis is not None:
            if self.rope_on_cls:
                q, k = apply_rotary_emb(q, k, self.freqs_cis)
            else:
                q_cls, q_tok = q[:, :, :1, :], q[:, :, 1:, :]
                k_cls, k_tok = k[:, :, :1, :], k[:, :, 1:, :]
                q_tok, k_tok = apply_rotary_emb(q_tok, k_tok, self.freqs_cis)
                q = torch.cat([q_cls, q_tok], dim=2)
                k = torch.cat([k_cls, k_tok], dim=2)

        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        if self.with_qkv:
            x = self.proj(x)
            x = self.proj_drop(x)
        return x


# ============================================
# Block (Space-Time Divided + RoPE)
# ============================================

class Block(nn.Module):

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.,
        qkv_bias=False,
        qk_scale=None,
        drop=0.,
        attn_drop=0.,
        drop_path=0.1,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        attention_type='divided_space_time'
    ):
        super().__init__()
        self.attention_type = attention_type
        assert attention_type in ['divided_space_time', 'space_only', 'joint_space_time']

        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop
        )

        # Temporal Attention Parameters
        if self.attention_type == 'divided_space_time':
            self.temporal_norm1 = norm_layer(dim)
            self.temporal_attn = Attention(
                dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                attn_drop=attn_drop, proj_drop=drop
            )
            self.temporal_fc = nn.Linear(dim, dim)

        # drop path
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop
        )

    def forward(self, x, B, T, W):
        """
        x: (B, 1 + N*T, C)
        """
        num_spatial_tokens = (x.size(1) - 1) // T
        H = num_spatial_tokens // W

        # space_only / joint_space_time 不做 RoPE 分解，直接用 attn
        if self.attention_type in ['space_only', 'joint_space_time']:
            x = x + self.drop_path(self.attn(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
            return x

        elif self.attention_type == 'divided_space_time':
            # ===== Temporal Attention =====
            xt = x[:, 1:, :] 
            xt = rearrange(xt, 'b (h w t) m -> (b h w) t m', b=B, h=H, w=W, t=T)
            res_temporal = self.drop_path(self.temporal_attn(self.temporal_norm1(xt)))
            res_temporal = rearrange(res_temporal, '(b h w) t m -> b (h w t) m', b=B, h=H, w=W, t=T)
            res_temporal = self.temporal_fc(res_temporal)
            xt = x[:, 1:, :] + res_temporal  # (B, H*W*T, C)

            # ===== Spatial Attention =====
            init_cls_token = x[:, 0, :].unsqueeze(1)  # (B,1,C)
            cls_token = init_cls_token.repeat(1, T, 1)  # (B,T,C)
            cls_token = rearrange(cls_token, 'b t m -> (b t) m', b=B, t=T).unsqueeze(1)  # (B*T,1,C)

            xs = xt
            xs = rearrange(xs, 'b (h w t) m -> (b t) (h w) m', b=B, h=H, w=W, t=T)  # (B*T, H*W, C)
            xs = torch.cat((cls_token, xs), 1)  # (B*T, 1+H*W, C)
            res_spatial = self.drop_path(self.attn(self.norm1(xs)))

            cls_token = res_spatial[:, 0, :]
            cls_token = rearrange(cls_token, '(b t) m -> b t m', b=B, t=T)
            cls_token = torch.mean(cls_token, 1, True)  # (B,1,C)
            res_spatial = res_spatial[:, 1:, :]
            res_spatial = rearrange(res_spatial, '(b t) (h w) m -> b (h w t) m', b=B, h=H, w=W, t=T)

            res = res_spatial
            x = xt
            x = torch.cat((init_cls_token, x), 1) + torch.cat((cls_token, res), 1)

            # MLP
            x = x + self.drop_path(self.mlp(self.norm2(x)))
            return x


# ============================================
# Patch Embedding
# ============================================

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: (B, C, T, H, W)
        B, C, T, H, W = x.shape
        x = rearrange(x, 'b c t h w -> (b t) c h w')
        x = self.proj(x)  # (B*T, embed_dim, H', W')
        W = x.size(-1)
        x = x.flatten(2).transpose(1, 2)  # (B*T, N, C)
        return x, T, W


# ============================================
# VisionTransformer with optional RoPE
# ============================================

class VisionTransformer(nn.Module):
    """ Vision Transformer (TimeSformer style) + optional RoPE
    """
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.,
        qkv_bias=False,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.1,
        hybrid_backbone=None,
        norm_layer=nn.LayerNorm,
        num_frames=8,
        attention_type='divided_space_time',
        dropout=0.,
        rope_enabled=False,
        rope_mixed_space=False,
        rope_theta_space=100.0,
        rope_theta_time=10000.0,
        use_ape=True,  
    ):
        super().__init__()
        self.attention_type = attention_type
        self.depth = depth
        self.dropout = nn.Dropout(dropout)
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim

        self.rope_enabled = rope_enabled
        self.rope_mixed_space = rope_mixed_space
        self.rope_theta_space = rope_theta_space
        self.rope_theta_time = rope_theta_time
        self.use_ape = use_ape

        self.num_heads = num_heads

        # Patch embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size,
            in_chans=in_chans, embed_dim=embed_dim
        )
        num_patches = self.patch_embed.num_patches

        # Positional Embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        if self.use_ape:
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        else:
            self.pos_embed = None
        self.pos_drop = nn.Dropout(p=drop_rate)

        if self.attention_type != 'space_only':
            if self.use_ape:
                self.time_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
            else:
                self.time_embed = None
            self.time_drop = nn.Dropout(p=drop_rate)
        else:
            self.time_embed = None
            self.time_drop = nn.Dropout(p=drop_rate)

        # Attention Blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.depth)]  # stochastic depth
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                attention_type=self.attention_type
            )
            for i in range(self.depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        # init
        trunc_normal_(self.cls_token, std=.02)
        if self.pos_embed is not None:
            trunc_normal_(self.pos_embed, std=.02)
        if self.time_embed is not None:
            trunc_normal_(self.time_embed, std=.02)

        self.apply(self._init_weights)

        # initialization of temporal attention weights 
        if self.attention_type == 'divided_space_time':
            i = 0
            for m in self.blocks.modules():
                m_str = str(m)
                if 'Block' in m_str:
                    if i > 0:
                        if hasattr(m, 'temporal_fc'):
                            nn.init.constant_(m.temporal_fc.weight, 0)
                            nn.init.constant_(m.temporal_fc.bias, 0)
                    i += 1

        if self.rope_enabled:
            head_dim = embed_dim // num_heads

            # Time RoPE (1D)
            self.compute_time_cis = partial(
                compute_1d_time_cis,
                dim=head_dim,
                theta=self.rope_theta_time
            )
            freqs_time = self.compute_time_cis(seq_len=num_frames)
            self.register_buffer('freqs_cis_time', freqs_time, persistent=False)

            # Space RoPE (2D)
            base_h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0]
            base_w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1]

            if self.rope_mixed_space:
                # mixed: learnable freqs per block per head
                self.compute_space_cis = partial(
                    compute_mixed_cis,
                    num_heads=self.num_heads
                )
                freqs = []
                for i in range(self.depth):
                    freqs.append(
                        init_random_2d_freqs(
                            dim=head_dim,
                            num_heads=self.num_heads,
                            theta=self.rope_theta_space
                        )
                    )
                freqs = torch.stack(freqs, dim=1).view(2, self.depth, -1)
                self.freqs_space = nn.Parameter(freqs.clone(), requires_grad=True)

                t_x, t_y = init_t_xy(end_x=base_w, end_y=base_h)
                self.register_buffer('freqs_t_x', t_x, persistent=False)
                self.register_buffer('freqs_t_y', t_y, persistent=False)
            else:
                # axial: fixed cis shared across blocks
                self.compute_space_cis = partial(
                    compute_axial_cis,
                    dim=head_dim,
                    theta=self.rope_theta_space
                )
                freqs_cis_space = self.compute_space_cis(
                    end_x=base_w,
                    end_y=base_h
                )
                self.register_buffer('freqs_cis_space', freqs_cis_space, persistent=False)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        nd = {'cls_token'}
        if hasattr(self, 'pos_embed') and self.pos_embed is not None:
            nd.add('pos_embed')
        if hasattr(self, 'time_embed') and self.time_embed is not None:
            nd.add('time_embed')
        if hasattr(self, 'freqs_space'):
            nd.add('freqs_space')
        return nd

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        """
        x: (B, C, T, H, W)
        """
        B = x.shape[0]
        x, T, W = self.patch_embed(x)  # x: (B*T, N, C)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)  # (B*T, 1, C)
        x = torch.cat((cls_tokens, x), dim=1)  # (B*T, 1+N, C)

        # ===== Positional Embedding (spatial) =====
        if self.pos_embed is not None:
            # resizing the positional embeddings in case they don't match the input at inference
            if x.size(1) != self.pos_embed.size(1):
                pos_embed = self.pos_embed
                cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)
                other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)
                P = int(other_pos_embed.size(2) ** 0.5)
                H = x.size(1) // W
                other_pos_embed = other_pos_embed.reshape(1, x.size(2), P, P)
                new_pos_embed = F.interpolate(other_pos_embed, size=(H, W), mode='nearest')
                new_pos_embed = new_pos_embed.flatten(2)
                new_pos_embed = new_pos_embed.transpose(1, 2)
                new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
                x = x + new_pos_embed
            else:
                x = x + self.pos_embed
        x = self.pos_drop(x)

        # ===== Time Embedding =====
        if self.attention_type != 'space_only':
            cls_tokens = x[:B, 0, :].unsqueeze(1)  # (B,1,C)
            x = x[:, 1:]  # (B*T, N, C)
            x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)  # (B*N, T, C)

            if self.time_embed is not None:
                # Resizing time embeddings in case they don't match
                if T != self.time_embed.size(1):
                    time_embed = self.time_embed.transpose(1, 2)
                    new_time_embed = F.interpolate(time_embed, size=(T), mode='nearest')
                    new_time_embed = new_time_embed.transpose(1, 2)
                    x = x + new_time_embed
                else:
                    x = x + self.time_embed

            x = self.time_drop(x)
            x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T)  # (B, N*T, C)
            x = torch.cat((cls_tokens, x), dim=1)  # (B, 1+N*T, C)

        # ===== RoPE 频率计算（只在 divided_space_time 下使用）=====
        if self.rope_enabled and self.attention_type == 'divided_space_time':
            # num_spatial_tokens = H*W per frame
            num_spatial_tokens = (x.size(1) - 1) // T
            H = num_spatial_tokens // W
            N_spatial = H * W

            head_dim = self.embed_dim // self.num_heads

            # ---- temporal freqs (1D) ----
            if self.freqs_cis_time.shape[0] != T:
                freqs_time = self.compute_time_cis(seq_len=T)
            else:
                freqs_time = self.freqs_cis_time
            freqs_time = freqs_time.to(x.device)  # (T, head_dim/2)

            # ---- spatial freqs (2D) ----
            if self.rope_mixed_space:
                # mixed: per block, per head
                if self.freqs_t_x.shape[0] != N_spatial:
                    t_x, t_y = init_t_xy(end_x=W, end_y=H)
                    t_x, t_y = t_x.to(x.device), t_y.to(x.device)
                else:
                    t_x, t_y = self.freqs_t_x.to(x.device), self.freqs_t_y.to(x.device)

                freqs_space_all = self.compute_space_cis(self.freqs_space, t_x, t_y)
                freqs_space_all = freqs_space_all.to(x.device)  # (depth, heads, N_spatial, head_dim/2)

            else:
                # axial: shared across blocks
                if self.freqs_cis_space.shape[0] != N_spatial:
                    freqs_space_base = self.compute_space_cis(end_x=W, end_y=H)
                else:
                    freqs_space_base = self.freqs_cis_space
                freqs_space_base = freqs_space_base.to(x.device)  # (N_spatial, head_dim/2)

            # ===== Blocks with RoPE =====
            for i, blk in enumerate(self.blocks):
                if hasattr(blk, 'temporal_attn'):
                    blk.temporal_attn.set_rope_freqs(
                        freqs_cis=freqs_time,
                        rope_on_cls=True
                    )

                if self.rope_mixed_space:
                    freqs_sp_layer = freqs_space_all[i]  # (heads, N_spatial, head_dim/2)
                else:
                    freqs_sp_layer = freqs_space_base  # (N_spatial, head_dim/2)
                blk.attn.set_rope_freqs(
                    freqs_cis=freqs_sp_layer,
                    rope_on_cls=False 
                )

                x = blk(x, B, T, W)
        else:
            for blk in self.blocks:
                x = blk(x, B, T, W)

        # space_only baseline
        if self.attention_type == 'space_only':
            x = rearrange(x, '(b t) n m -> b t n m', b=B, t=T)
            x = torch.mean(x, 1)  # average over frames

        x = self.norm(x)
        return x[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


# ============================================
# Pretrained conv filter
# ============================================

def _conv_filter(state_dict, patch_size=16):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    for k, v in state_dict.items():
        if 'patch_embed.proj.weight' in k:
            if v.shape[-1] != patch_size:
                patch_size = v.shape[-1]
            v = v.reshape((v.shape[0], 3, patch_size, patch_size))
        out_dict[k] = v
    return out_dict


@MODEL_REGISTRY.register()
class rope_timesformer_axial(nn.Module):
    def __init__(self, cfg, **kwargs):
        super().__init__()
        self.pretrained = True
        patch_size = 16
        self.model = VisionTransformer(
            img_size=cfg.DATA.TRAIN_CROP_SIZE,
            num_classes=cfg.MODEL.NUM_CLASSES,
            patch_size=patch_size,
            embed_dim=768,
            depth=12,
            num_heads=12,
            mlp_ratio=4.,
            qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            drop_rate=0.,
            attn_drop_rate=0.,
            drop_path_rate=0.1,
            num_frames=cfg.DATA.NUM_FRAMES,
            attention_type=cfg.TIMESFORMER.ATTENTION_TYPE,
            rope_enabled=True,
            rope_mixed_space=False,
            rope_theta_space=100.0,
            rope_theta_time=10000.0,
            use_ape=False,   
            **kwargs
        )

        self.attention_type = cfg.TIMESFORMER.ATTENTION_TYPE
        self.model.default_cfg = default_cfgs['vit_base_patch16_224']
        self.num_patches = (cfg.DATA.TRAIN_CROP_SIZE // patch_size) * (cfg.DATA.TRAIN_CROP_SIZE // patch_size)
        pretrained_model = cfg.TIMESFORMER.PRETRAINED_MODEL
        pretrained_model = "xxx/jx_vit_base_p16_224-80ecf9dd.pth"
        if self.pretrained:
            load_pretrained(
                self.model,
                num_classes=self.model.num_classes,
                in_chans=kwargs.get('in_chans', 3),
                filter_fn=_conv_filter,
                img_size=cfg.DATA.TRAIN_CROP_SIZE,
                num_frames=cfg.DATA.NUM_FRAMES,
                num_patches=self.num_patches,
                attention_type=self.attention_type,
                pretrained_model=pretrained_model
            )

    def forward(self, x):
        return self.model(x)


@MODEL_REGISTRY.register()
class rope_timesformer_mixed(nn.Module):
    def __init__(self, cfg, **kwargs):
        super().__init__()
        self.pretrained = True
        patch_size = 16
        self.model = VisionTransformer(
            img_size=cfg.DATA.TRAIN_CROP_SIZE,
            num_classes=cfg.MODEL.NUM_CLASSES,
            patch_size=patch_size,
            embed_dim=768,
            depth=12,
            num_heads=12,
            mlp_ratio=4.,
            qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            drop_rate=0.,
            attn_drop_rate=0.,
            drop_path_rate=0.1,
            num_frames=cfg.DATA.NUM_FRAMES,
            attention_type=cfg.TIMESFORMER.ATTENTION_TYPE,
            rope_enabled=True,
            rope_mixed_space=True,
            rope_theta_space=10.0,
            rope_theta_time=10000.0,
            use_ape=False,
            **kwargs
        )

        self.attention_type = cfg.TIMESFORMER.ATTENTION_TYPE
        self.model.default_cfg = default_cfgs['vit_base_patch16_224']
        self.num_patches = (cfg.DATA.TRAIN_CROP_SIZE // patch_size) * (cfg.DATA.TRAIN_CROP_SIZE // patch_size)
        pretrained_model = cfg.TIMESFORMER.PRETRAINED_MODEL
        pretrained_model = "xxx/jx_vit_base_p16_224-80ecf9dd.pth"
        if self.pretrained:
            load_pretrained(
                self.model,
                num_classes=self.model.num_classes,
                in_chans=kwargs.get('in_chans', 3),
                filter_fn=_conv_filter,
                img_size=cfg.DATA.TRAIN_CROP_SIZE,
                num_frames=cfg.DATA.NUM_FRAMES,
                num_patches=self.num_patches,
                attention_type=self.attention_type,
                pretrained_model=pretrained_model
            )

    def forward(self, x):
        return self.model(x)


@MODEL_REGISTRY.register()
class rope_timesformer_axial_ape(nn.Module):
    def __init__(self, cfg, **kwargs):
        super().__init__()
        self.pretrained = True
        patch_size = 16
        self.model = VisionTransformer(
            img_size=cfg.DATA.TRAIN_CROP_SIZE,
            num_classes=cfg.MODEL.NUM_CLASSES,
            patch_size=patch_size,
            embed_dim=768,
            depth=12,
            num_heads=12,
            mlp_ratio=4.,
            qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            drop_rate=0.,
            attn_drop_rate=0.,
            drop_path_rate=0.1,
            num_frames=cfg.DATA.NUM_FRAMES,
            attention_type=cfg.TIMESFORMER.ATTENTION_TYPE,
            rope_enabled=True,
            rope_mixed_space=False,
            rope_theta_space=100.0,
            rope_theta_time=10000.0,
            use_ape=True,
            **kwargs
        )

        self.attention_type = cfg.TIMESFORMER.ATTENTION_TYPE
        self.model.default_cfg = default_cfgs['vit_base_patch16_224']
        self.num_patches = (cfg.DATA.TRAIN_CROP_SIZE // patch_size) * (cfg.DATA.TRAIN_CROP_SIZE // patch_size)
        pretrained_model = cfg.TIMESFORMER.PRETRAINED_MODEL
        pretrained_model = "xxx/jx_vit_base_p16_224-80ecf9dd.pth"
        if self.pretrained:
            load_pretrained(
                self.model,
                num_classes=self.model.num_classes,
                in_chans=kwargs.get('in_chans', 3),
                filter_fn=_conv_filter,
                img_size=cfg.DATA.TRAIN_CROP_SIZE,
                num_frames=cfg.DATA.NUM_FRAMES,
                num_patches=self.num_patches,
                attention_type=self.attention_type,
                pretrained_model=pretrained_model
            )

    def forward(self, x):
        return self.model(x)


@MODEL_REGISTRY.register()
class rope_timesformer_mixed_ape(nn.Module):
    def __init__(self, cfg, **kwargs):
        super().__init__()
        self.pretrained = True
        patch_size = 16
        self.model = VisionTransformer(
            img_size=cfg.DATA.TRAIN_CROP_SIZE,
            num_classes=cfg.MODEL.NUM_CLASSES,
            patch_size=patch_size,
            embed_dim=768,
            depth=12,
            num_heads=12,
            mlp_ratio=4.,
            qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            drop_rate=0.,
            attn_drop_rate=0.,
            drop_path_rate=0.1,
            num_frames=cfg.DATA.NUM_FRAMES,
            attention_type=cfg.TIMESFORMER.ATTENTION_TYPE,
            rope_enabled=True,
            rope_mixed_space=True,
            rope_theta_space=10.0,
            rope_theta_time=10000.0,
            use_ape=True,
            **kwargs
        )

        self.attention_type = cfg.TIMESFORMER.ATTENTION_TYPE
        self.model.default_cfg = default_cfgs['vit_base_patch16_224']
        self.num_patches = (cfg.DATA.TRAIN_CROP_SIZE // patch_size) * (cfg.DATA.TRAIN_CROP_SIZE // patch_size)
        pretrained_model = cfg.TIMESFORMER.PRETRAINED_MODEL
        pretrained_model = "xxx/jx_vit_base_p16_224-80ecf9dd.pth"
        if self.pretrained:
            load_pretrained(
                self.model,
                num_classes=self.model.num_classes,
                in_chans=kwargs.get('in_chans', 3),
                filter_fn=_conv_filter,
                img_size=cfg.DATA.TRAIN_CROP_SIZE,
                num_frames=cfg.DATA.NUM_FRAMES,
                num_patches=self.num_patches,
                attention_type=self.attention_type,
                pretrained_model=pretrained_model
            )

    def forward(self, x):
        return self.model(x)