from functools import partial

import einops
import torch
from einops import rearrange
from kappamodules.layers import LinearProjection

# from src.models.kappa_overrides.bi_a_dpa_cross_selfattn import BidirectionalAnchoredDotProductAttention
# from src.models.kappa_overrides.bi_a_dpa import BidirectionalAnchoredDotProductAttention
from dit_xa_block_old import DitXABlock
from src.models.kappa_overrides.prenorm_xa_block import PrenormWithCrossAttentionBlock
from src.models.kappa_overrides.a_dpa import AnchoredDotProductAttention
from src.models.kappa_overrides.dit_block import DitBlock
from torch import Tensor, nn

from src.modules.rope_frequency import RopeFrequency
from src.utils.attn_mask import attn_mask_from_radius


class DiTAUPT(nn.Module):
    def __init__(
        self,
        conditioner,
        probe_encoding,
        transformer_dim,
        transformer_depth,
        transformer_attn_heads,
        n_anchors: int,
        x_dim: int,
        condition_dim: int,
        out_proj: nn.Module,
        init_weights="truncnormal",
        attn_ctor = AnchoredDotProductAttention,
        do_time_mixing = True,
        **kwargs,
    ):
        super().__init__()
        self.conditioner = conditioner
        self.probe_encoding = probe_encoding
        self.transformer_dim = transformer_dim
        self.transformer_depth = transformer_depth
        self.transformer_attn_heads = transformer_attn_heads
        self.init_weights = init_weights
        self.condition_dim = condition_dim
        self.n_anchors = n_anchors
        self.do_time_mixing = do_time_mixing

        self.rope_spatial = RopeFrequency(dim=transformer_dim // transformer_attn_heads, ndim=x_dim)
        self.rope_temporal = RopeFrequency(dim=transformer_dim // transformer_attn_heads, ndim=1)
        
        self.field_proj = LinearProjection(x_dim, transformer_dim)
        self.out_proj = out_proj

        self.spatial_blocks = []
        self.temporal_blocks = []
        for i in range(transformer_depth):
            self.spatial_blocks.append(
                DitXABlock(
                    attn_ctor=partial(attn_ctor,
                        n_anchors=n_anchors,
                    ),
                    dim=transformer_dim,
                    num_heads=transformer_attn_heads,
                    init_weights=init_weights,
                ))
            self.temporal_blocks.append(
                DitBlock(
                    attn_ctor=partial(attn_ctor,
                        n_anchors=-1,  # disable anchors
                    ),
                    dim=transformer_dim,
                    num_heads=transformer_attn_heads,
                    init_weights=init_weights,
                ))
        self.spatial_blocks = nn.ModuleList(self.spatial_blocks)
        self.temporal_blocks = nn.ModuleList(self.temporal_blocks)
        

    def forward(
        self,
        x,
        t,
        *,
        probe_pos,
        probe_field,
        **kwargs,
    ):
        b, n, T, d = x.shape
        if t.ndim == 0:
            t = t.unsqueeze(0) # at inference
        
        x = self.field_proj(x)
        
        # cond = self.conditioner(t.unsqueeze(1))
        cond = einops.repeat(self.conditioner(t.unsqueeze(1)), 'b d -> b n T d', n=n, T=T)
        
        spatial_rope_freqs = einops.repeat(self.rope_spatial(pos), 'b n d -> (b T) n d', T=T)
        
        # TODO: remove hardcoded 1000
        temporal_pos = torch.arange(T, device=pos.device).unsqueeze(1) / T * 1000
        temporal_rope_freqs = einops.repeat(self.rope_temporal(temporal_pos), 'T d -> (b n) T d', b=b, n=n)
        
        kv, kv_pos_for_rope_freqs = self.probe_encoding(probe_field, probe_pos)
        kv_rope_freqs = self.rope_spatial(kv_pos_for_rope_freqs)
        kv_rope_freqs = einops.repeat(kv_rope_freqs, 'b n d -> b n T d', T=T)
        
        # reshape for spatial attention
        x, cond, kv, kv_rope_freqs = m_rearrange([x, cond, kv, kv_rope_freqs], 
                                                 'b n T d -> (b T) n d')
        for i in range(self.transformer_depth):
            # spatial attention
            x = self.spatial_blocks[i](x, cond=cond, 
                                       rope_freqs=spatial_rope_freqs, 
                                       kv=kv, kv_rope_freqs=kv_rope_freqs)
            
            if self.do_time_mixing:
                # temporal attention
                x, cond = m_rearrange([x, cond], '(b T) n d -> (b n) T d', T=T)            
                x = self.temporal_blocks[i](x, cond=cond, rope_freqs=temporal_rope_freqs)
                # prepare spatial attention
                x, cond = m_rearrange([x, cond], '(b n) T d -> (b T) n d ', n=n)
        
        x = einops.rearrange(x, '(b T) n d -> b n T d', T=T)
        
        x = self.out_proj(x)
        
        return x
    
def m_rearrange(array_list, pattern, *args, **kwargs):
    """
    Apply einops.rearrange to a list of arrays/tensors using the same pattern.

    Args:
        array_list (list): List of tensors/arrays to rearrange.
        pattern (str): The rearrangement pattern (e.g., 'b c h w -> b h w c').
        *args, **kwargs: Additional arguments to pass to einops.rearrange.

    Returns:
        list: Rearranged tensors/arrays.
    """
    return [einops.rearrange(arr, pattern, *args, **kwargs) for arr in array_list]