# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from modules.attention import STTransformer, ContinuousPositionBias, FeedForward

class AdaForwardDynamics(nn.Module):
    def __init__(self, structure_dim=16, action_dim=256, hidden_dim=512, depth=4, dim_head=64, heads=8, attn_dropout=0.1, ff_dropout=0.1):
        super().__init__()
        
        self.structure_dim = structure_dim
        self.hidden_dim = hidden_dim

        self.input_proj = nn.Sequential(
            nn.Linear(structure_dim, hidden_dim),
            FeedForward(hidden_dim, mult=4.0, dropout=ff_dropout),
        )
        self.action_proj = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            FeedForward(hidden_dim, mult=4.0, dropout=ff_dropout),
        )

        dec_st_transformer_kwargs = dict(
            dim=hidden_dim,
            dim_cond=hidden_dim,
            dim_head=dim_head,
            heads=heads,
            depth=depth,
            attn_dropout=attn_dropout,
            ff_dropout=ff_dropout,
            causal=True,
            peg=True,
            peg_causal=True,
            enable_conditioning=True,
        )
        
        self.dec_spatial_rel_pos_bias = ContinuousPositionBias(dim=hidden_dim, heads=heads, num_dims=2)   
        self.transformer = STTransformer(**dec_st_transformer_kwargs)
        # Project transformer output to delta_structure
        self.to_delta = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, structure_dim),
        )

    def forward(self, structure, action):
        """
        Predict delta_structure using adaptive spatio-temporal transformer.
        
        Args:
            structure: [B, T-1, H, W, C] - current structure representation
            action: [B, T-1, D] - latent action
            
        Returns:
            delta_structure: [B, T-1, H, W, C] - predicted change in structure
        """
        b, t, h, w, c = structure.shape
        
        # Reshape structure to [B, T, N, C] where N = H * W
        x = rearrange(structure, 'b t h w c -> b t (h w) c')

        x = self.input_proj(x)
        action_features = self.action_proj(action)

        tokens = torch.cat([x, rearrange(action_features, "b t d -> b t 1 d")], dim=2)  # (B, T-1, H*W + 1, D)
        attn_bias = self.dec_spatial_rel_pos_bias(h, w, device=tokens.device, dtype=tokens.dtype)  # (h, Np, Np)
        attn_bias = F.pad(attn_bias, (0, 1, 0, 1), value=0.0)  # (h, Np + 1, Np + 1)

        # Apply transformer with conditioning
        # x: [B, T, N, C], cond: [B, T, D]
        x = self.transformer(tokens, video_shape=(b, t, h, w), cond=action_features, spatial_attn_bias=attn_bias)

        x = x[:, :, :-1, :]
        
        # Predict delta_structure
        delta_structure = self.to_delta(x)
        
        # Reshape back to spatial format
        delta_structure = rearrange(delta_structure, 'b t (h w) c -> b t h w c', h=h, w=w)
        
        return delta_structure