from functools import partial

import einops
import torch
from einops import rearrange
from kappamodules.layers import LinearProjection

# from src.models.kappa_overrides.bi_a_dpa_cross_selfattn import BidirectionalAnchoredDotProductAttention
# from src.models.kappa_overrides.bi_a_dpa import BidirectionalAnchoredDotProductAttention
from src.models.kappa_overrides.dpa import DotProductAttention
from src.models.kappa_overrides.dit_xa_block import DitXABlock
from src.models.kappa_overrides.prenorm_xa_block import PrenormWithCrossAttentionBlock
from src.models.kappa_overrides.a_dpa import AnchoredDotProductAttention
from src.models.kappa_overrides.dit_block import DitBlock
from torch import Tensor, nn

from src.models.vision.patching.patchify_w_conv import ConvPatchEmbedding
from src.models.vision.patching.patched_positions import rope_pos_for_patch_midpoints
from src.modules.rope_frequency import RopeFrequency
from src.utils.attn_mask import attn_mask_from_radius
from src.utils.checkpointing import checkpointed_forward

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)


class CondVit(nn.Module):
    def __init__(
        self,
        conditioner,
        probe_encoding,
        patch_embedding,
        image_size,
        patch_size,
        transformer_dim,
        transformer_depth,
        transformer_attn_heads,
        x_dim: int,
        condition_dim: int,
        init_weights="truncnormal",
        init_modulation_zero=False,
        init_last_proj_zero=False,
        attn_ctor = DotProductAttention,
        do_time_mixing = True,
        rope_range_max = 1000,
        ignore_probes=False,
        checkpoint_interval=-1,
        **kwargs,
    ):
        super().__init__()
        self.conditioner = conditioner
        self.probe_encoding = probe_encoding
        self.transformer_dim = transformer_dim
        self.transformer_depth = transformer_depth
        self.transformer_attn_heads = transformer_attn_heads
        self.init_weights = init_weights
        self.condition_dim = condition_dim
        self.do_time_mixing = do_time_mixing
        self.rope_range_max = rope_range_max
        self.ignore_probes = ignore_probes
        self.checkpoint_interval = checkpoint_interval

        self.patch_embedding = patch_embedding

        self.rope_spatial = RopeFrequency(dim=transformer_dim // transformer_attn_heads, ndim=x_dim)
        self.rope_temporal = RopeFrequency(dim=transformer_dim // transformer_attn_heads, ndim=1)
        
        self.register_buffer("spatial_pos", 
            rope_pos_for_patch_midpoints(image_size, patch_size, rope_range_max),
        )
        
        self.field_proj = LinearProjection(x_dim, transformer_dim)

        dit_ctor = partial(DitBlock,
                    attn_ctor=attn_ctor,
                    dim=transformer_dim,
                    num_heads=transformer_attn_heads,
                    init_weights=init_weights,
                    init_modulation_zero=init_modulation_zero,
                    init_last_proj_zero=init_last_proj_zero,
                )
        spatial_ctor = dit_ctor if ignore_probes else partial(DitXABlock,
                    attn_ctor=attn_ctor,
                    dim=transformer_dim,
                    num_heads=transformer_attn_heads,
                    init_weights=init_weights,
                    init_modulation_zero=init_modulation_zero,
                    init_last_proj_zero=init_last_proj_zero,
                )

        self.spatial_blocks = []
        self.temporal_blocks = []
        for _ in range(transformer_depth):
            self.spatial_blocks.append(spatial_ctor())
            self.temporal_blocks.append(dit_ctor())
        self.spatial_blocks = nn.ModuleList(self.spatial_blocks)
        self.temporal_blocks = nn.ModuleList(self.temporal_blocks)
        

    def forward(
        self,
        x,
        t,
        *,
        probe_pos=None,
        probe_field=None,
        **kwargs,
    ):
        if t.ndim == 0:
            t = t.unsqueeze(0) # at inference
        
        x = self.patch_embedding(x)
        conv_residual = None
        if isinstance(x, tuple):
            x, conv_residual = x
        b, T, h, w, _ = x.shape
        
        cond = einops.repeat(self.conditioner(t.unsqueeze(1)), 'b d -> b T h w d', T=T, h=h, w=w)
        
        spatial_rope_freqs = einops.repeat(self.rope_spatial(self.spatial_pos), 'h w d -> (b T) (h w) d', b=b, T=T)
        
        temporal_pos = torch.arange(T, device=x.device).unsqueeze(1) / T * self.rope_range_max
        temporal_rope_freqs = einops.repeat(self.rope_temporal(temporal_pos), 'T d -> (b h w) T d', b=b, h=h, w=w)
        
        spatial_kwargs = dict(
            cond = einops.rearrange(cond, 'b T h w d -> (b T) (h w) d'),
            rope_freqs = spatial_rope_freqs
        )
        
        if self.ignore_probes == False:
            kv, kv_pos_for_rope_freqs = self.probe_encoding(probe_field, probe_pos)
            kv_spatial_freqs = self.rope_spatial(kv_pos_for_rope_freqs)
            spatial_kwargs['kv'] = einops.rearrange(kv, 'b T n d -> (b T) n d')
            spatial_kwargs['kv_rope_freqs'] = einops.repeat(kv_spatial_freqs, 'b n d -> (b T) n d', T=T)
        
        # reshape for spatial attention
        x = einops.rearrange(x, 'b T h w d -> (b T) (h w) d')
        for i in range(self.transformer_depth):
            # spatial attention
            use_checkpointing = (self.checkpoint_interval > 0) and (i % self.checkpoint_interval == 0)
            x = checkpointed_forward(self.spatial_blocks[i], use_checkpointing, 
                                     x, **spatial_kwargs)
            
            if self.do_time_mixing:
                x = einops.rearrange(x, '(b T) (h w) d -> (b h w) T d', T=T, h=h, w=w)
                # temporal attention
                x = checkpointed_forward(self.temporal_blocks[i], use_checkpointing,
                                         x, rope_freqs=temporal_rope_freqs,
                                         cond=einops.rearrange(cond, 'b T h w d -> (b h w) T d', T=T, h=h, w=w))
                # prepare spatial attention
                x = einops.rearrange(x, '(b h w) T d -> (b T) (h w) d', T=T, h=h, w=w)
        
        x = einops.rearrange(x, '(b T) (h w) d -> b T h w d', T=T, h=h)
        
        if conv_residual is None:
            x = self.patch_embedding.unpatch(x)
        else:
            x = self.patch_embedding.unpatch(x, conv_residual)
        
        return x