import torch
import torch.nn as nn
from einops import rearrange
import math
import numpy as np
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch.nn.attention import SDPBackend, sdpa_kernel

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """

    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half, dtype=torch.float32)
            / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb


class LabelEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """

    def __init__(self, num_classes, hidden_size, dropout_prob):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0
        self.embedding_table = nn.Embedding(
            num_classes + use_cfg_embedding, hidden_size
        )
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob

    def token_drop(self, labels, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = (
                torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
            )
        else:
            drop_ids = force_drop_ids == 1
        labels = torch.where(drop_ids, self.num_classes, labels)
        return labels

    def forward(self, labels, train, force_drop_ids=None):
        use_dropout = self.dropout_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            labels = self.token_drop(labels, force_drop_ids)
        embeddings = self.embedding_table(labels)
        return embeddings


class RotaryEmbedding(nn.Module):
    """
    Rotary Positional Embedding (RoPE) module that pre-computes sin/cos tables for efficiency.
    """
    def __init__(self, dim: int, seq_len: int, base: int = 10000, coords: torch.Tensor = None):
        super().__init__()
        self.seq_len = seq_len

        # Determine the positions 't' to use for the embedding
        if coords is not None:
            assert seq_len == len(coords), f"Provided seq_len {seq_len} does not match coordinate length {len(coords)}"
            t = coords.float()
        else:
            t = torch.arange(seq_len, dtype=torch.float32)

        # The inverse frequencies have shape (dim / 2)
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        
        # freqs has shape (seq_len, dim / 2)
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        
        # Pre-compute and register the real-valued sin/cos tables as buffers
        self.register_buffer("cos_cached", freqs.cos(), persistent=False)
        self.register_buffer("sin_cached", freqs.sin(), persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply the pre-computed RoPE to the input tensor.
        """
        assert x.shape[2] == self.seq_len, f"Input sequence length {x.shape[2]} does not match the module's fixed length {self.seq_len}."
        
        # x1 and x2 will have shape (B, H, N, D_h / 2)
        x1, x2 = x.chunk(2, dim=-1)
        
        # Apply the rotation using standard, Inductor-friendly operations
        rotated_x = torch.cat(
            [
                x1 * self.cos_cached - x2 * self.sin_cached,
                x1 * self.sin_cached + x2 * self.cos_cached,
            ],
            dim=-1,
        )
        return rotated_x.type_as(x).contiguous()

class PatchEmbedTFW(nn.Module):
    """
    Anisotropic patching over (H=freq, W=covariate). No patching over time T.
    Input is provided as (B, T, C, H, W). We fold (B*T) as batch for conv.
    """

    def __init__(
        self, H, W, patch_f=4, patch_w=1, in_chans=3, embed_dim=1152, bias=True
    ):
        super().__init__()
        assert (
            H % patch_f == 0 and W % patch_w == 0
        ), "H/W must be divisible by patch sizes"
        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=(patch_f, patch_w),
            stride=(patch_f, patch_w),
            bias=bias,
        )
        self.patch_f = patch_f
        self.patch_w = patch_w
        self.num_patches_f = H // patch_f
        self.num_patches_w = W // patch_w
        self.num_patches = self.num_patches_f * self.num_patches_w

    def forward(self, x):
        # x: (B*T, C, H, W)
        x = self.proj(x)  # (B*T, D, H/pf, W/pw)
        x = rearrange(x, "bt d hf hw -> bt (hf hw) d")
        return x

class AttentionWithBias(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = attn_drop  # Dropout is handled by scaled_dot_product_attention
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, attn_bias=None):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv.unbind(dim=0)  # (B, num_heads, N, head_dim)

        with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
            out = F.scaled_dot_product_attention(
                q,
                k,
                v,
                attn_mask=attn_bias,
                dropout_p=self.attn_drop if self.training else 0.0,
            )

        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class AttentionWithRoPE(nn.Module):
    def __init__(self, dim, num_heads, rotary_emb_module : RotaryEmbedding, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.rotary_emb = rotary_emb_module
        self.qkv_proj = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.o_proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = attn_drop  # Dropout is handled by scaled_dot_product_attention

    def forward(self, x, attn_bias=None):
        B, N, D = x.shape
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
        q = self.rotary_emb(q) # (B, H, N, D_h)
        k = self.rotary_emb(k) # (B, H, N, D_h)
        with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
            output = F.scaled_dot_product_attention(
                q,
                k,
                v,
                attn_mask=attn_bias,
                is_causal=False,
                dropout_p=self.attn_drop if self.training else 0.0,
            )

        output = rearrange(output, 'b h n d -> b n (h d)')
        return self.o_proj(output)

def coarsen_adj_matrix(A: torch.Tensor, ntokens: int) -> torch.Tensor:
    """
    Average-pool (N_full, N_full) → (ntokens, ntokens) assuming contiguous grouping
    when patch_size > 1. Preserves symmetry.
    """
    N_full = A.shape[-1]
    if N_full == ntokens:
        return A 
    
    assert N_full % ntokens == 0, f"N_full={N_full} must be divisible by ntokens={ntokens}"
    g = N_full // ntokens
    A_coarse = A.view(ntokens, g, ntokens, g).mean(dim=(1, 3))
    return A_coarse

class AxisGraphBias(nn.Module):
    """
    Per-axis additive attention bias (heads, N, N), optionally learnable.
    - adj: torch/np (N_full,N_full) or (N,N). If N_full != N, we coarsen by average pooling.
    - learnable=False -> register as buffer; True -> nn.Parameter.
    - scale: fixed multiplier applied in forward() so logits feel the bias.
    """
    def __init__(self, num_heads, N_tokens, adj=None,
                 learnable=False, scale_init=0.3, init_std=0.02, zero_diag=True):
        super().__init__()
        self.num_heads = num_heads
        self.N_tokens = N_tokens
        self.zero_diag = zero_diag

        if adj is None:
            bias = torch.zeros((N_tokens, N_tokens), dtype=torch.float32)
            if learnable:
                nn.init.trunc_normal_(bias, std=init_std)
        else:
            bias = torch.as_tensor(adj, dtype=torch.float32)
            bias = coarsen_adj_matrix(bias, N_tokens)

        bias = 0.5 * (bias + bias.T)
        bias = bias.clamp(min = -1.0, max = 1.0)  # ensure no extreme values

        if zero_diag:
            bias.fill_diagonal_(0.0)

        # Expand to heads
        bias = bias.unsqueeze(0)                 # (1,N,N)
        if num_heads > 1:
            bias = bias.repeat(num_heads, 1, 1)  # (H,N,N)

        if learnable:
            self.bias = nn.Parameter(bias)
        else:
            self.register_buffer("bias", bias, persistent=True)

        scale = torch.full((num_heads,), float(scale_init))
        self.scale = nn.Parameter(scale, requires_grad=True)  # always learnable

    def forward(self):
        return self.scale[:, None, None] * self.bias  # (H,N,N)


class STBlock(nn.Module):
    """
    A Transformer block with axis-factorized attention:
      1) temporal -> 2) frequency -> 3) covariate
    Each sub-attention is adaLN-modulated and residual.
    """

    def __init__(
        self,
        hidden_size,
        num_heads,
        num_frames,
        F_tokens,
        W_tokens,
        time_rotary_embed: RotaryEmbedding,
        freq_rotary_embed: RotaryEmbedding,
        cov_bias: AxisGraphBias  = None,
        freq_bias: AxisGraphBias = None,
        mlp_ratio=4.0,
        dropout=0.0,
        attn_order:str = 'tfc'
    ):
        super().__init__()
        self.num_frames = num_frames
        self.F_tokens = F_tokens
        self.W_tokens = W_tokens
        assert attn_order in ['tfc', 'tcf'], "attn_order must be 'tfc' or 'tcf'"
        self.attn_order = attn_order

        # Positional / structural biases
        self.time_rotary = time_rotary_embed
        self.freq_rotary = freq_rotary_embed
        self.freq_bias = freq_bias
        self.cov_bias = cov_bias

        # Attentions with per-head bias support
        self.temporal_norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.temporal_attn = AttentionWithRoPE(
            dim =hidden_size,
            num_heads=num_heads,
            rotary_emb_module=time_rotary_embed,
            qkv_bias=True,
            attn_drop=dropout,
            proj_drop=dropout,
        )
        self.temporal_fc = nn.Linear(hidden_size, hidden_size)

        self.freq_norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.freq_attn = AttentionWithRoPE(
            dim=hidden_size,
            num_heads=num_heads,
            rotary_emb_module=freq_rotary_embed,
            qkv_bias=True,
            attn_drop=dropout,
            proj_drop=dropout,
        )
        self.freq_fc = nn.Linear(hidden_size, hidden_size)

        self.cov_norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.cov_attn = AttentionWithBias(
            hidden_size,
            num_heads=num_heads,
            qkv_bias=True,
            attn_drop=dropout,
            proj_drop=dropout,
        )
        self.cov_fc = nn.Linear(hidden_size, hidden_size)

        # MLP
        mlp_hidden = int(hidden_size * mlp_ratio)
        self.norm_mlp = nn.LayerNorm(hidden_size, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_hidden, hidden_size),
        )

        # adaLN-Zero (FiLM) for each subpath
        self.ada_temporal = nn.Sequential(
            nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size)
        )
        self.ada_freq = nn.Sequential(
            nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size)
        )
        self.ada_cov = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size))
        self.ada_mlp = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size))

    def _mod(self, x, shift, scale, T):
        # x: ((B*T), N, D)  where T is number of frames
        B = x.shape[0] // T
        x = rearrange(x, "(b t) n d -> b (t n) d", b=B, t=T)
        x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = rearrange(x, "b (t n) d -> (b t) n d", b=B, t=T)
        return x
    
    def _gate_bt(self, gate):
        # small helper for broadcasting gate to (B*T, 1, D)
        # gate: (B, D) -> (B*T, 1, D)
        T = self.num_frames
        return gate.repeat_interleave(T, dim=0).unsqueeze(1)

    def _temp_pass(self, x, c):
        # ---------- Temporal attention ----------
        B, T, F, W = c.shape[0], self.num_frames, self.F_tokens, self.W_tokens
        shift, scale, gate = self.ada_temporal(c).chunk(3, dim=1)
        xt = self._mod(self.temporal_norm(x), shift, scale, T)
        xt = rearrange(xt, "(b t) (f w) d -> (b f w) t d", b=B, t=T, f=F, w=W)
        xt = self.temporal_attn(xt)
        xt = rearrange(xt, "(b f w) t d -> (b t) (f w) d", b=B, t=T, f=F, w=W)
        x = x + self._gate_bt(gate) * self.temporal_fc(xt)
        return x
    
    def _freq_pass(self, x, c):
        # ---------- Frequency attention ----------
        B, T, F, W = c.shape[0], self.num_frames, self.F_tokens, self.W_tokens
        shift, scale, gate = self.ada_freq(c).chunk(3, dim=1)
        xf = self._mod(self.freq_norm(x), shift, scale, T)
        xf = rearrange(xf, "(b t) (f w) d -> (b t w) f d", b=B, t=T, f=F, w=W)
        attn_bias = self.freq_bias() if self.freq_bias is not None else None
        xf = self.freq_attn(xf, attn_bias=attn_bias)
        xf = rearrange(xf, "(b t w) f d -> (b t) (f w) d", b=B, t=T, f=F, w=W)
        x = x + self._gate_bt(gate) * self.freq_fc(xf)
        return x
    
    def _cov_pass(self, x, c):
        # ---------- Covariate attention ----------
        B, T, F, W = c.shape[0], self.num_frames, self.F_tokens, self.W_tokens
        shift, scale, gate = self.ada_cov(c).chunk(3, dim=1)
        xc = self._mod(self.cov_norm(x), shift, scale, T)
        xc = rearrange(xc, "(b t) (f w) d -> (b t f) w d", b=B, t=T, f=F, w=W)
        attn_bias = self.cov_bias() if self.cov_bias is not None else None
        xc = self.cov_attn(xc, attn_bias=attn_bias)
        xc = rearrange(xc, "(b t f) w d -> (b t) (f w) d", b=B, t=T, f=F, w=W)
        x = x + self._gate_bt(gate) * self.cov_fc(xc)
        return x

    def forward(self, x, c):
        T = self.num_frames
        if self.attn_order == 'tfc':
            x = self._temp_pass(x, c)
            x = self._freq_pass(x, c)
            x = self._cov_pass(x, c)
        elif self.attn_order == 'tcf':
            x = self._temp_pass(x, c)
            x = self._cov_pass(x, c)
            x = self._freq_pass(x, c)

        # ---------- MLP ----------
        shift, scale = self.ada_mlp(c).chunk(2, dim=1)
        xm = self._mod(self.norm_mlp(x), shift, scale, T)
        x = x + self.mlp(xm)

        return x


class STDiff(nn.Module):
    def __init__(
        self,
        input_size: tuple = (2, 2),  # freq bins, covariates
        patch_size: tuple = (4, 1),  # patch sizes (f, w)
        in_channels=3,  # trend, Re, Im
        hidden_size=1152,
        depth=24,
        num_heads=16,
        mlp_ratio=4.0,
        dropout=0.0,
        num_frames=16,
        learn_sigma=True,
        freq_centers=None,  # tensor of physical bin centers (Hz or normalized)
        freq_adj=None,  # optional (F, F) adjacency
        cov_adj=None,  # optional (Wc, Wc) adjacency
        use_freq_bias=False,
        use_cov_bias=True,
        alternate_attn_order=False,
        use_checkpoint_every : int = 0,  # if >0, use checkpointing every N blocks
        **kwargs,
    ):
        super().__init__()
        self.learn_sigma = learn_sigma
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.num_frames = num_frames
        self.use_checkpoint_every = use_checkpoint_every

        H, W = input_size
        patch_f, patch_w = patch_size

        # --- Patching over (H, W) ---
        self.x_embedder = PatchEmbedTFW(
            H,
            W,
            patch_f=patch_f,
            patch_w=patch_w,
            in_chans=in_channels,
            embed_dim=hidden_size,
        )
        F_tokens = self.x_embedder.num_patches_f
        W_tokens = self.x_embedder.num_patches_w
        N_tokens = self.x_embedder.num_patches

        if freq_centers is not None:
            freq_centers = torch.as_tensor(freq_centers)
            if freq_centers.numel() != F_tokens:
                patch_f = freq_centers.numel() // F_tokens
                freq_centers_log_pooled = torch.log2(freq_centers + 1e-8).reshape(-1, patch_f).mean(dim=1)
                # Normalize to [0, F_tokens-1] so its dynamic range matches a "sequence index"
                freq_coords = (freq_centers_log_pooled - freq_centers_log_pooled.min()) / (freq_centers_log_pooled.max() - freq_centers_log_pooled.min() + 1e-8)
                freq_coords = freq_coords * (F_tokens - 1)

        rotary_head_dim = hidden_size // num_heads
        assert rotary_head_dim % 2 == 0, "Rotary head dim must be even"

        self.time_rotary_embed = RotaryEmbedding(dim=rotary_head_dim, seq_len=num_frames)
        self.freq_rotary_embed = RotaryEmbedding(dim=rotary_head_dim, seq_len=F_tokens, coords=freq_coords)
        self.cov_embed = nn.Embedding(W_tokens, hidden_size)  # learned covariate embeddings

        self.freq_bias = AxisGraphBias(num_heads, F_tokens, adj=freq_adj, scale_init=0.3, learnable=False, zero_diag=True) if freq_adj is not None and use_freq_bias else None
        self.cov_bias = AxisGraphBias(num_heads, W_tokens, adj=cov_adj, scale_init=1.0, learnable=True, zero_diag=True) if cov_adj is not None and use_cov_bias else None

        self.t_embedder = TimestepEmbedder(hidden_size)

        self.blocks = nn.ModuleList(
            [
                STBlock(
                    hidden_size,
                    num_heads,
                    num_frames=num_frames,
                    F_tokens=F_tokens,
                    W_tokens=W_tokens,
                    time_rotary_embed=self.time_rotary_embed,
                    freq_rotary_embed=self.freq_rotary_embed,
                    freq_bias=self.freq_bias,
                    cov_bias=AxisGraphBias(num_heads, W_tokens, adj=cov_adj, scale_init=1.0, learnable=True, zero_diag=True) if cov_adj is not None and use_cov_bias else None,
                    mlp_ratio=mlp_ratio,
                    dropout=dropout,
                    attn_order= ('tfc' if (i % 2 == 0) else 'tcf') if alternate_attn_order else 'tfc'  # alternate order
                )
                for i in range(depth)
            ]
        )

        # --- Head ---
        self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6)
        self.linear = nn.Linear(hidden_size, self.out_channels * patch_f * patch_w)

        # Init
        nn.init.constant_(self.linear.weight, 0)
        nn.init.constant_(self.linear.bias, 0)

        self.patch_f, self.patch_w = patch_f, patch_w
        self.F_tokens, self.W_tokens = F_tokens, W_tokens

    def unpatchify(self, x):
        """
        x: (B*T, N=F*W, out_ch * pf * pw) -> (B*T, out_ch, H, W)
        """
        BT, N, P = x.shape
        pf, pw = self.patch_f, self.patch_w
        F, W = self.F_tokens, self.W_tokens
        out_ch = self.out_channels

        x = x.view(BT, F, W, pf, pw, out_ch)  # (BT, F, W, pf, pw, C)
        x = rearrange(x, "bt f w pf pw c -> bt c (f pf) (w pw)")
        return x

    def _apply_cov_embedding(self, tokens):
        tokens = rearrange(tokens, "bt (f w) d -> bt f w d", f=self.F_tokens, w=self.W_tokens)
        cov_embeddings = self.cov_embed(torch.arange(self.W_tokens, device=tokens.device)) # Shape: (W, D)
        tokens = tokens + cov_embeddings.view(1, 1, self.W_tokens, -1)
        tokens = rearrange(tokens, "bt f w d -> bt (f w) d")
        return tokens

    def forward(self, x, t, tc=None):
        """
        x: (B, T, C, H, W)
        t: (B,)
        """
        B, T, C, H, W = x.shape
        assert (
            T == self.num_frames
        ), "Fixed num_frames for relative bias; update if variable."

        # Fold time into batch for patching
        x = x.view(B * T, C, H, W)
        tokens = self.x_embedder(x)  # (B*T, N, D)
        tokens = self._apply_cov_embedding(tokens)  # add covariate embeddings

        c = self.t_embedder(t)  # (B, D)

        for i, blk in enumerate(self.blocks):
            if self.use_checkpoint_every > 0 and (self.training and (i % self.use_checkpoint_every == 0)):
                tokens = checkpoint(blk, tokens, c, use_reentrant=False)
            else:
                tokens = blk(tokens, c)

        tokens = self.norm_final(tokens)
        out = self.linear(tokens)  # (B*T, N, out_ch * pf * pw)
        out = self.unpatchify(out)  # (B*T, out_ch, H, W)
        out = out.view(B, T, self.out_channels, H, W)
        return out